diff --git a/app.py b/app.py index e1e636e..4e56f01 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Literal, Optional import base64 -import cv2 +import numpy as np import supervision as sv import uvicorn import yaml @@ -24,7 +24,7 @@ PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir")) PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True) BASE_URL = "http://127.0.0.1:8000" -API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "") +API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "dev-AfghDgr3fgf74vc") API_KEY_NAME = "x-api-key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) @@ -38,10 +38,14 @@ app.add_middleware( allow_headers=["*"], ) -pdf_converter: PdfConverter = PdfConverter() +pdf_converter: PdfConverter = PdfConverter(dpi=150) grounding_model = Model( - model_config_path=os.environ.get("GROUNDING_DINO_CONFIG"), - model_checkpoint_path=os.environ.get("GROUNDING_DINO_CHECKPOINT"), + model_config_path=os.environ.get( + "GROUNDING_DINO_CONFIG", "groundingdino/config/GroundingDINO_SwinT_OGC.py" + ), + model_checkpoint_path=os.environ.get( + "GROUNDING_DINO_CHECKPOINT", "gdino_checkpoints/groundingdino_swint_ogc.pth" + ), device="cuda:0", ) BOX_THRESHOLD = 0.4 @@ -112,8 +116,8 @@ async def api_key_middleware(request: Request, call_next): @app.post("/crop_ooi") async def crop_object_of_interest( - document_file: Optional[UploadFile] = File( - None, description="The document to process." + document_file: UploadFile = File( + ..., description="The document to process." ), concept_list: Optional[list[str]] = Form( ["ID document"], description="List of concepts to detect e.g. dog, cat, rain" @@ -144,7 +148,7 @@ async def crop_object_of_interest( detection_img_list = [] for image in images: detections, labels = grounding_model.predict_with_caption( - image=images, + image=np.asarray(image), caption=caption, box_threshold=box_threshold, text_threshold=text_threshold, @@ -158,7 +162,7 @@ async def crop_object_of_interest( ] for i, bbox in enumerate(detections.xyxy): x_min, y_min, x_max, y_max = tuple(bbox) - patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)] + patch = image.crop((int(y_min), int(y_max), int(x_min), int(x_max))) image_json = output_img(processed_img=patch, output=output) image_list.append(image_json) @@ -179,7 +183,7 @@ async def crop_object_of_interest( except Exception as e: print(f"{e}") - return JSONResponse(status_code=500, content=str(e)) + return JSONResponse(status_code=501, content=str(e)) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 2bb3c73..8e1b235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,5 @@ dependencies = [ "fastapi>=0.116.1", "openai>=1.99.9", "starlette>=0.47.2", + "python-multipart>=0.0.20", ] diff --git a/uv.lock b/uv.lock index b209d6d..c688bb3 100644 --- a/uv.lock +++ b/uv.lock @@ -268,6 +268,7 @@ dependencies = [ { name = "pillow" }, { name = "pip" }, { name = "pycocotools" }, + { name = "python-multipart" }, { name = "setuptools" }, { name = "starlette" }, { name = "supervision" }, @@ -293,6 +294,7 @@ requires-dist = [ { name = "pillow", specifier = ">=9.4.0" }, { name = "pip", specifier = ">=25.2" }, { name = "pycocotools", specifier = ">=2.0.10" }, + { name = "python-multipart", specifier = ">=0.0.20" }, { name = "setuptools", specifier = ">=80.9.0" }, { name = "starlette", specifier = ">=0.47.2" }, { name = "supervision", specifier = ">=0.26.1" }, @@ -1049,6 +1051,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158, upload-time = "2024-12-16T19:45:46.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, +] + [[package]] name = "pywin32" version = "311"