Files
grounding-dino/main.py

93 lines
3.1 KiB
Python
Raw Permalink Normal View History

2025-08-16 09:57:17 +00:00
from pathlib import Path
import cv2
import supervision as sv
import torch
import yaml
from tqdm import tqdm
from groundingdino.util.inference import Model, preprocess_caption
2025-08-16 21:16:58 +00:00
from pdf_converter import PdfConverter
2025-08-16 09:57:17 +00:00
GROUNDING_DINO_CONFIG = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.4
TEXT_THRESHOLD = 0.25
def main(
data_dir: str | Path,
text: str = "ID card. Carte Vitale. Bank details. Human face.",
concept_list_yaml: str | None = None,
device: str = "cuda:0" if torch.cuda.is_available() else "cpu",
):
output_dir = Path("outputs") / "extract"
output_dir.mkdir(parents=True, exist_ok=True)
if concept_list_yaml:
print(f"Overriding concepts !")
with open(concept_list_yaml, "r") as f:
concepts = yaml.load(f)
text = "".join([f" {x}." for x in concepts])
print(f"List of concepts to detect: {text}")
if isinstance(data_dir, str):
data_dir = Path(data_dir)
for img_path in tqdm(
data_dir.glob("*.pdf"), total=len(list(data_dir.glob("*.pdf")))
):
2025-08-16 21:16:58 +00:00
pdf_convertor = PdfConverter(120)
2025-08-16 09:57:17 +00:00
if img_path.suffix == ".pdf":
imgs = pdf_convertor.convert_pdf_to_jpg(str(img_path))
img = imgs[0]
pdf_convertor.save_image_as_png(img, img_path.parent / "test.png")
img = pdf_convertor.to_cv2_image(img)
else:
img = cv2.imread(str(img_path))
# image_source, image = load_image(str(img_path.parent / "test.png"))
grounding_model = Model(
model_config_path=GROUNDING_DINO_CONFIG,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
device=device,
)
caption = preprocess_caption(text)
detections, labels = grounding_model.predict_with_caption(
image=img,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD,
)
confidences = detections.confidence.tolist()
class_names = labels
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
for i, bbox in enumerate(detections.xyxy):
x_min, y_min, x_max, y_max = tuple(bbox)
patch = img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
patch_img_path = str(
Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png"
)
cv2.imwrite(patch_img_path, patch)
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = box_annotator.annotate(
scene=img.copy(), detections=detections
)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame)
if __name__ == "__main__":
main("data")