commit 65adb5d4ba7cf151766ec8ce1bb35cd6513e9f1a Author: lphatnguyen Date: Thu Jul 10 09:04:29 2025 +0000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d8e64e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.json +extracted_images/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/clustering_example.ipynb b/clustering_example.ipynb new file mode 100644 index 0000000..3fd5e3b --- /dev/null +++ b/clustering_example.ipynb @@ -0,0 +1,205 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "59f8a415", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n", + "from qwen_vl_utils import process_vision_info\n", + "from PIL import Image\n", + "import os\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "# --- Configuration ---\n", + "MODEL_NAME = \"Qwen/Qwen2.5-VL-3B-Instruct\" # You can choose other model sizes\n", + "IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n", + "BATCH_SIZE = 4\n", + "# --- End Configuration ---\n", + "\n", + "# Check for GPU availability\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Load the model and processor\n", + "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n", + " MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\", attn_implementation=\"flash_attention_2\",\n", + ")\n", + "processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdfdab0e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_image_embeddings(image_paths):\n", + " \"\"\"\n", + " Processes a batch of images and extracts their embeddings.\n", + " \"\"\"\n", + " images_pil = []\n", + " valid_paths = []\n", + " for path in image_paths:\n", + " if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n", + " try:\n", + " # The processor expects PIL images in RGB format\n", + " images_pil.append(Image.open(path).convert(\"RGB\"))\n", + " valid_paths.append(path)\n", + " except Exception as e:\n", + " print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n", + "\n", + " if not images_pil:\n", + " return np.array([]), []\n", + "\n", + " # For pure vision feature extraction, we can provide an empty text prompt.\n", + " # The processor handles tokenizing text and preparing images.\n", + " inputs = processor(\n", + " text=[\"\"] * len(images_pil),\n", + " images=images_pil,\n", + " padding=True,\n", + " return_tensors=\"pt\"\n", + " ).to(device)\n", + "\n", + " with torch.no_grad():\n", + " # Get the vision embeddings from the model's vision tower\n", + " vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw'])\n", + " # We'll use the pooled output as the embedding\n", + " embeddings = vision_outputs\n", + "\n", + " return embeddings.to(torch.float16).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdaebb7b", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# --- Process all images in the directory ---\n", + "image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n", + "all_embeddings = []\n", + "filepaths = []\n", + "\n", + "with open(\"embeddings_factures_osteopathie_1k_qwen.json\", \"w\") as f:\n", + " f.write(\"[\\n\")\n", + " first = True\n", + " for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n", + " batch_paths = image_files[i:i+BATCH_SIZE]\n", + " batch_embeddings = get_image_embeddings(batch_paths)\n", + " embeddings_list = [emb.tolist() for emb in batch_embeddings]\n", + " for path, emb in zip(batch_paths, embeddings_list):\n", + " if not first:\n", + " f.write(\",\\n\")\n", + " json.dump({\"filepath\": path, \"embedding\": emb}, f)\n", + " first = False\n", + " f.write(\"\\n]\\n\")\n", + "\n", + "print(\"Embeddings extracted and saved.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27fea4f3", + "metadata": {}, + "outputs": [], + "source": [ + "import fiftyone as fo\n", + "import fiftyone.brain as fob\n", + "import numpy as np\n", + "from sklearn.mixture import GaussianMixture\n", + "import json\n", + "\n", + "DATASET_NAME = \"mock\"\n", + "\n", + "json_path = \"./embeddings_factures_osteopathie_1k_qwen.json\"\n", + "\n", + "with open(json_path, \"r\") as file:\n", + " embedding_data = json.load(file)\n", + "\n", + "file_paths = []\n", + "embeddings = []\n", + "for i, record in enumerate(embedding_data):\n", + " file_paths.append(record.get(\"filepath\"))\n", + " embeddings.append(record.get(\"embedding\"))\n", + "\n", + "if DATASET_NAME in fo.list_datasets():\n", + " dataset = fo.load_dataset(DATASET_NAME)\n", + " dataset.delete()\n", + "dataset = fo.Dataset(DATASET_NAME)\n", + "\n", + "# Add samples to the dataset\n", + "samples = [fo.Sample(filepath=p) for p in file_paths]\n", + "dataset.add_samples(samples)\n", + "\n", + "# Building Gaussian mixture model (GMM)\n", + "n_gaussians = 50\n", + "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", + "gmm.fit(embeddings)\n", + "cluster_labels = gmm.predict(embeddings)\n", + "\n", + "# Adding labeled embeddings to visulization\n", + "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", + "for sample, label in zip(dataset, cluster_labels):\n", + " sample[\"gmm_cluster_50_gaussians\"] = int(label)\n", + " sample.save()\n", + "\n", + "n_gaussians = 200\n", + "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", + "gmm.fit(embeddings)\n", + "cluster_labels = gmm.predict(embeddings)\n", + "\n", + "# Adding labeled embeddings to visulization\n", + "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", + "for sample, label in zip(dataset, cluster_labels):\n", + " sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", + " sample.save()\n", + "\n", + "# --- Visualize the Embeddings with UMAP ---\n", + "# This will compute a 2D representation of your embeddings\n", + "# for visualization.\n", + "res = fob.compute_visualization(\n", + " dataset,\n", + " embeddings=embeddings,\n", + " brain_key=\"qwen_vision_viz\",\n", + " method=\"tsne\",\n", + " verbose=True\n", + ")\n", + "dataset.set_values(\"qwen_umap\", res.current_points)\n", + "\n", + "print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", + "session = fo.launch_app(dataset)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sesame", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/clustering_layoutlm.ipynb b/clustering_layoutlm.ipynb new file mode 100644 index 0000000..cfed077 --- /dev/null +++ b/clustering_layoutlm.ipynb @@ -0,0 +1,282 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "59f8a415", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel\n", + "from qwen_vl_utils import process_vision_info\n", + "from PIL import Image\n", + "import os\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "# --- Configuration ---\n", + "MODEL_NAME = \"microsoft/layoutlmv3-base\" # You can choose other model sizes\n", + "IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n", + "BATCH_SIZE = 32\n", + "# --- End Configuration ---\n", + "\n", + "# Check for GPU availability\n", + "device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Load the model and processor\n", + "model = AutoModel.from_pretrained(\n", + " MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\" # , attn_implementation=\"flash_attention_2\",\n", + ")\n", + "processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdfdab0e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_image_embeddings(image_paths):\n", + " \"\"\"\n", + " Processes a batch of images and extracts their embeddings.\n", + " \"\"\"\n", + " images_pil = []\n", + " valid_paths = []\n", + " for path in image_paths:\n", + " if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n", + " try:\n", + " # The processor expects PIL images in RGB format\n", + " images_pil.append(Image.open(path).convert(\"RGB\"))\n", + " valid_paths.append(path)\n", + " except Exception as e:\n", + " print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n", + "\n", + " if not images_pil:\n", + " return np.array([]), []\n", + "\n", + " # For pure vision feature extraction, we can provide an empty text prompt.\n", + " # The processor handles tokenizing text and preparing images.\n", + " inputs = processor(\n", + " # text=[\"\"] * len(images_pil),\n", + " images=images_pil,\n", + " padding=True,\n", + " return_tensors=\"pt\"\n", + " ).to(device)\n", + "\n", + " with torch.no_grad():\n", + " # Get the vision embeddings from the model's vision tower\n", + " vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])\n", + " # We'll use the pooled output as the embedding\n", + " embeddings = vision_outputs[0][:,0,:]\n", + "\n", + " return embeddings.to(torch.float16).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdaebb7b", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# --- Process all images in the directory ---\n", + "image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n", + "all_embeddings = []\n", + "filepaths = []\n", + "\n", + "with open(\"embeddings_factures_ostepoathie_1k.json\", \"w\") as f:\n", + " f.write(\"[\\n\")\n", + " first = True\n", + " for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n", + " batch_paths = image_files[i:i+BATCH_SIZE]\n", + " batch_embeddings = get_image_embeddings(batch_paths)\n", + " embeddings_list = [emb.tolist() for emb in batch_embeddings]\n", + " for path, emb in zip(batch_paths, embeddings_list):\n", + " if not first:\n", + " f.write(\",\\n\")\n", + " json.dump({\"filepath\": path, \"embedding\": emb}, f)\n", + " first = False\n", + " f.write(\"\\n]\\n\")\n", + "\n", + "print(\"Embeddings extracted and saved.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "27fea4f3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nguyendc/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n", + "Generating visualization...\n", + "[t-SNE] Computing 91 nearest neighbors...\n", + "[t-SNE] Indexed 1090 samples in 0.000s...\n", + "[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n", + "[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n", + "[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n", + "[t-SNE] Mean sigma: 1.117175\n", + "[t-SNE] Computed conditional probabilities in 0.018s\n", + "[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n", + "[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n", + "[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n", + "[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n", + "[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n", + "[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n", + "[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n", + "[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n", + "[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n", + "[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n", + "[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n", + "[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n", + "[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n", + "[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n", + "[t-SNE] KL divergence after 1000 iterations: 0.727630\n", + "UMAP visualization computed. Launch the app to see the plot.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Notebook sessions cannot wait\n" + ] + } + ], + "source": [ + "import fiftyone as fo\n", + "import fiftyone.brain as fob\n", + "import numpy as np\n", + "from sklearn.mixture import GaussianMixture\n", + "import json\n", + "\n", + "DATASET_NAME = \"mock\"\n", + "\n", + "json_path = \"./embeddings_factures_osteopathie_1k.json\"\n", + "\n", + "with open(json_path, \"r\") as file:\n", + " embedding_data = json.load(file)\n", + "\n", + "file_paths = []\n", + "embeddings = []\n", + "for i, record in enumerate(embedding_data):\n", + " file_paths.append(record.get(\"filepath\"))\n", + " embeddings.append(record.get(\"embedding\"))\n", + "\n", + "if DATASET_NAME in fo.list_datasets():\n", + " dataset = fo.load_dataset(DATASET_NAME)\n", + " dataset.delete()\n", + "dataset = fo.Dataset(DATASET_NAME)\n", + "\n", + "# Add samples to the dataset\n", + "samples = [fo.Sample(filepath=p) for p in file_paths]\n", + "dataset.add_samples(samples)\n", + "\n", + "# Building Gaussian mixture model (GMM)\n", + "\n", + "n_gaussians = 50\n", + "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", + "gmm.fit(embeddings)\n", + "cluster_labels = gmm.predict(embeddings)\n", + "\n", + "# Adding labeled embeddings to visulization\n", + "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", + "for sample, label in zip(dataset, cluster_labels):\n", + " sample[\"gmm_cluster_50_gaussians\"] = int(label)\n", + " sample.save()\n", + "\n", + "n_gaussians = 100\n", + "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", + "gmm.fit(embeddings)\n", + "cluster_labels = gmm.predict(embeddings)\n", + "\n", + "# Adding labeled embeddings to visulization\n", + "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", + "for sample, label in zip(dataset, cluster_labels):\n", + " sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", + " sample.save()\n", + "\n", + "# --- Visualize the Embeddings with UMAP ---\n", + "# This will compute a 2D representation of your embeddings\n", + "# for visualization.\n", + "res = fob.compute_visualization(\n", + " dataset,\n", + " embeddings=embeddings,\n", + " brain_key=\"qwen_vision_viz\",\n", + " method=\"tsne\",\n", + " verbose=True\n", + ")\n", + "dataset.set_values(\"qwen_umap\", res.current_points)\n", + "\n", + "print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", + "session = fo.launch_app(dataset)\n", + "session.wait()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sesame", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/embedding_extraction.py b/src/embedding_extraction.py new file mode 100644 index 0000000..4e3e2a3 --- /dev/null +++ b/src/embedding_extraction.py @@ -0,0 +1,188 @@ +import torch +from transformers import AutoModel, AutoProcessor +from PIL import Image +import os +import numpy as np +from tqdm import tqdm +import argparse +import json +from typing import Callable, List, Tuple + +def get_layoutlm_image_embeddings( + model: AutoModel, + processor: AutoProcessor, + image_paths: List[str], + device: str +) -> Tuple[np.ndarray, List[str]]: + """ + Processes a batch of images using a LayoutLM-like model and extracts their embeddings. + This function can be replaced with another one that follows the same signature + to support different models or embedding strategies. + + Args: + model: The loaded Hugging Face model. + processor: The loaded Hugging Face processor. + image_paths: A list of file paths for the images in the batch. + device: The device to run the model on ('cpu', 'cuda'). + + Returns: + A tuple containing: + - A numpy array of the extracted embeddings. + - A list of the valid file paths that were successfully processed. + """ + images_pil = [] + valid_paths = [] + for path in image_paths: + if path.lower().endswith(('.png', '.jpg', '.jpeg')): + try: + images_pil.append(Image.open(path).convert("RGB")) + valid_paths.append(path) + except Exception as e: + print(f"Warning: Could not load image {path}. Skipping. Error: {e}") + + if not images_pil: + return np.array([]), [] + + inputs = processor( + images=images_pil, + padding=True, + return_tensors="pt" + ).to(device) + + with torch.no_grad(): + # Forward pass to get model outputs + outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) + # We use the embedding of the [CLS] token as the document representation + embeddings = outputs.last_hidden_state[:, 0, :] + + return embeddings.cpu().numpy(), valid_paths + +def get_image_embeddings( + model: AutoModel, + processor: AutoProcessor, + image_paths: List[str], + device: str +) -> Tuple[np.ndarray, List[str]]: + """ + Processes a batch of images using a LayoutLM-like model and extracts their embeddings. + This function can be replaced with another one that follows the same signature + to support different models or embedding strategies. + + Args: + model: The loaded Hugging Face model. + processor: The loaded Hugging Face processor. + image_paths: A list of file paths for the images in the batch. + device: The device to run the model on ('cpu', 'cuda'). + + Returns: + A tuple containing: + - A numpy array of the extracted embeddings. + - A list of the valid file paths that were successfully processed. + """ + images_pil = [] + valid_paths = [] + for path in image_paths: + if path.lower().endswith(('.png', '.jpg', '.jpeg')): + try: + # The processor expects PIL images in RGB format + images_pil.append(Image.open(path).convert("RGB")) + valid_paths.append(path) + except Exception as e: + print(f"Warning: Could not load image {path}. Skipping. Error: {e}") + + if not images_pil: + return np.array([]), [] + + # For pure vision feature extraction, we can provide an empty text prompt. + # The processor handles tokenizing text and preparing images. + inputs = processor( + text=[""] * len(images_pil), + images=images_pil, + padding=True, + return_tensors="pt" + ).to(device) + + with torch.no_grad(): + # Get the vision embeddings from the model's vision tower + vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw']) + # We'll use the pooled output as the embedding + embeddings = vision_outputs + + return embeddings.to(torch.float16).cpu().numpy() + +def run_extraction( + model_name: str, + image_dir: str, + batch_size: int, + device: str, + output_file: str, + embedding_extractor: Callable +): + """ + Loads a model and processes all images in a directory to save their embeddings. + """ + print(f"Using device: {device}") + + # Load the model and processor + print(f"Loading model: {model_name}") + model = AutoModel.from_pretrained(model_name).to(device) + processor = AutoProcessor.from_pretrained(model_name) + + image_files = [os.path.join(image_dir, f) for f in sorted(os.listdir(image_dir)) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + print(f"Found {len(image_files)} images to process.") + + with open(output_file, "w") as f: + f.write("[\n") + is_first_entry = True + for i in tqdm(range(0, len(image_files), batch_size), desc="Extracting Embeddings"): + batch_paths = image_files[i:i+batch_size] + + # Call the provided embedding extractor function + batch_embeddings, valid_paths = embedding_extractor(model, processor, batch_paths, device) + + if len(valid_paths) == 0: + continue + + embeddings_list = [emb.tolist() for emb in batch_embeddings] + + for path, emb in zip(valid_paths, embeddings_list): + if not is_first_entry: + f.write(",\n") + json.dump({"filepath": path, "embedding": emb}, f, indent=4) + is_first_entry = False + f.write("\n]\n") + + print(f"Embedding extraction complete. Results saved to {output_file}") + +def main(): + """ + Main function to parse command-line arguments and start the extraction process. + """ + parser = argparse.ArgumentParser(description="Extract document image embeddings using a transformer model.") + + parser.add_argument("--model_name", type=str, default="microsoft/layoutlmv3-base", help="Hugging Face model name.") + parser.add_argument("--image_dir", type=str, required=True, help="Directory containing the images to process.") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing.") + parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0', 'cpu'). Autodetects if not set.") + parser.add_argument("--output_file", type=str, default="embeddings.json", help="File path to save the output JSON.") + + args = parser.parse_args() + + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + + # The `get_layoutlm_image_embeddings` function is passed here. + # You can define a different function and pass it instead to change the behavior. + run_extraction( + model_name=args.model_name, + image_dir=args.image_dir, + batch_size=args.batch_size, + device=device, + output_file=args.output_file, + embedding_extractor=get_layoutlm_image_embeddings + ) + +if __name__ == "__main__": + main() \ No newline at end of file