From 147c52de718d08d9c32a0b88fb766a3b112b993a Mon Sep 17 00:00:00 2001 From: kiennt Date: Sun, 17 Aug 2025 13:28:25 +0000 Subject: [PATCH] feat: Add app --- .gitignore | 1 + app.py | 105 +++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 83 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index d2232b0..5dcf2fb 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,4 @@ tmp/ data/ *.pth +temp_dir/ diff --git a/app.py b/app.py index d61c161..e1e636e 100644 --- a/app.py +++ b/app.py @@ -1,23 +1,30 @@ import os -import cv2 -from typing import Optional +import uuid +from io import BytesIO from pathlib import Path +from typing import Literal, Optional +import base64 +import cv2 import supervision as sv import uvicorn import yaml -from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi import (FastAPI, File, Form, HTTPException, Query, Request, + UploadFile) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.security.api_key import APIKeyHeader +from PIL import Image from starlette.status import HTTP_403_FORBIDDEN - from groundingdino.util.inference import Model, preprocess_caption from pdf_converter import PdfConverter +PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir")) +PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True) +BASE_URL = "http://127.0.0.1:8000" -API_PARTNER_KEY = "" +API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "") API_KEY_NAME = "x-api-key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) @@ -40,6 +47,48 @@ grounding_model = Model( BOX_THRESHOLD = 0.4 TEXT_THRESHOLD = 0.25 + +def output_img(processed_img: Image,output: Literal["url", "base64"]) -> dict[str, any]: + if output == "url": + unique_filename = f"{uuid.uuid4()}.png" + save_path = os.path.join(PROCESSED_FOLDER, unique_filename) + # Save the processed image to the filesystem + processed_img.save(save_path, 'PNG') + # Construct the full URL for the client to access the image + # request.base_url gives us the root path (e.g., http://127.0.0.1:8000/) + image_url = f"{BASE_URL}/{PROCESSED_FOLDER}/{unique_filename}" + image_json = { + "image": image_url, + "metadata": { + "type": "url", + "format": "png", + "width": processed_img.width, + "height": processed_img.height, + "mode": processed_img.mode, + }, + } + return image_json + elif output == 'base64': + # Save the image to an in-memory buffer instead of a file + buffered = BytesIO() + processed_img.save(buffered, format="PNG") + # Encode the bytes to a Base64 string + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + # Prepend the data URI scheme + base64_data_uri = f"data:image/png;base64,{img_str}" + image_json = { + "image": base64_data_uri, + "metadata": { + "type": "base64", + "format": "png", + "width": processed_img.width, + "height": processed_img.height, + "mode": processed_img.mode, + }, + } + return image_json + + async def verify_api_key(request: Request): api_key = await api_key_header(request) if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY: @@ -69,13 +118,17 @@ async def crop_object_of_interest( concept_list: Optional[list[str]] = Form( ["ID document"], description="List of concepts to detect e.g. dog, cat, rain" ), - save_files: Optional[bool] = Form( - False, description="True if crop are saved on local" - ), box_threshold: Optional[float] = Form( 0.4, description="Threshold rate to keep confidence detections" ), text_threshold: Optional[float] = Form(0.25, description="Text threshold"), + output: Literal["url", "base64"] = Query( + "base64", description="The desired output format." + ), + render_detection: Optional[bool] = Form( + False, + description="True if the output contains picture with detected objects with bboxes", + ), ): content = await document_file.read() if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]: @@ -87,6 +140,8 @@ async def crop_object_of_interest( images = pdf_converter.convert_pdf_bytes_to_jpg(content) text = "".join([f"{x}. " for x in concept_list]) caption = preprocess_caption(text) + image_list = [] + detection_img_list = [] for image in images: detections, labels = grounding_model.predict_with_caption( image=images, @@ -104,26 +159,30 @@ async def crop_object_of_interest( for i, bbox in enumerate(detections.xyxy): x_min, y_min, x_max, y_max = tuple(bbox) patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)] - # patch_img_path = str( - # Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png" - # ) - # cv2.imwrite(patch_img_path, patch) + image_json = output_img(processed_img=patch, output=output) + image_list.append(image_json) - box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) - annotated_frame = box_annotator.annotate( - scene=image.copy(), detections=detections - ) - label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) - annotated_frame = label_annotator.annotate( - scene=annotated_frame, detections=detections, labels=labels - ) - # cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame) + if render_detection: + box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) + annotated_frame = box_annotator.annotate( + scene=image.copy(), detections=detections + ) + label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) + annotated_frame = label_annotator.annotate( + scene=annotated_frame, detections=detections, labels=labels + ) + detection_image_json = output_img(processed_img=annotated_frame, output=output) + detection_img_list.append(detection_image_json) + + response_data = dict(status="succes", images=image_list) + return JSONResponse(status_code=200, content=response_data) except Exception as e: print(f"{e}") + return JSONResponse(status_code=500, content=str(e)) if __name__ == "__main__": - APP_PORT = int(os.environ.get("VLM_APP_PORT", 8009)) - APP_HOST = os.environ.get("VLM_APP_HOST", "0.0.0.0") + APP_PORT = int(os.environ.get("APP_PORT", 8000)) + APP_HOST = os.environ.get("APP_HOST", "0.0.0.0") uvicorn.run(app, host=APP_HOST, port=APP_PORT)