90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
import torch
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel
|
|
# from qwen_vl_utils import process_vision_info
|
|
from PIL import Image
|
|
import os
|
|
import numpy as np
|
|
import json
|
|
from tqdm import tqdm
|
|
|
|
from transformers import LayoutLMv3ImageProcessor, LayoutLMv3Model
|
|
|
|
|
|
# --- Configuration ---
|
|
MODEL_NAME = "microsoft/layoutlmv3-base" # You can choose other model sizes
|
|
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
|
|
BATCH_SIZE = 8
|
|
# --- End Configuration ---
|
|
|
|
# Check for GPU availability
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
|
|
# Load the model and processor
|
|
# model = AutoModel.from_pretrained(
|
|
# MODEL_NAME, torch_dtype="bfloat16", device_map="cuda" # , attn_implementation="flash_attention_2",
|
|
# )
|
|
|
|
|
|
model = LayoutLMv3Model.from_pretrained(MODEL_NAME, device_map="cuda")
|
|
processor = LayoutLMv3ImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
|
|
def get_image_embeddings(image_paths):
|
|
"""
|
|
Processes a batch of images and extracts their embeddings.
|
|
"""
|
|
images_pil = []
|
|
valid_paths = []
|
|
for path in image_paths:
|
|
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
|
try:
|
|
# The processor expects PIL images in RGB format
|
|
images_pil.append(Image.open(path).convert("RGB"))
|
|
valid_paths.append(path)
|
|
except Exception as e:
|
|
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
|
|
|
|
if not images_pil:
|
|
return np.array([]), []
|
|
|
|
# For pure vision feature extraction, we can provide an empty text prompt.
|
|
# The processor handles tokenizing text and preparing images.
|
|
# LayoutLMv3 expects 224x224 images by default
|
|
inputs = processor(
|
|
# text=[""] * len(images_pil),
|
|
images=images_pil,
|
|
# padding=True,
|
|
size = {"height" : 224, "width": 224},
|
|
return_tensors="pt"
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
# Get the vision embeddings from the model's vision tower
|
|
vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])
|
|
# We'll use the pooled output as the embedding
|
|
embeddings = vision_outputs[0][:,0,:]
|
|
|
|
return embeddings.to(torch.float16).cpu().numpy()
|
|
|
|
|
|
# --- Process all images in the directory ---
|
|
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
|
all_embeddings = []
|
|
filepaths = []
|
|
|
|
with open("embeddings_factures_ostepoathie_1k.json", "w") as f:
|
|
f.write("[\n")
|
|
first = True
|
|
for i in tqdm(range(0, len(image_files), BATCH_SIZE)):
|
|
batch_paths = image_files[i:i+BATCH_SIZE]
|
|
batch_embeddings = get_image_embeddings(batch_paths)
|
|
embeddings_list = [emb.tolist() for emb in batch_embeddings]
|
|
for path, emb in zip(batch_paths, embeddings_list):
|
|
if not first:
|
|
f.write(",\n")
|
|
json.dump({"filepath": path, "embedding": emb}, f)
|
|
first = False
|
|
f.write("\n]\n")
|
|
|
|
print("Embeddings extracted and saved.") |