feat: Add app
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -147,3 +147,4 @@ tmp/
|
|||||||
data/
|
data/
|
||||||
|
|
||||||
*.pth
|
*.pth
|
||||||
|
temp_dir/
|
||||||
|
105
app.py
105
app.py
@@ -1,23 +1,30 @@
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import uuid
|
||||||
from typing import Optional
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import cv2
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
from fastapi import (FastAPI, File, Form, HTTPException, Query, Request,
|
||||||
|
UploadFile)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.security.api_key import APIKeyHeader
|
from fastapi.security.api_key import APIKeyHeader
|
||||||
|
from PIL import Image
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
from groundingdino.util.inference import Model, preprocess_caption
|
from groundingdino.util.inference import Model, preprocess_caption
|
||||||
from pdf_converter import PdfConverter
|
from pdf_converter import PdfConverter
|
||||||
|
|
||||||
|
PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir"))
|
||||||
|
PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True)
|
||||||
|
BASE_URL = "http://127.0.0.1:8000"
|
||||||
|
|
||||||
API_PARTNER_KEY = ""
|
API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "")
|
||||||
API_KEY_NAME = "x-api-key"
|
API_KEY_NAME = "x-api-key"
|
||||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||||
|
|
||||||
@@ -40,6 +47,48 @@ grounding_model = Model(
|
|||||||
BOX_THRESHOLD = 0.4
|
BOX_THRESHOLD = 0.4
|
||||||
TEXT_THRESHOLD = 0.25
|
TEXT_THRESHOLD = 0.25
|
||||||
|
|
||||||
|
|
||||||
|
def output_img(processed_img: Image,output: Literal["url", "base64"]) -> dict[str, any]:
|
||||||
|
if output == "url":
|
||||||
|
unique_filename = f"{uuid.uuid4()}.png"
|
||||||
|
save_path = os.path.join(PROCESSED_FOLDER, unique_filename)
|
||||||
|
# Save the processed image to the filesystem
|
||||||
|
processed_img.save(save_path, 'PNG')
|
||||||
|
# Construct the full URL for the client to access the image
|
||||||
|
# request.base_url gives us the root path (e.g., http://127.0.0.1:8000/)
|
||||||
|
image_url = f"{BASE_URL}/{PROCESSED_FOLDER}/{unique_filename}"
|
||||||
|
image_json = {
|
||||||
|
"image": image_url,
|
||||||
|
"metadata": {
|
||||||
|
"type": "url",
|
||||||
|
"format": "png",
|
||||||
|
"width": processed_img.width,
|
||||||
|
"height": processed_img.height,
|
||||||
|
"mode": processed_img.mode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return image_json
|
||||||
|
elif output == 'base64':
|
||||||
|
# Save the image to an in-memory buffer instead of a file
|
||||||
|
buffered = BytesIO()
|
||||||
|
processed_img.save(buffered, format="PNG")
|
||||||
|
# Encode the bytes to a Base64 string
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
# Prepend the data URI scheme
|
||||||
|
base64_data_uri = f"data:image/png;base64,{img_str}"
|
||||||
|
image_json = {
|
||||||
|
"image": base64_data_uri,
|
||||||
|
"metadata": {
|
||||||
|
"type": "base64",
|
||||||
|
"format": "png",
|
||||||
|
"width": processed_img.width,
|
||||||
|
"height": processed_img.height,
|
||||||
|
"mode": processed_img.mode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return image_json
|
||||||
|
|
||||||
|
|
||||||
async def verify_api_key(request: Request):
|
async def verify_api_key(request: Request):
|
||||||
api_key = await api_key_header(request)
|
api_key = await api_key_header(request)
|
||||||
if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY:
|
if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY:
|
||||||
@@ -69,13 +118,17 @@ async def crop_object_of_interest(
|
|||||||
concept_list: Optional[list[str]] = Form(
|
concept_list: Optional[list[str]] = Form(
|
||||||
["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
|
["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
|
||||||
),
|
),
|
||||||
save_files: Optional[bool] = Form(
|
|
||||||
False, description="True if crop are saved on local"
|
|
||||||
),
|
|
||||||
box_threshold: Optional[float] = Form(
|
box_threshold: Optional[float] = Form(
|
||||||
0.4, description="Threshold rate to keep confidence detections"
|
0.4, description="Threshold rate to keep confidence detections"
|
||||||
),
|
),
|
||||||
text_threshold: Optional[float] = Form(0.25, description="Text threshold"),
|
text_threshold: Optional[float] = Form(0.25, description="Text threshold"),
|
||||||
|
output: Literal["url", "base64"] = Query(
|
||||||
|
"base64", description="The desired output format."
|
||||||
|
),
|
||||||
|
render_detection: Optional[bool] = Form(
|
||||||
|
False,
|
||||||
|
description="True if the output contains picture with detected objects with bboxes",
|
||||||
|
),
|
||||||
):
|
):
|
||||||
content = await document_file.read()
|
content = await document_file.read()
|
||||||
if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]:
|
if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]:
|
||||||
@@ -87,6 +140,8 @@ async def crop_object_of_interest(
|
|||||||
images = pdf_converter.convert_pdf_bytes_to_jpg(content)
|
images = pdf_converter.convert_pdf_bytes_to_jpg(content)
|
||||||
text = "".join([f"{x}. " for x in concept_list])
|
text = "".join([f"{x}. " for x in concept_list])
|
||||||
caption = preprocess_caption(text)
|
caption = preprocess_caption(text)
|
||||||
|
image_list = []
|
||||||
|
detection_img_list = []
|
||||||
for image in images:
|
for image in images:
|
||||||
detections, labels = grounding_model.predict_with_caption(
|
detections, labels = grounding_model.predict_with_caption(
|
||||||
image=images,
|
image=images,
|
||||||
@@ -104,26 +159,30 @@ async def crop_object_of_interest(
|
|||||||
for i, bbox in enumerate(detections.xyxy):
|
for i, bbox in enumerate(detections.xyxy):
|
||||||
x_min, y_min, x_max, y_max = tuple(bbox)
|
x_min, y_min, x_max, y_max = tuple(bbox)
|
||||||
patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)]
|
patch = image[int(y_min) : int(y_max), int(x_min) : int(x_max)]
|
||||||
# patch_img_path = str(
|
image_json = output_img(processed_img=patch, output=output)
|
||||||
# Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png"
|
image_list.append(image_json)
|
||||||
# )
|
|
||||||
# cv2.imwrite(patch_img_path, patch)
|
|
||||||
|
|
||||||
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
if render_detection:
|
||||||
annotated_frame = box_annotator.annotate(
|
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
||||||
scene=image.copy(), detections=detections
|
annotated_frame = box_annotator.annotate(
|
||||||
)
|
scene=image.copy(), detections=detections
|
||||||
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
)
|
||||||
annotated_frame = label_annotator.annotate(
|
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
||||||
scene=annotated_frame, detections=detections, labels=labels
|
annotated_frame = label_annotator.annotate(
|
||||||
)
|
scene=annotated_frame, detections=detections, labels=labels
|
||||||
# cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame)
|
)
|
||||||
|
detection_image_json = output_img(processed_img=annotated_frame, output=output)
|
||||||
|
detection_img_list.append(detection_image_json)
|
||||||
|
|
||||||
|
response_data = dict(status="succes", images=image_list)
|
||||||
|
return JSONResponse(status_code=200, content=response_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{e}")
|
print(f"{e}")
|
||||||
|
return JSONResponse(status_code=500, content=str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
APP_PORT = int(os.environ.get("VLM_APP_PORT", 8009))
|
APP_PORT = int(os.environ.get("APP_PORT", 8000))
|
||||||
APP_HOST = os.environ.get("VLM_APP_HOST", "0.0.0.0")
|
APP_HOST = os.environ.get("APP_HOST", "0.0.0.0")
|
||||||
uvicorn.run(app, host=APP_HOST, port=APP_PORT)
|
uvicorn.run(app, host=APP_HOST, port=APP_PORT)
|
||||||
|
Reference in New Issue
Block a user