Initial commit

This commit is contained in:
2025-07-10 09:04:29 +00:00
committed by trungkienbkhn
commit 65adb5d4ba
5 changed files with 677 additions and 0 deletions

188
src/embedding_extraction.py Normal file
View File

@@ -0,0 +1,188 @@
import torch
from transformers import AutoModel, AutoProcessor
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import argparse
import json
from typing import Callable, List, Tuple
def get_layoutlm_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],
device: str
) -> Tuple[np.ndarray, List[str]]:
"""
Processes a batch of images using a LayoutLM-like model and extracts their embeddings.
This function can be replaced with another one that follows the same signature
to support different models or embedding strategies.
Args:
model: The loaded Hugging Face model.
processor: The loaded Hugging Face processor.
image_paths: A list of file paths for the images in the batch.
device: The device to run the model on ('cpu', 'cuda').
Returns:
A tuple containing:
- A numpy array of the extracted embeddings.
- A list of the valid file paths that were successfully processed.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
images_pil.append(Image.open(path).convert("RGB"))
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
inputs = processor(
images=images_pil,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
# Forward pass to get model outputs
outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype))
# We use the embedding of the [CLS] token as the document representation
embeddings = outputs.last_hidden_state[:, 0, :]
return embeddings.cpu().numpy(), valid_paths
def get_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],
device: str
) -> Tuple[np.ndarray, List[str]]:
"""
Processes a batch of images using a LayoutLM-like model and extracts their embeddings.
This function can be replaced with another one that follows the same signature
to support different models or embedding strategies.
Args:
model: The loaded Hugging Face model.
processor: The loaded Hugging Face processor.
image_paths: A list of file paths for the images in the batch.
device: The device to run the model on ('cpu', 'cuda').
Returns:
A tuple containing:
- A numpy array of the extracted embeddings.
- A list of the valid file paths that were successfully processed.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
# The processor expects PIL images in RGB format
images_pil.append(Image.open(path).convert("RGB"))
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
# For pure vision feature extraction, we can provide an empty text prompt.
# The processor handles tokenizing text and preparing images.
inputs = processor(
text=[""] * len(images_pil),
images=images_pil,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
# Get the vision embeddings from the model's vision tower
vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw'])
# We'll use the pooled output as the embedding
embeddings = vision_outputs
return embeddings.to(torch.float16).cpu().numpy()
def run_extraction(
model_name: str,
image_dir: str,
batch_size: int,
device: str,
output_file: str,
embedding_extractor: Callable
):
"""
Loads a model and processes all images in a directory to save their embeddings.
"""
print(f"Using device: {device}")
# Load the model and processor
print(f"Loading model: {model_name}")
model = AutoModel.from_pretrained(model_name).to(device)
processor = AutoProcessor.from_pretrained(model_name)
image_files = [os.path.join(image_dir, f) for f in sorted(os.listdir(image_dir)) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(f"Found {len(image_files)} images to process.")
with open(output_file, "w") as f:
f.write("[\n")
is_first_entry = True
for i in tqdm(range(0, len(image_files), batch_size), desc="Extracting Embeddings"):
batch_paths = image_files[i:i+batch_size]
# Call the provided embedding extractor function
batch_embeddings, valid_paths = embedding_extractor(model, processor, batch_paths, device)
if len(valid_paths) == 0:
continue
embeddings_list = [emb.tolist() for emb in batch_embeddings]
for path, emb in zip(valid_paths, embeddings_list):
if not is_first_entry:
f.write(",\n")
json.dump({"filepath": path, "embedding": emb}, f, indent=4)
is_first_entry = False
f.write("\n]\n")
print(f"Embedding extraction complete. Results saved to {output_file}")
def main():
"""
Main function to parse command-line arguments and start the extraction process.
"""
parser = argparse.ArgumentParser(description="Extract document image embeddings using a transformer model.")
parser.add_argument("--model_name", type=str, default="microsoft/layoutlmv3-base", help="Hugging Face model name.")
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing the images to process.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing.")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0', 'cpu'). Autodetects if not set.")
parser.add_argument("--output_file", type=str, default="embeddings.json", help="File path to save the output JSON.")
args = parser.parse_args()
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# The `get_layoutlm_image_embeddings` function is passed here.
# You can define a different function and pass it instead to change the behavior.
run_extraction(
model_name=args.model_name,
image_dir=args.image_dir,
batch_size=args.batch_size,
device=device,
output_file=args.output_file,
embedding_extractor=get_layoutlm_image_embeddings
)
if __name__ == "__main__":
main()