feat: Add app
This commit is contained in:
129
app.py
Normal file
129
app.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import cv2
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
import supervision as sv
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
from groundingdino.util.inference import Model, preprocess_caption
|
||||
from pdf_converter import PdfConverter
|
||||
|
||||
|
||||
API_PARTNER_KEY = ""
|
||||
API_KEY_NAME = "x-api-key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
pdf_converter: PdfConverter = PdfConverter()
|
||||
grounding_model = Model(
|
||||
model_config_path=os.environ.get("GROUNDING_DINO_CONFIG"),
|
||||
model_checkpoint_path=os.environ.get("GROUNDING_DINO_CHECKPOINT"),
|
||||
device="cuda:0",
|
||||
)
|
||||
BOX_THRESHOLD = 0.4
|
||||
TEXT_THRESHOLD = 0.25
|
||||
|
||||
async def verify_api_key(request: Request):
|
||||
api_key = await api_key_header(request)
|
||||
if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate API KEY",
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def api_key_middleware(request: Request, call_next):
|
||||
# Allow docs and openapi without API key
|
||||
if request.url.path in ["/docs", "/openapi.json", "/redoc"]:
|
||||
return await call_next(request)
|
||||
try:
|
||||
await verify_api_key(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.post("/crop_ooi")
|
||||
async def crop_object_of_interest(
|
||||
document_file: Optional[UploadFile] = File(
|
||||
None, description="The document to process."
|
||||
),
|
||||
concept_list: Optional[list[str]] = Form(
|
||||
["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
|
||||
),
|
||||
save_files: Optional[bool] = Form(
|
||||
False, description="True if crop are saved on local"
|
||||
),
|
||||
box_threshold: Optional[float] = Form(
|
||||
0.4, description="Threshold rate to keep confidence detections"
|
||||
),
|
||||
text_threshold: Optional[float] = Form(0.25, description="Text threshold"),
|
||||
):
|
||||
content = await document_file.read()
|
||||
if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type ({document_file.content_type}).",
|
||||
)
|
||||
try:
|
||||
images = pdf_converter.convert_pdf_bytes_to_jpg(content)
|
||||
text = "".join([f"{x}. " for x in concept_list])
|
||||
caption = preprocess_caption(text)
|
||||
for image in images:
|
||||
detections, labels = grounding_model.predict_with_caption(
|
||||
image=images,
|
||||
caption=caption,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold,
|
||||
)
|
||||
confidences = detections.confidence.tolist()
|
||||
class_names = labels
|
||||
|
||||
labels = [
|
||||
f"{class_name} {confidence:.2f}"
|
||||
for class_name, confidence in zip(class_names, confidences)
|
||||
]
|
||||
for i, bbox in enumerate(detections.xyxy):
|
||||
x_min, y_min, x_max, y_max = tuple(bbox)
|
||||
patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)]
|
||||
# patch_img_path = str(
|
||||
# Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png"
|
||||
# )
|
||||
# cv2.imwrite(patch_img_path, patch)
|
||||
|
||||
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=image.copy(), detections=detections
|
||||
)
|
||||
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
||||
annotated_frame = label_annotator.annotate(
|
||||
scene=annotated_frame, detections=detections, labels=labels
|
||||
)
|
||||
# cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
APP_PORT = int(os.environ.get("VLM_APP_PORT", 8009))
|
||||
APP_HOST = os.environ.get("VLM_APP_HOST", "0.0.0.0")
|
||||
uvicorn.run(app, host=APP_HOST, port=APP_PORT)
|
Reference in New Issue
Block a user