import os import uuid from io import BytesIO from pathlib import Path from typing import Literal, Optional import base64 import numpy as np import supervision as sv import uvicorn import yaml from fastapi import (FastAPI, File, Form, HTTPException, Query, Request, UploadFile) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader from PIL import Image from starlette.status import HTTP_403_FORBIDDEN from groundingdino.util.inference import Model, preprocess_caption from pdf_converter import PdfConverter PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir")) PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True) BASE_URL = "http://127.0.0.1:8000" API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "dev-AfghDgr3fgf74vc") API_KEY_NAME = "x-api-key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) pdf_converter: PdfConverter = PdfConverter(dpi=150) grounding_model = Model( model_config_path=os.environ.get( "GROUNDING_DINO_CONFIG", "groundingdino/config/GroundingDINO_SwinT_OGC.py" ), model_checkpoint_path=os.environ.get( "GROUNDING_DINO_CHECKPOINT", "gdino_checkpoints/groundingdino_swint_ogc.pth" ), device="cuda:0", ) BOX_THRESHOLD = 0.4 TEXT_THRESHOLD = 0.25 def output_img(processed_img: Image,output: Literal["url", "base64"]) -> dict[str, any]: if output == "url": unique_filename = f"{uuid.uuid4()}.png" save_path = os.path.join(PROCESSED_FOLDER, unique_filename) # Save the processed image to the filesystem processed_img.save(save_path, 'PNG') # Construct the full URL for the client to access the image # request.base_url gives us the root path (e.g., http://127.0.0.1:8000/) image_url = f"{BASE_URL}/{PROCESSED_FOLDER}/{unique_filename}" image_json = { "image": image_url, "metadata": { "type": "url", "format": "png", "width": processed_img.width, "height": processed_img.height, "mode": processed_img.mode, }, } return image_json elif output == 'base64': # Save the image to an in-memory buffer instead of a file buffered = BytesIO() processed_img.save(buffered, format="PNG") # Encode the bytes to a Base64 string img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Prepend the data URI scheme base64_data_uri = f"data:image/png;base64,{img_str}" image_json = { "image": base64_data_uri, "metadata": { "type": "base64", "format": "png", "width": processed_img.width, "height": processed_img.height, "mode": processed_img.mode, }, } return image_json async def verify_api_key(request: Request): api_key = await api_key_header(request) if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY", ) @app.middleware("http") async def api_key_middleware(request: Request, call_next): # Allow docs and openapi without API key if request.url.path in ["/docs", "/openapi.json", "/redoc"]: return await call_next(request) try: await verify_api_key(request) except HTTPException as exc: return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) return await call_next(request) @app.post("/crop_ooi") async def crop_object_of_interest( document_file: UploadFile = File( ..., description="The document to process." ), concept_list: Optional[list[str]] = Form( ["ID document"], description="List of concepts to detect e.g. dog, cat, rain" ), box_threshold: Optional[float] = Form( 0.4, description="Threshold rate to keep confidence detections" ), text_threshold: Optional[float] = Form(0.25, description="Text threshold"), output: Literal["url", "base64"] = Query( "base64", description="The desired output format." ), render_detection: Optional[bool] = Form( False, description="True if the output contains picture with detected objects with bboxes", ), ): content = await document_file.read() if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]: raise HTTPException( status_code=400, detail=f"Unsupported file type ({document_file.content_type}).", ) try: images = pdf_converter.convert_pdf_bytes_to_jpg(content) text = "".join([f"{x}. " for x in concept_list]) caption = preprocess_caption(text) image_list = [] detection_img_list = [] for image in images: detections, labels = grounding_model.predict_with_caption( image=np.asarray(image), caption=caption, box_threshold=box_threshold, text_threshold=text_threshold, ) confidences = detections.confidence.tolist() class_names = labels labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence in zip(class_names, confidences) ] for i, bbox in enumerate(detections.xyxy): x_min, y_min, x_max, y_max = tuple(bbox) patch = image.crop((int(y_min), int(y_max), int(x_min), int(x_max))) image_json = output_img(processed_img=patch, output=output) image_list.append(image_json) if render_detection: box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_frame = box_annotator.annotate( scene=image.copy(), detections=detections ) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_frame = label_annotator.annotate( scene=annotated_frame, detections=detections, labels=labels ) detection_image_json = output_img(processed_img=annotated_frame, output=output) detection_img_list.append(detection_image_json) response_data = dict(status="succes", images=image_list) return JSONResponse(status_code=200, content=response_data) except Exception as e: print(f"{e}") return JSONResponse(status_code=501, content=str(e)) if __name__ == "__main__": APP_PORT = int(os.environ.get("APP_PORT", 8000)) APP_HOST = os.environ.get("APP_HOST", "0.0.0.0") uvicorn.run(app, host=APP_HOST, port=APP_PORT)