Files
grounding-dino/app.py

193 lines
6.9 KiB
Python
Raw Normal View History

2025-08-16 21:16:58 +00:00
import os
2025-08-17 13:28:25 +00:00
import uuid
from io import BytesIO
2025-08-16 21:16:58 +00:00
from pathlib import Path
2025-08-17 13:28:25 +00:00
from typing import Literal, Optional
import base64
2025-08-16 21:16:58 +00:00
import numpy as np
2025-08-16 21:16:58 +00:00
import supervision as sv
import uvicorn
import yaml
2025-08-17 13:28:25 +00:00
from fastapi import (FastAPI, File, Form, HTTPException, Query, Request,
UploadFile)
2025-08-16 21:16:58 +00:00
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
2025-08-17 13:28:25 +00:00
from PIL import Image
2025-08-16 21:16:58 +00:00
from starlette.status import HTTP_403_FORBIDDEN
from groundingdino.util.inference import Model, preprocess_caption
from pdf_converter import PdfConverter
2025-08-17 13:28:25 +00:00
PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir"))
PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True)
BASE_URL = "http://127.0.0.1:8000"
2025-08-16 21:16:58 +00:00
API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "dev-AfghDgr3fgf74vc")
2025-08-16 21:16:58 +00:00
API_KEY_NAME = "x-api-key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
pdf_converter: PdfConverter = PdfConverter(dpi=150)
2025-08-16 21:16:58 +00:00
grounding_model = Model(
model_config_path=os.environ.get(
"GROUNDING_DINO_CONFIG", "groundingdino/config/GroundingDINO_SwinT_OGC.py"
),
model_checkpoint_path=os.environ.get(
"GROUNDING_DINO_CHECKPOINT", "gdino_checkpoints/groundingdino_swint_ogc.pth"
),
2025-08-16 21:16:58 +00:00
device="cuda:0",
)
BOX_THRESHOLD = 0.4
TEXT_THRESHOLD = 0.25
2025-08-17 13:28:25 +00:00
def output_img(processed_img: Image,output: Literal["url", "base64"]) -> dict[str, any]:
if output == "url":
unique_filename = f"{uuid.uuid4()}.png"
save_path = os.path.join(PROCESSED_FOLDER, unique_filename)
# Save the processed image to the filesystem
processed_img.save(save_path, 'PNG')
# Construct the full URL for the client to access the image
# request.base_url gives us the root path (e.g., http://127.0.0.1:8000/)
image_url = f"{BASE_URL}/{PROCESSED_FOLDER}/{unique_filename}"
image_json = {
"image": image_url,
"metadata": {
"type": "url",
"format": "png",
"width": processed_img.width,
"height": processed_img.height,
"mode": processed_img.mode,
},
}
return image_json
elif output == 'base64':
# Save the image to an in-memory buffer instead of a file
buffered = BytesIO()
processed_img.save(buffered, format="PNG")
# Encode the bytes to a Base64 string
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Prepend the data URI scheme
base64_data_uri = f"data:image/png;base64,{img_str}"
image_json = {
"image": base64_data_uri,
"metadata": {
"type": "base64",
"format": "png",
"width": processed_img.width,
"height": processed_img.height,
"mode": processed_img.mode,
},
}
return image_json
2025-08-16 21:16:58 +00:00
async def verify_api_key(request: Request):
api_key = await api_key_header(request)
if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Could not validate API KEY",
)
@app.middleware("http")
async def api_key_middleware(request: Request, call_next):
# Allow docs and openapi without API key
if request.url.path in ["/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
try:
await verify_api_key(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
return await call_next(request)
@app.post("/crop_ooi")
async def crop_object_of_interest(
document_file: UploadFile = File(
..., description="The document to process."
2025-08-16 21:16:58 +00:00
),
concept_list: Optional[list[str]] = Form(
["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
),
box_threshold: Optional[float] = Form(
0.4, description="Threshold rate to keep confidence detections"
),
text_threshold: Optional[float] = Form(0.25, description="Text threshold"),
2025-08-17 13:28:25 +00:00
output: Literal["url", "base64"] = Query(
"base64", description="The desired output format."
),
render_detection: Optional[bool] = Form(
False,
description="True if the output contains picture with detected objects with bboxes",
),
2025-08-16 21:16:58 +00:00
):
content = await document_file.read()
if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type ({document_file.content_type}).",
)
try:
images = pdf_converter.convert_pdf_bytes_to_jpg(content)
text = "".join([f"{x}. " for x in concept_list])
caption = preprocess_caption(text)
2025-08-17 13:28:25 +00:00
image_list = []
detection_img_list = []
2025-08-16 21:16:58 +00:00
for image in images:
detections, labels = grounding_model.predict_with_caption(
image=np.asarray(image),
2025-08-16 21:16:58 +00:00
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
confidences = detections.confidence.tolist()
class_names = labels
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
for i, bbox in enumerate(detections.xyxy):
x_min, y_min, x_max, y_max = tuple(bbox)
patch = image.crop((int(y_min), int(y_max), int(x_min), int(x_max)))
2025-08-17 13:28:25 +00:00
image_json = output_img(processed_img=patch, output=output)
image_list.append(image_json)
if render_detection:
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = box_annotator.annotate(
scene=image.copy(), detections=detections
)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
detection_image_json = output_img(processed_img=annotated_frame, output=output)
detection_img_list.append(detection_image_json)
response_data = dict(status="succes", images=image_list)
return JSONResponse(status_code=200, content=response_data)
2025-08-16 21:16:58 +00:00
except Exception as e:
print(f"{e}")
return JSONResponse(status_code=501, content=str(e))
2025-08-16 21:16:58 +00:00
if __name__ == "__main__":
2025-08-17 13:28:25 +00:00
APP_PORT = int(os.environ.get("APP_PORT", 8000))
APP_HOST = os.environ.get("APP_HOST", "0.0.0.0")
2025-08-16 21:16:58 +00:00
uvicorn.run(app, host=APP_HOST, port=APP_PORT)