fix: form with multi-part python, np format for model

This commit is contained in:
kiennt
2025-08-17 14:12:26 +00:00
parent 147c52de71
commit e09128f94c
3 changed files with 26 additions and 10 deletions

24
app.py
View File

@@ -5,7 +5,7 @@ from pathlib import Path
from typing import Literal, Optional from typing import Literal, Optional
import base64 import base64
import cv2 import numpy as np
import supervision as sv import supervision as sv
import uvicorn import uvicorn
import yaml import yaml
@@ -24,7 +24,7 @@ PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir"))
PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True) PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True)
BASE_URL = "http://127.0.0.1:8000" BASE_URL = "http://127.0.0.1:8000"
API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "") API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "dev-AfghDgr3fgf74vc")
API_KEY_NAME = "x-api-key" API_KEY_NAME = "x-api-key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
@@ -38,10 +38,14 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
pdf_converter: PdfConverter = PdfConverter() pdf_converter: PdfConverter = PdfConverter(dpi=150)
grounding_model = Model( grounding_model = Model(
model_config_path=os.environ.get("GROUNDING_DINO_CONFIG"), model_config_path=os.environ.get(
model_checkpoint_path=os.environ.get("GROUNDING_DINO_CHECKPOINT"), "GROUNDING_DINO_CONFIG", "groundingdino/config/GroundingDINO_SwinT_OGC.py"
),
model_checkpoint_path=os.environ.get(
"GROUNDING_DINO_CHECKPOINT", "gdino_checkpoints/groundingdino_swint_ogc.pth"
),
device="cuda:0", device="cuda:0",
) )
BOX_THRESHOLD = 0.4 BOX_THRESHOLD = 0.4
@@ -112,8 +116,8 @@ async def api_key_middleware(request: Request, call_next):
@app.post("/crop_ooi") @app.post("/crop_ooi")
async def crop_object_of_interest( async def crop_object_of_interest(
document_file: Optional[UploadFile] = File( document_file: UploadFile = File(
None, description="The document to process." ..., description="The document to process."
), ),
concept_list: Optional[list[str]] = Form( concept_list: Optional[list[str]] = Form(
["ID document"], description="List of concepts to detect e.g. dog, cat, rain" ["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
@@ -144,7 +148,7 @@ async def crop_object_of_interest(
detection_img_list = [] detection_img_list = []
for image in images: for image in images:
detections, labels = grounding_model.predict_with_caption( detections, labels = grounding_model.predict_with_caption(
image=images, image=np.asarray(image),
caption=caption, caption=caption,
box_threshold=box_threshold, box_threshold=box_threshold,
text_threshold=text_threshold, text_threshold=text_threshold,
@@ -158,7 +162,7 @@ async def crop_object_of_interest(
] ]
for i, bbox in enumerate(detections.xyxy): for i, bbox in enumerate(detections.xyxy):
x_min, y_min, x_max, y_max = tuple(bbox) x_min, y_min, x_max, y_max = tuple(bbox)
patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)] patch = image.crop((int(y_min), int(y_max), int(x_min), int(x_max)))
image_json = output_img(processed_img=patch, output=output) image_json = output_img(processed_img=patch, output=output)
image_list.append(image_json) image_list.append(image_json)
@@ -179,7 +183,7 @@ async def crop_object_of_interest(
except Exception as e: except Exception as e:
print(f"{e}") print(f"{e}")
return JSONResponse(status_code=500, content=str(e)) return JSONResponse(status_code=501, content=str(e))
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -26,4 +26,5 @@ dependencies = [
"fastapi>=0.116.1", "fastapi>=0.116.1",
"openai>=1.99.9", "openai>=1.99.9",
"starlette>=0.47.2", "starlette>=0.47.2",
"python-multipart>=0.0.20",
] ]

11
uv.lock generated
View File

@@ -268,6 +268,7 @@ dependencies = [
{ name = "pillow" }, { name = "pillow" },
{ name = "pip" }, { name = "pip" },
{ name = "pycocotools" }, { name = "pycocotools" },
{ name = "python-multipart" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "starlette" }, { name = "starlette" },
{ name = "supervision" }, { name = "supervision" },
@@ -293,6 +294,7 @@ requires-dist = [
{ name = "pillow", specifier = ">=9.4.0" }, { name = "pillow", specifier = ">=9.4.0" },
{ name = "pip", specifier = ">=25.2" }, { name = "pip", specifier = ">=25.2" },
{ name = "pycocotools", specifier = ">=2.0.10" }, { name = "pycocotools", specifier = ">=2.0.10" },
{ name = "python-multipart", specifier = ">=0.0.20" },
{ name = "setuptools", specifier = ">=80.9.0" }, { name = "setuptools", specifier = ">=80.9.0" },
{ name = "starlette", specifier = ">=0.47.2" }, { name = "starlette", specifier = ">=0.47.2" },
{ name = "supervision", specifier = ">=0.26.1" }, { name = "supervision", specifier = ">=0.26.1" },
@@ -1049,6 +1051,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
] ]
[[package]]
name = "python-multipart"
version = "0.0.20"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158, upload-time = "2024-12-16T19:45:46.972Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" },
]
[[package]] [[package]]
name = "pywin32" name = "pywin32"
version = "311" version = "311"