Files
embedding-clustering/extract/extract.py

90 lines
3.2 KiB
Python
Raw Normal View History

2025-09-02 15:01:50 +00:00
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel
# from qwen_vl_utils import process_vision_info
from PIL import Image
import os
import numpy as np
import json
from tqdm import tqdm
from transformers import LayoutLMv3ImageProcessor, LayoutLMv3Model
# --- Configuration ---
MODEL_NAME = "microsoft/layoutlmv3-base" # You can choose other model sizes
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
BATCH_SIZE = 8
# --- End Configuration ---
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the model and processor
# model = AutoModel.from_pretrained(
# MODEL_NAME, torch_dtype="bfloat16", device_map="cuda" # , attn_implementation="flash_attention_2",
# )
model = LayoutLMv3Model.from_pretrained(MODEL_NAME, device_map="cuda")
processor = LayoutLMv3ImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
def get_image_embeddings(image_paths):
"""
Processes a batch of images and extracts their embeddings.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
# The processor expects PIL images in RGB format
images_pil.append(Image.open(path).convert("RGB"))
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
# For pure vision feature extraction, we can provide an empty text prompt.
# The processor handles tokenizing text and preparing images.
# LayoutLMv3 expects 224x224 images by default
inputs = processor(
# text=[""] * len(images_pil),
images=images_pil,
# padding=True,
size = {"height" : 224, "width": 224},
return_tensors="pt"
).to(device)
with torch.no_grad():
# Get the vision embeddings from the model's vision tower
vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])
# We'll use the pooled output as the embedding
embeddings = vision_outputs[0][:,0,:]
return embeddings.to(torch.float16).cpu().numpy()
# --- Process all images in the directory ---
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
all_embeddings = []
filepaths = []
with open("embeddings_factures_ostepoathie_1k.json", "w") as f:
f.write("[\n")
first = True
for i in tqdm(range(0, len(image_files), BATCH_SIZE)):
batch_paths = image_files[i:i+BATCH_SIZE]
batch_embeddings = get_image_embeddings(batch_paths)
embeddings_list = [emb.tolist() for emb in batch_embeddings]
for path, emb in zip(batch_paths, embeddings_list):
if not first:
f.write(",\n")
json.dump({"filepath": path, "embedding": emb}, f)
first = False
f.write("\n]\n")
print("Embeddings extracted and saved.")