Compare commits

...

14 Commits

Author SHA1 Message Date
kiennt
83c93e85ac feat: Update dockerfile 2025-08-18 08:06:51 +00:00
kiennt
e09128f94c fix: form with multi-part python, np format for model 2025-08-17 14:12:26 +00:00
kiennt
147c52de71 feat: Add app 2025-08-17 13:28:25 +00:00
kiennt
e3f3ab95cc feat: Add app 2025-08-16 21:16:58 +00:00
kiennt
546f444c1c fix: for CUDA version >= 12.6 2025-08-16 09:57:17 +00:00
kiennt
e1420f9335 fix: Fix setup 2025-08-16 09:47:52 +00:00
Ren Tianhe
856dde20ae Grounded SAM 2 Release 2024-08-12 16:52:02 +08:00
Piotr Skalski
5a890bd867 Merge pull request #342 from ethanlee928/main
fix Supervision depreciation of BoxAnnotator
2024-07-24 08:59:41 +02:00
Piotr Skalski
e49e881edd Merge branch 'main' into main 2024-07-24 08:58:19 +02:00
ethanlee928
8b6a55f612 replaced BoundingBoxAnnotator with BoxAnnotator, updated Supervision version 2024-07-23 23:19:52 +08:00
Piotr Skalski
e27a646ca0 Update requirements.txt
`supervision==0.22.0` is deprecating `BoxAnnotator`. I'm freezing the `supervision` version to prevent any problems.
2024-07-12 12:27:24 +02:00
ethanlee928
d75c95daf6 fix Supervision depreciation of BoxAnnotator 2024-06-29 01:10:48 +08:00
Ren Tianhe
df5b48a3ef Update README.md 2024-05-23 20:10:37 +08:00
Ren Tianhe
4330960fa7 Grounding DINO 1.5 Release 2024-05-18 13:36:18 +08:00
17 changed files with 1957 additions and 55 deletions

6
.gitignore vendored
View File

@@ -143,4 +143,8 @@ grounding/config/configs
grounding/version.py
vis/
tmp/
tmp/
data/
*.pth
temp_dir/

View File

@@ -1,35 +1,28 @@
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime
ARG DEBIAN_FRONTEND=noninteractive
ENV CUDA_HOME=/usr/local/cuda \
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
SETUPTOOLS_USE_DISTUTILS=stdlib
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX"
RUN conda update conda -y
# RUN conda update conda -y
# Install libraries in the brand new image.
RUN apt-get -y update && apt-get install -y --no-install-recommends \
wget \
build-essential \
git \
python3-opencv \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
RUN apt -y update && apt install tesseract-ocr -y
RUN apt-get -y update && apt-get install libgl1 poppler-utils -y --no-install-recommends
# Set the working directory for all the subsequent Dockerfile instructions.
WORKDIR /opt/program
RUN git clone https://github.com/IDEA-Research/GroundingDINO.git
WORKDIR /app
RUN mkdir weights ; cd weights ; wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth ; cd ..
RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y
ENV CUDA_HOME=$CONDA_PREFIX
COPY app.py /app/
COPY pdf_converter.py /app/
COPY uv.lock /app/
COPY setup.py /app/
ADD groundingdino/ /app/groundingdino/
COPY pyproject.toml /app/
ENV PATH=/usr/local/cuda/bin:$PATH
# RUN cd GroundingDINO/ && python -m pip install .
RUN pip install --no-cache-dir uv && uv sync --no-cache
RUN /app/.venv/bin/python -m pip install -e .
RUN cd GroundingDINO/ && python -m pip install .
COPY docker_test.py docker_test.py
CMD [ "python", "docker_test.py" ]
CMD [ "/app/.venv/bin/python", "app.py" ]

View File

