138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import supervision as sv
|
|
import torch
|
|
import yaml
|
|
from pdf2image import convert_from_bytes, convert_from_path
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
from groundingdino.util.inference import Model, preprocess_caption
|
|
|
|
GROUNDING_DINO_CONFIG = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
|
BOX_THRESHOLD = 0.4
|
|
TEXT_THRESHOLD = 0.25
|
|
|
|
|
|
class PdfConverterService:
|
|
"""
|
|
A service to convert PDF files to images and resize images,
|
|
using pypdfium2 for PDFs and Pillow for images.
|
|
"""
|
|
|
|
def __init__(self, dpi: int):
|
|
self._pdfium_initialized = False # Track if PDFium needs explicit init/deinit
|
|
self._dpi = dpi
|
|
|
|
def convert_pdf_to_jpg(self, file_path: str) -> list[Image.Image]:
|
|
"""
|
|
Converts a PDF file to JPG images at different scales.
|
|
"""
|
|
pil_images = convert_from_path(file_path, dpi=self._dpi)
|
|
return pil_images
|
|
|
|
def resize_image(self, img: Image.Image, size: tuple[int, int]) -> Image.Image:
|
|
"""
|
|
Resizes a PIL Image to the specified size.
|
|
"""
|
|
return img.resize(size, Image.LANCZOS)
|
|
|
|
@staticmethod
|
|
def save_image_as_png(img: Image.Image, file_path: str):
|
|
"""
|
|
Saves a PIL Image as a PNG file.
|
|
"""
|
|
img.save(file_path, format="PNG")
|
|
|
|
@staticmethod
|
|
def to_cv2_image(img: Image.Image):
|
|
open_cv_image = np.array(img.convert("RGB"))
|
|
return open_cv_image[:, :, ::-1].copy()
|
|
|
|
def convert_pdf_bytes_to_jpg(self, pdf_bytes: bytes) -> list[Image.Image]:
|
|
"""
|
|
Converts PDF bytes to JPG images at different scales.
|
|
"""
|
|
pil_images = convert_from_bytes(pdf_bytes, dpi=self._dpi)
|
|
return pil_images
|
|
|
|
|
|
def main(
|
|
data_dir: str | Path,
|
|
text: str = "ID card. Carte Vitale. Bank details. Human face.",
|
|
concept_list_yaml: str | None = None,
|
|
device: str = "cuda:0" if torch.cuda.is_available() else "cpu",
|
|
):
|
|
output_dir = Path("outputs") / "extract"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
if concept_list_yaml:
|
|
print(f"Overriding concepts !")
|
|
with open(concept_list_yaml, "r") as f:
|
|
concepts = yaml.load(f)
|
|
text = "".join([f" {x}." for x in concepts])
|
|
|
|
print(f"List of concepts to detect: {text}")
|
|
|
|
if isinstance(data_dir, str):
|
|
data_dir = Path(data_dir)
|
|
|
|
for img_path in tqdm(
|
|
data_dir.glob("*.pdf"), total=len(list(data_dir.glob("*.pdf")))
|
|
):
|
|
pdf_convertor = PdfConverterService(120)
|
|
if img_path.suffix == ".pdf":
|
|
imgs = pdf_convertor.convert_pdf_to_jpg(str(img_path))
|
|
img = imgs[0]
|
|
pdf_convertor.save_image_as_png(img, img_path.parent / "test.png")
|
|
img = pdf_convertor.to_cv2_image(img)
|
|
else:
|
|
img = cv2.imread(str(img_path))
|
|
|
|
# image_source, image = load_image(str(img_path.parent / "test.png"))
|
|
|
|
grounding_model = Model(
|
|
model_config_path=GROUNDING_DINO_CONFIG,
|
|
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
|
|
device=device,
|
|
)
|
|
caption = preprocess_caption(text)
|
|
detections, labels = grounding_model.predict_with_caption(
|
|
image=img,
|
|
caption=caption,
|
|
box_threshold=BOX_THRESHOLD,
|
|
text_threshold=TEXT_THRESHOLD,
|
|
)
|
|
confidences = detections.confidence.tolist()
|
|
class_names = labels
|
|
|
|
labels = [
|
|
f"{class_name} {confidence:.2f}"
|
|
for class_name, confidence in zip(class_names, confidences)
|
|
]
|
|
|
|
for i, bbox in enumerate(detections.xyxy):
|
|
x_min, y_min, x_max, y_max = tuple(bbox)
|
|
patch = img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
|
|
patch_img_path = str(
|
|
Path("outputs") / "extract" / f"{img_path.stem}_{i:d}.png"
|
|
)
|
|
cv2.imwrite(patch_img_path, patch)
|
|
|
|
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
|
annotated_frame = box_annotator.annotate(
|
|
scene=img.copy(), detections=detections
|
|
)
|
|
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
|
annotated_frame = label_annotator.annotate(
|
|
scene=annotated_frame, detections=detections, labels=labels
|
|
)
|
|
|
|
cv2.imwrite(str(Path("outputs") / f"{img_path.stem}.png"), annotated_frame)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main("data")
|