check visison extract model
This commit is contained in:
90
extract/extract.py
Normal file
90
extract/extract.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel
|
||||
# from qwen_vl_utils import process_vision_info
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import LayoutLMv3ImageProcessor, LayoutLMv3Model
|
||||
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME = "microsoft/layoutlmv3-base" # You can choose other model sizes
|
||||
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
|
||||
BATCH_SIZE = 8
|
||||
# --- End Configuration ---
|
||||
|
||||
# Check for GPU availability
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load the model and processor
|
||||
# model = AutoModel.from_pretrained(
|
||||
# MODEL_NAME, torch_dtype="bfloat16", device_map="cuda" # , attn_implementation="flash_attention_2",
|
||||
# )
|
||||
|
||||
|
||||
model = LayoutLMv3Model.from_pretrained(MODEL_NAME, device_map="cuda")
|
||||
processor = LayoutLMv3ImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
|
||||
|
||||
def get_image_embeddings(image_paths):
|
||||
"""
|
||||
Processes a batch of images and extracts their embeddings.
|
||||
"""
|
||||
images_pil = []
|
||||
valid_paths = []
|
||||
for path in image_paths:
|
||||
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
try:
|
||||
# The processor expects PIL images in RGB format
|
||||
images_pil.append(Image.open(path).convert("RGB"))
|
||||
valid_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
|
||||
|
||||
if not images_pil:
|
||||
return np.array([]), []
|
||||
|
||||
# For pure vision feature extraction, we can provide an empty text prompt.
|
||||
# The processor handles tokenizing text and preparing images.
|
||||
# LayoutLMv3 expects 224x224 images by default
|
||||
inputs = processor(
|
||||
# text=[""] * len(images_pil),
|
||||
images=images_pil,
|
||||
# padding=True,
|
||||
size = {"height" : 224, "width": 224},
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Get the vision embeddings from the model's vision tower
|
||||
vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])
|
||||
# We'll use the pooled output as the embedding
|
||||
embeddings = vision_outputs[0][:,0,:]
|
||||
|
||||
return embeddings.to(torch.float16).cpu().numpy()
|
||||
|
||||
|
||||
# --- Process all images in the directory ---
|
||||
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
all_embeddings = []
|
||||
filepaths = []
|
||||
|
||||
with open("embeddings_factures_ostepoathie_1k.json", "w") as f:
|
||||
f.write("[\n")
|
||||
first = True
|
||||
for i in tqdm(range(0, len(image_files), BATCH_SIZE)):
|
||||
batch_paths = image_files[i:i+BATCH_SIZE]
|
||||
batch_embeddings = get_image_embeddings(batch_paths)
|
||||
embeddings_list = [emb.tolist() for emb in batch_embeddings]
|
||||
for path, emb in zip(batch_paths, embeddings_list):
|
||||
if not first:
|
||||
f.write(",\n")
|
||||
json.dump({"filepath": path, "embedding": emb}, f)
|
||||
first = False
|
||||
f.write("\n]\n")
|
||||
|
||||
print("Embeddings extracted and saved.")
|
Reference in New Issue
Block a user