@@ -18,6 +18,10 @@
PyTorch implementation and pretrained models for Grounding DINO. For details, see the paper **[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)**.
- 🔥 **[Grounded SAM 2](https://github.com/IDEA-Research/Grounded-SAM-2)** is released now, which combines Grounding DINO with [SAM 2](https://github.com/facebookresearch/segment-anything-2) for any object tracking in open-world scenarios.
- 🔥 **[Grounding DINO 1.5](https://github.com/IDEA-Research/Grounding-DINO-1.5-API)** is released now, which is IDEA Research's **Most Capable** Open-World Object Detection Model!
- 🔥 **[Grounding DINO](https://arxiv.org/abs/2303.05499)** and **[Grounded SAM](https://arxiv.org/abs/2401.14159)** are now supported in Huggingface. For more convenient use, you can refer to [this documentation](https://huggingface.co/docs/transformers/model_doc/grounding-dino)
## :sun_with_face: Helpful Tutorial
- :grapes: [[Read our arXiv Paper](https://arxiv.org/abs/2303.05499)]

192
app.py Normal file
View File

@@ -0,0 +1,192 @@
import os
import uuid
from io import BytesIO
from pathlib import Path
from typing import Literal, Optional
import base64
import numpy as np
import supervision as sv
import uvicorn
import yaml
from fastapi import (FastAPI, File, Form, HTTPException, Query, Request,
UploadFile)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from PIL import Image
from starlette.status import HTTP_403_FORBIDDEN
from groundingdino.util.inference import Model, preprocess_caption
from pdf_converter import PdfConverter
PROCESSED_FOLDER = Path(os.environ.get("TEMP_IMG_FOLDER", "temp_dir"))
PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True)
BASE_URL = "http://127.0.0.1:8000"
API_PARTNER_KEY = os.environ.get("API_PARTNER_KEY", "dev-AfghDgr3fgf74vc")
API_KEY_NAME = "x-api-key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
pdf_converter: PdfConverter = PdfConverter(dpi=150)
grounding_model = Model(
model_config_path=os.environ.get(
"GROUNDING_DINO_CONFIG", "groundingdino/config/GroundingDINO_SwinT_OGC.py"
),
model_checkpoint_path=os.environ.get(
"GROUNDING_DINO_CHECKPOINT", "gdino_checkpoints/groundingdino_swint_ogc.pth"
),
device="cuda:0",
)
BOX_THRESHOLD = 0.4
TEXT_THRESHOLD = 0.25
def output_img(processed_img: Image,output: Literal["url", "base64"]) -> dict[str, any]:
if output == "url":
unique_filename = f"{uuid.uuid4()}.png"
save_path = os.path.join(PROCESSED_FOLDER, unique_filename)
# Save the processed image to the filesystem
processed_img.save(save_path, 'PNG')
# Construct the full URL for the client to access the image
# request.base_url gives us the root path (e.g., http://127.0.0.1:8000/)
image_url = f"{BASE_URL}/{PROCESSED_FOLDER}/{unique_filename}"
image_json = {
"image": image_url,
"metadata": {
"type": "url",
"format": "png",
"width": processed_img.width,
"height": processed_img.height,
"mode": processed_img.mode,
},
}
return image_json
elif output == 'base64':
# Save the image to an in-memory buffer instead of a file
buffered = BytesIO()
processed_img.save(buffered, format="PNG")
# Encode the bytes to a Base64 string
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Prepend the data URI scheme
base64_data_uri = f"data:image/png;base64,{img_str}"
image_json = {
"image": base64_data_uri,
"metadata": {
"type": "base64",
"format": "png",
"width": processed_img.width,
"height": processed_img.height,
"mode": processed_img.mode,
},
}
return image_json
async def verify_api_key(request: Request):
api_key = await api_key_header(request)
if API_PARTNER_KEY is None or api_key != API_PARTNER_KEY:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Could not validate API KEY",
)
@app.middleware("http")
async def api_key_middleware(request: Request, call_next):
# Allow docs and openapi without API key
if request.url.path in ["/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
try:
await verify_api_key(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
return await call_next(request)
@app.post("/crop_ooi")
async def crop_object_of_interest(
document_file: UploadFile = File(
..., description="The document to process."
),
concept_list: Optional[list[str]] = Form(
["ID document"], description="List of concepts to detect e.g. dog, cat, rain"
),
box_threshold: Optional[float] = Form(
0.4, description="Threshold rate to keep confidence detections"
),
text_threshold: Optional[float] = Form(0.25, description="Text threshold"),
output: Literal["url", "base64"] = Query(
"base64", description="The desired output format."
),
render_detection: Optional[bool] = Form(
False,
description="True if the output contains picture with detected objects with bboxes",
),
):
content = await document_file.read()
if document_file.content_type not in ["application/pdf", "image/jpeg", "image/png"]:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type ({document_file.content_type}).",
)
try:
images = pdf_converter.convert_pdf_bytes_to_jpg(content)
text = "".join([f"{x}. " for x in concept_list])
caption = preprocess_caption(text)
image_list = []
detection_img_list = []
for image in images:
detections, labels = grounding_model.predict_with_caption(
image=np.asarray(image),
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
confidences = detections.confidence.tolist()
class_names = labels
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
for i, bbox in enumerate(detections.xyxy):
x_min, y_min, x_max, y_max = tuple(bbox)
patch = image.crop((int(y_min), int(y_max), int(x_min), int(x_max)))
image_json = output_img(processed_img=patch, output=output)
image_list.append(image_json)
if render_detection:
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = box_annotator.annotate(
scene=image.copy(), detections=detections
)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
detection_image_json = output_img(processed_img=annotated_frame, output=output)
detection_img_list.append(detection_image_json)
response_data = dict(status="succes", images=image_list)
return JSONResponse(status_code=200, content=response_data)
except Exception as e:
print(f"{e}")
return JSONResponse(status_code=501, content=str(e))
if __name__ == "__main__":
APP_PORT = int(os.environ.get("APP_PORT", 8000))
APP_HOST = os.environ.get("APP_HOST", "0.0.0.0")
uvicorn.run(app, host=APP_HOST, port=APP_PORT)

View File

@@ -0,0 +1,24 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Define the URLs for the checkpoints
BASE_URL="https://github.com/IDEA-Research/GroundingDINO/releases/download/"
swint_ogc_url="${BASE_URL}v0.1.0-alpha/groundingdino_swint_ogc.pth"
swinb_cogcoor_url="${BASE_URL}v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth"
# Download each of the four checkpoints using wget
echo "Downloading groundingdino_swint_ogc.pth checkpoint..."
wget $swint_ogc_url || { echo "Failed to download checkpoint from $swint_ogc_url"; exit 1; }
echo "Downloading groundingdino_swinb_cogcoor.pth checkpoint..."
wget $swinb_cogcoor_url || { echo "Failed to download checkpoint from $swinb_cogcoor_url"; exit 1; }
echo "All checkpoints are downloaded successfully."

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_
from groundingdino.util.misc import NestedTensor
@@ -445,7 +445,7 @@ class BasicLayer(nn.Module):
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
x = checkpoint.checkpoint(blk, x, attn_mask, use_reentrant=True)
else:
x = blk(x, attn_mask)
if self.downsample is not None:

View File

@@ -15,11 +15,24 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/version.h>
// Check PyTorch version and define appropriate macros
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
// PyTorch 2.x and above
#define GET_TENSOR_TYPE(x) x.scalar_type()
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
#else
// PyTorch 1.x
#define GET_TENSOR_TYPE(x) x.type()
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
#endif
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
@@ -51,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
@@ -82,7 +95,7 @@ at::Tensor ms_deform_attn_cuda_forward(
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
@@ -128,11 +141,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
@@ -153,4 +166,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
};
}
} // namespace groundingdino
} // namespace groundingdino

View File

@@ -8,7 +8,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from timm.layers import DropPath
class FeatureResizer(nn.Module):

View File

@@ -554,6 +554,7 @@ class TransformerEncoder(nn.Module):
memory_text,
key_padding_mask,
text_attention_mask,
use_reentrant=True,
)
else:
output, memory_text = self.fusion_layers[layer_id](
@@ -581,6 +582,7 @@ class TransformerEncoder(nn.Module):
spatial_shapes,
level_start_index,
key_padding_mask,
use_reentrant=True,
)
else:
output = layer(
@@ -859,7 +861,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast("cuda", enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)

View File

@@ -121,9 +121,11 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
in zip(phrases, logits)
]
box_annotator = sv.BoxAnnotator()
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=detections)
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame

