Initial commit

This commit is contained in:
lphatnguyen 2025-07-10 09:04:29 +00:00 committed by trungkienbkhn
commit 65adb5d4ba
5 changed files with 677 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.json
extracted_images/

0
README.md Normal file
View File

205
clustering_example.ipynb Normal file
View File

@ -0,0 +1,205 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "59f8a415",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n",
"from qwen_vl_utils import process_vision_info\n",
"from PIL import Image\n",
"import os\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"# --- Configuration ---\n",
"MODEL_NAME = \"Qwen/Qwen2.5-VL-3B-Instruct\" # You can choose other model sizes\n",
"IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n",
"BATCH_SIZE = 4\n",
"# --- End Configuration ---\n",
"\n",
"# Check for GPU availability\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Load the model and processor\n",
"model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n",
" MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\", attn_implementation=\"flash_attention_2\",\n",
")\n",
"processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdfdab0e",
"metadata": {},
"outputs": [],
"source": [
"def get_image_embeddings(image_paths):\n",
" \"\"\"\n",
" Processes a batch of images and extracts their embeddings.\n",
" \"\"\"\n",
" images_pil = []\n",
" valid_paths = []\n",
" for path in image_paths:\n",
" if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
" try:\n",
" # The processor expects PIL images in RGB format\n",
" images_pil.append(Image.open(path).convert(\"RGB\"))\n",
" valid_paths.append(path)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n",
"\n",
" if not images_pil:\n",
" return np.array([]), []\n",
"\n",
" # For pure vision feature extraction, we can provide an empty text prompt.\n",
" # The processor handles tokenizing text and preparing images.\n",
" inputs = processor(\n",
" text=[\"\"] * len(images_pil),\n",
" images=images_pil,\n",
" padding=True,\n",
" return_tensors=\"pt\"\n",
" ).to(device)\n",
"\n",
" with torch.no_grad():\n",
" # Get the vision embeddings from the model's vision tower\n",
" vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw'])\n",
" # We'll use the pooled output as the embedding\n",
" embeddings = vision_outputs\n",
"\n",
" return embeddings.to(torch.float16).cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdaebb7b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# --- Process all images in the directory ---\n",
"image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
"all_embeddings = []\n",
"filepaths = []\n",
"\n",
"with open(\"embeddings_factures_osteopathie_1k_qwen.json\", \"w\") as f:\n",
" f.write(\"[\\n\")\n",
" first = True\n",
" for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n",
" batch_paths = image_files[i:i+BATCH_SIZE]\n",
" batch_embeddings = get_image_embeddings(batch_paths)\n",
" embeddings_list = [emb.tolist() for emb in batch_embeddings]\n",
" for path, emb in zip(batch_paths, embeddings_list):\n",
" if not first:\n",
" f.write(\",\\n\")\n",
" json.dump({\"filepath\": path, \"embedding\": emb}, f)\n",
" first = False\n",
" f.write(\"\\n]\\n\")\n",
"\n",
"print(\"Embeddings extracted and saved.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27fea4f3",
"metadata": {},
"outputs": [],
"source": [
"import fiftyone as fo\n",
"import fiftyone.brain as fob\n",
"import numpy as np\n",
"from sklearn.mixture import GaussianMixture\n",
"import json\n",
"\n",
"DATASET_NAME = \"mock\"\n",
"\n",
"json_path = \"./embeddings_factures_osteopathie_1k_qwen.json\"\n",
"\n",
"with open(json_path, \"r\") as file:\n",
" embedding_data = json.load(file)\n",
"\n",
"file_paths = []\n",
"embeddings = []\n",
"for i, record in enumerate(embedding_data):\n",
" file_paths.append(record.get(\"filepath\"))\n",
" embeddings.append(record.get(\"embedding\"))\n",
"\n",
"if DATASET_NAME in fo.list_datasets():\n",
" dataset = fo.load_dataset(DATASET_NAME)\n",
" dataset.delete()\n",
"dataset = fo.Dataset(DATASET_NAME)\n",
"\n",
"# Add samples to the dataset\n",
"samples = [fo.Sample(filepath=p) for p in file_paths]\n",
"dataset.add_samples(samples)\n",
"\n",
"# Building Gaussian mixture model (GMM)\n",
"n_gaussians = 50\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"n_gaussians = 200\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"# --- Visualize the Embeddings with UMAP ---\n",
"# This will compute a 2D representation of your embeddings\n",
"# for visualization.\n",
"res = fob.compute_visualization(\n",
" dataset,\n",
" embeddings=embeddings,\n",
" brain_key=\"qwen_vision_viz\",\n",
" method=\"tsne\",\n",
" verbose=True\n",
")\n",
"dataset.set_values(\"qwen_umap\", res.current_points)\n",
"\n",
"print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
"session = fo.launch_app(dataset)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sesame",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

282
clustering_layoutlm.ipynb Normal file
View File

@ -0,0 +1,282 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "59f8a415",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel\n",
"from qwen_vl_utils import process_vision_info\n",
"from PIL import Image\n",
"import os\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"# --- Configuration ---\n",
"MODEL_NAME = \"microsoft/layoutlmv3-base\" # You can choose other model sizes\n",
"IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n",
"BATCH_SIZE = 32\n",
"# --- End Configuration ---\n",
"\n",
"# Check for GPU availability\n",
"device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Load the model and processor\n",
"model = AutoModel.from_pretrained(\n",
" MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\" # , attn_implementation=\"flash_attention_2\",\n",
")\n",
"processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdfdab0e",
"metadata": {},
"outputs": [],
"source": [
"def get_image_embeddings(image_paths):\n",
" \"\"\"\n",
" Processes a batch of images and extracts their embeddings.\n",
" \"\"\"\n",
" images_pil = []\n",
" valid_paths = []\n",
" for path in image_paths:\n",
" if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
" try:\n",
" # The processor expects PIL images in RGB format\n",
" images_pil.append(Image.open(path).convert(\"RGB\"))\n",
" valid_paths.append(path)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n",
"\n",
" if not images_pil:\n",
" return np.array([]), []\n",
"\n",
" # For pure vision feature extraction, we can provide an empty text prompt.\n",
" # The processor handles tokenizing text and preparing images.\n",
" inputs = processor(\n",
" # text=[\"\"] * len(images_pil),\n",
" images=images_pil,\n",
" padding=True,\n",
" return_tensors=\"pt\"\n",
" ).to(device)\n",
"\n",
" with torch.no_grad():\n",
" # Get the vision embeddings from the model's vision tower\n",
" vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])\n",
" # We'll use the pooled output as the embedding\n",
" embeddings = vision_outputs[0][:,0,:]\n",
"\n",
" return embeddings.to(torch.float16).cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdaebb7b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# --- Process all images in the directory ---\n",
"image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
"all_embeddings = []\n",
"filepaths = []\n",
"\n",
"with open(\"embeddings_factures_ostepoathie_1k.json\", \"w\") as f:\n",
" f.write(\"[\\n\")\n",
" first = True\n",
" for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n",
" batch_paths = image_files[i:i+BATCH_SIZE]\n",
" batch_embeddings = get_image_embeddings(batch_paths)\n",
" embeddings_list = [emb.tolist() for emb in batch_embeddings]\n",
" for path, emb in zip(batch_paths, embeddings_list):\n",
" if not first:\n",
" f.write(\",\\n\")\n",
" json.dump({\"filepath\": path, \"embedding\": emb}, f)\n",
" first = False\n",
" f.write(\"\\n]\\n\")\n",
"\n",
"print(\"Embeddings extracted and saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "27fea4f3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/nguyendc/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n",
"Generating visualization...\n",
"[t-SNE] Computing 91 nearest neighbors...\n",
"[t-SNE] Indexed 1090 samples in 0.000s...\n",
"[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n",
"[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n",
"[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n",
"[t-SNE] Mean sigma: 1.117175\n",
"[t-SNE] Computed conditional probabilities in 0.018s\n",
"[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n",
"[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n",
"[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n",
"[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n",
"[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n",
"[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n",
"[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n",
"[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n",
"[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n",
"[t-SNE] KL divergence after 1000 iterations: 0.727630\n",
"UMAP visualization computed. Launch the app to see the plot.\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"800\"\n",
" src=\"http://localhost:5151/?notebook=True&subscription=9f7aad1e-72b8-41ab-873d-b776dbb9e9d5\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x798114e78160>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Notebook sessions cannot wait\n"
]
}
],
"source": [
"import fiftyone as fo\n",
"import fiftyone.brain as fob\n",
"import numpy as np\n",
"from sklearn.mixture import GaussianMixture\n",
"import json\n",
"\n",
"DATASET_NAME = \"mock\"\n",
"\n",
"json_path = \"./embeddings_factures_osteopathie_1k.json\"\n",
"\n",
"with open(json_path, \"r\") as file:\n",
" embedding_data = json.load(file)\n",
"\n",
"file_paths = []\n",
"embeddings = []\n",
"for i, record in enumerate(embedding_data):\n",
" file_paths.append(record.get(\"filepath\"))\n",
" embeddings.append(record.get(\"embedding\"))\n",
"\n",
"if DATASET_NAME in fo.list_datasets():\n",
" dataset = fo.load_dataset(DATASET_NAME)\n",
" dataset.delete()\n",
"dataset = fo.Dataset(DATASET_NAME)\n",
"\n",
"# Add samples to the dataset\n",
"samples = [fo.Sample(filepath=p) for p in file_paths]\n",
"dataset.add_samples(samples)\n",
"\n",
"# Building Gaussian mixture model (GMM)\n",
"\n",
"n_gaussians = 50\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"n_gaussians = 100\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"# --- Visualize the Embeddings with UMAP ---\n",
"# This will compute a 2D representation of your embeddings\n",
"# for visualization.\n",
"res = fob.compute_visualization(\n",
" dataset,\n",
" embeddings=embeddings,\n",
" brain_key=\"qwen_vision_viz\",\n",
" method=\"tsne\",\n",
" verbose=True\n",
")\n",
"dataset.set_values(\"qwen_umap\", res.current_points)\n",
"\n",
"print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
"session = fo.launch_app(dataset)\n",
"session.wait()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sesame",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

188
src/embedding_extraction.py Normal file
View File

@ -0,0 +1,188 @@
import torch
from transformers import AutoModel, AutoProcessor
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import argparse
import json
from typing import Callable, List, Tuple
def get_layoutlm_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],
device: str
) -> Tuple[np.ndarray, List[str]]:
"""
Processes a batch of images using a LayoutLM-like model and extracts their embeddings.
This function can be replaced with another one that follows the same signature
to support different models or embedding strategies.
Args:
model: The loaded Hugging Face model.
processor: The loaded Hugging Face processor.
image_paths: A list of file paths for the images in the batch.
device: The device to run the model on ('cpu', 'cuda').
Returns:
A tuple containing:
- A numpy array of the extracted embeddings.
- A list of the valid file paths that were successfully processed.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
images_pil.append(Image.open(path).convert("RGB"))
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
inputs = processor(
images=images_pil,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
# Forward pass to get model outputs
outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype))
# We use the embedding of the [CLS] token as the document representation
embeddings = outputs.last_hidden_state[:, 0, :]
return embeddings.cpu().numpy(), valid_paths
def get_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],
device: str
) -> Tuple[np.ndarray, List[str]]:
"""
Processes a batch of images using a LayoutLM-like model and extracts their embeddings.
This function can be replaced with another one that follows the same signature
to support different models or embedding strategies.
Args:
model: The loaded Hugging Face model.
processor: The loaded Hugging Face processor.
image_paths: A list of file paths for the images in the batch.
device: The device to run the model on ('cpu', 'cuda').
Returns:
A tuple containing:
- A numpy array of the extracted embeddings.
- A list of the valid file paths that were successfully processed.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
# The processor expects PIL images in RGB format
images_pil.append(Image.open(path).convert("RGB"))
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
# For pure vision feature extraction, we can provide an empty text prompt.
# The processor handles tokenizing text and preparing images.
inputs = processor(
text=[""] * len(images_pil),
images=images_pil,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
# Get the vision embeddings from the model's vision tower
vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw'])
# We'll use the pooled output as the embedding
embeddings = vision_outputs
return embeddings.to(torch.float16).cpu().numpy()
def run_extraction(
model_name: str,
image_dir: str,
batch_size: int,
device: str,
output_file: str,
embedding_extractor: Callable
):
"""
Loads a model and processes all images in a directory to save their embeddings.
"""
print(f"Using device: {device}")
# Load the model and processor
print(f"Loading model: {model_name}")
model = AutoModel.from_pretrained(model_name).to(device)
processor = AutoProcessor.from_pretrained(model_name)
image_files = [os.path.join(image_dir, f) for f in sorted(os.listdir(image_dir)) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(f"Found {len(image_files)} images to process.")
with open(output_file, "w") as f:
f.write("[\n")
is_first_entry = True
for i in tqdm(range(0, len(image_files), batch_size), desc="Extracting Embeddings"):
batch_paths = image_files[i:i+batch_size]
# Call the provided embedding extractor function
batch_embeddings, valid_paths = embedding_extractor(model, processor, batch_paths, device)
if len(valid_paths) == 0:
continue
embeddings_list = [emb.tolist() for emb in batch_embeddings]
for path, emb in zip(valid_paths, embeddings_list):
if not is_first_entry:
f.write(",\n")
json.dump({"filepath": path, "embedding": emb}, f, indent=4)
is_first_entry = False
f.write("\n]\n")
print(f"Embedding extraction complete. Results saved to {output_file}")
def main():
"""
Main function to parse command-line arguments and start the extraction process.
"""
parser = argparse.ArgumentParser(description="Extract document image embeddings using a transformer model.")
parser.add_argument("--model_name", type=str, default="microsoft/layoutlmv3-base", help="Hugging Face model name.")
parser.add_argument("--image_dir", type=str, required=True, help="Directory containing the images to process.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing.")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0', 'cpu'). Autodetects if not set.")
parser.add_argument("--output_file", type=str, default="embeddings.json", help="File path to save the output JSON.")
args = parser.parse_args()
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# The `get_layoutlm_image_embeddings` function is passed here.
# You can define a different function and pass it instead to change the behavior.
run_extraction(
model_name=args.model_name,
image_dir=args.image_dir,
batch_size=args.batch_size,
device=device,
output_file=args.output_file,
embedding_extractor=get_layoutlm_image_embeddings
)
if __name__ == "__main__":
main()