1
groundingdino/version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

92
main.py Normal file
View File

@@ -0,0 +1,92 @@
from pathlib import Path
import cv2
import supervision as sv
import torch
import yaml
from tqdm import tqdm
from groundingdino.util.inference import Model, preprocess_caption
from pdf_converter import PdfConverter
GROUNDING_DINO_CONFIG = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.4
TEXT_THRESHOLD = 0.25
def main(
data_dir: str | Path,
text: str = "ID card. Carte Vitale. Bank details. Human face.",
concept_list_yaml: str | None = None,
device: str = "cuda:0" if torch.cuda.is_available() else "cpu",
):
output_dir = Path("outputs") / "extract"
output_dir.mkdir(parents=True, exist_ok=True)
if concept_list_yaml:
print(f"Overriding concepts !")
with open(concept_list_yaml, "r") as f:
concepts = yaml.load(f)
text = "".join([f" {x}." for x in concepts])
print(f"List of concepts to detect: {text}")
if isinstance(data_dir, str):
data_dir = Path(data_dir)
for img_path in tqdm(
data_dir.glob("*.pdf"), total=len(list(data_dir.glob("*.pdf")))
):
pdf_convertor = PdfConverter(120)
if img_path.suffix == ".pdf":
imgs = pdf_convertor.convert_pdf_to_jpg(str(img_path))
img = imgs[0]
pdf_convertor.save_image_as_png(img, img_path.parent / "test.png")
img = pdf_convertor.to_cv2_image(img)
else:
img = cv2.imread(str(img_path))
# image_source, image = load_image(str(img_path.parent / "test.png"))
grounding_model = Model(
model_config_path=GROUNDING_DINO_CONFIG,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
device=device,
)
caption = preprocess_caption(text)
detections, labels = grounding_model.predict_with_caption(
image=img,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD,
)
confidences = detections.confidence.tolist()
class_names = labels
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
for i, bbox in enumerate(detections.xyxy):
x_min, y_min, x_max, y_max = tuple(bbox)
patch = img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
patch_img_path = str(
Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png"
)
cv2.imwrite(patch_img_path, patch)
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = box_annotator.annotate(
scene=img.copy(), detections=detections
)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame)
if __name__ == "__main__":
main("data")

46
pdf_converter.py Normal file
View File

@@ -0,0 +1,46 @@
import numpy as np
from pdf2image import convert_from_bytes, convert_from_path
from PIL import Image
class PdfConverter:
"""
A service to convert PDF files to images and resize images,
using pypdfium2 for PDFs and Pillow for images.
"""
def __init__(self, dpi: int):
self._pdfium_initialized = False # Track if PDFium needs explicit init/deinit
self._dpi = dpi
def convert_pdf_to_jpg(self, file_path: str) -> list[Image.Image]:
"""
Converts a PDF file to JPG images at different scales.
"""
pil_images = convert_from_path(file_path, dpi=self._dpi)
return pil_images
def resize_image(self, img: Image.Image, size: tuple[int, int]) -> Image.Image:
"""
Resizes a PIL Image to the specified size.
"""
return img.resize(size, Image.LANCZOS)
@staticmethod
def save_image_as_png(img: Image.Image, file_path: str):
"""
Saves a PIL Image as a PNG file.
"""
img.save(file_path, format="PNG")
@staticmethod
def to_cv2_image(img: Image.Image):
open_cv_image = np.array(img.convert("RGB"))
return open_cv_image[:, :, ::-1].copy()
def convert_pdf_bytes_to_jpg(self, pdf_bytes: bytes) -> list[Image.Image]:
"""
Converts PDF bytes to JPG images at different scales.
"""
pil_images = convert_from_bytes(pdf_bytes, dpi=self._dpi)
return pil_images

30
pyproject.toml Normal file
View File

@@ -0,0 +1,30 @@
[project]
name = "groundingdino"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"torch>=2.3.1",
"torchvision>=0.18.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"opencv-python-headless>=4.8.0",
"supervision>=0.26.1",
"pycocotools>=2.0.10",
"transformers>=4.55.1",
"addict>=2.4.0",
"yapf>=0.43.0",
"timm>=1.0.19",
"pdf2image>=1.17.0",
"pip>=25.2",
"setuptools>=80.9.0",
"uvicorn>=0.35.0",
"fastapi>=0.116.1",
"openai>=1.99.9",
"starlette>=0.47.2",
"python-multipart>=0.0.20",
]

View File

@@ -6,5 +6,5 @@ yapf
timm
numpy
opencv-python
supervision
supervision>=0.22.0
pycocotools

View File

@@ -27,14 +27,19 @@ import subprocess
import subprocess
import sys
def install_torch():
try:
import torch
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
# Call the function to ensure torch is installed
install_torch()
# install_torch()
sys.path.insert(
0, f"{__file__}/.venv/lib/python3.13/site-packages"
)
import torch
from setuptools import find_packages, setup
@@ -48,7 +53,11 @@ cwd = os.path.dirname(os.path.abspath(__file__))
sha = "Unknown"
try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
.decode("ascii")
.strip()
)
except Exception:
pass
@@ -67,7 +76,9 @@ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
extensions_dir = os.path.join(
this_dir, "groundingdino", "models", "GroundingDINO", "csrc"
)
main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
@@ -82,7 +93,9 @@ def get_extensions():
extra_compile_args = {"cxx": []}
define_macros = []
if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
if CUDA_HOME is not None and (
torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ
):
print("Compiling with CUDA")
extension = CUDAExtension
sources += source_cuda
@@ -92,6 +105,10 @@ def get_extensions():
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"-gencode=arch=compute_70,code=sm_70",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
]
else:
print("Compiling without CUDA")
@@ -99,7 +116,7 @@ def get_extensions():
extra_compile_args["nvcc"] = []
return None
sources = [os.path.join(extensions_dir, s) for s in sources]
sources = [x.replace(this_dir + "/", "") for x in sources]
include_dirs = [extensions_dir]
ext_modules = [
@@ -208,7 +225,7 @@ if __name__ == "__main__":
url="https://github.com/IDEA-Research/GroundingDINO",
description="open-set object detector",
license=license,
install_requires=parse_requirements("requirements.txt"),
# install_requires=parse_requirements("requirements.txt"),
packages=find_packages(
exclude=(
"configs",

1482
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff