embedding-clustering/clustering_layoutlm.ipynb

327 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "59f8a415",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel\n",
"from qwen_vl_utils import process_vision_info\n",
"from PIL import Image\n",
"import os\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"# --- Configuration ---\n",
"MODEL_NAME = \"microsoft/layoutlmv3-base\" # You can choose other model sizes\n",
"IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n",
"BATCH_SIZE = 32\n",
"# --- End Configuration ---\n",
"\n",
"# Check for GPU availability\n",
"device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Load the model and processor\n",
"model = AutoModel.from_pretrained(\n",
" MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\" # , attn_implementation=\"flash_attention_2\",\n",
")\n",
"processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdfdab0e",
"metadata": {},
"outputs": [],
"source": [
"def get_image_embeddings(image_paths):\n",
" \"\"\"\n",
" Processes a batch of images and extracts their embeddings.\n",
" \"\"\"\n",
" images_pil = []\n",
" valid_paths = []\n",
" for path in image_paths:\n",
" if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
" try:\n",
" # The processor expects PIL images in RGB format\n",
" images_pil.append(Image.open(path).convert(\"RGB\"))\n",
" valid_paths.append(path)\n",
" except Exception as e:\n",
" print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n",
"\n",
" if not images_pil:\n",
" return np.array([]), []\n",
"\n",
" # For pure vision feature extraction, we can provide an empty text prompt.\n",
" # The processor handles tokenizing text and preparing images.\n",
" inputs = processor(\n",
" # text=[\"\"] * len(images_pil),\n",
" images=images_pil,\n",
" padding=True,\n",
" return_tensors=\"pt\"\n",
" ).to(device)\n",
"\n",
" with torch.no_grad():\n",
" # Get the vision embeddings from the model's vision tower\n",
" vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])\n",
" # We'll use the pooled output as the embedding\n",
" embeddings = vision_outputs[0][:,0,:]\n",
"\n",
" return embeddings.to(torch.float16).cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdaebb7b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# --- Process all images in the directory ---\n",
"image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
"all_embeddings = []\n",
"filepaths = []\n",
"\n",
"with open(\"embeddings_factures_ostepoathie_1k.json\", \"w\") as f:\n",
" f.write(\"[\\n\")\n",
" first = True\n",
" for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n",
" batch_paths = image_files[i:i+BATCH_SIZE]\n",
" batch_embeddings = get_image_embeddings(batch_paths)\n",
" embeddings_list = [emb.tolist() for emb in batch_embeddings]\n",
" for path, emb in zip(batch_paths, embeddings_list):\n",
" if not first:\n",
" f.write(\",\\n\")\n",
" json.dump({\"filepath\": path, \"embedding\": emb}, f)\n",
" first = False\n",
" f.write(\"\\n]\\n\")\n",
"\n",
"print(\"Embeddings extracted and saved.\")\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "27fea4f3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/nguyendc/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 100% |███████████████| 1090/1090 [182.4ms elapsed, 0s remaining, 6.0K samples/s] \n"
]
}
],
"source": [
"import fiftyone as fo\n",
"import fiftyone.brain as fob\n",
"import numpy as np\n",
"from sklearn.mixture import GaussianMixture\n",
"import json\n",
"\n",
"DATASET_NAME = \"mock\"\n",
"\n",
"json_path = \"./embeddings_factures_osteopathie_1k.json\"\n",
"\n",
"with open(json_path, \"r\") as file:\n",
" embedding_data = json.load(file)\n",
"\n",
"file_paths = []\n",
"embeddings = []\n",
"for i, record in enumerate(embedding_data):\n",
" file_paths.append(record.get(\"filepath\"))\n",
" embeddings.append(record.get(\"embedding\"))\n",
"\n",
"if DATASET_NAME in fo.list_datasets():\n",
" dataset = fo.load_dataset(DATASET_NAME)\n",
" dataset.delete()\n",
"dataset = fo.Dataset(DATASET_NAME)\n",
"\n",
"# Add samples to the dataset\n",
"samples = [fo.Sample(filepath=p) for p in file_paths]\n",
"dataset.add_samples(samples)\n",
"\n",
"# Building Gaussian mixture model (GMM)\n",
"\n",
"n_gaussians = 50\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"# n_gaussians = 100\n",
"# gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"# gmm.fit(embeddings)\n",
"# cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# # Adding labeled embeddings to visulization\n",
"# dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"# for sample, label in zip(dataset, cluster_labels):\n",
"# sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
"# sample.save()\n",
"\n",
"# # --- Visualize the Embeddings with UMAP ---\n",
"# # This will compute a 2D representation of your embeddings\n",
"# # for visualization.\n",
"# res = fob.compute_visualization(\n",
"# dataset,\n",
"# embeddings=embeddings,\n",
"# brain_key=\"qwen_vision_viz\",\n",
"# method=\"tsne\",\n",
"# verbose=True\n",
"# )\n",
"# dataset.set_values(\"qwen_umap\", res.current_points)\n",
"\n",
"# print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
"# session = fo.launch_app(dataset)\n",
"# session.wait()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8c23ca3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 2.22716806e+06 -1.55841522e+05 1.20176465e+04 ... 7.78593400e+03\n",
" 1.90060269e+03 -1.81817591e+02]\n",
" [-1.55841522e+05 2.45425677e+06 9.82699914e+04 ... -5.37744672e+04\n",
" -1.14976214e+05 1.95498686e+05]\n",
" [ 1.20176465e+04 9.82699914e+04 2.05845769e+06 ... -6.42761657e+04\n",
" -1.57428613e+03 1.90527654e+05]\n",
" ...\n",
" [ 7.78593400e+03 -5.37744672e+04 -6.42761657e+04 ... 2.19897397e+06\n",
" 9.79439501e+04 2.63393791e+05]\n",
" [ 1.90060269e+03 -1.14976214e+05 -1.57428613e+03 ... 9.79439501e+04\n",
" 2.09478614e+06 1.35758244e+05]\n",
" [-1.81817591e+02 1.95498686e+05 1.90527654e+05 ... 2.63393791e+05\n",
" 1.35758244e+05 2.12691079e+06]]\n",
"\n",
" [[ 2.83658946e+06 -1.34780282e+05 6.49580505e+04 ... 8.54367930e+04\n",
" 2.68631019e+04 -2.00238083e+03]\n",
" [-1.34780282e+05 3.16043262e+06 1.74730403e+05 ... -1.52450289e+04\n",
" -1.33867983e+05 2.33287605e+05]\n",
" [ 6.49580505e+04 1.74730403e+05 2.62079599e+06 ... -1.34705133e+05\n",
" -3.42729631e+03 2.52694121e+05]\n",
" ...\n",
" [ 8.54367930e+04 -1.52450289e+04 -1.34705133e+05 ... 2.83179172e+06\n",
" 1.20704061e+05 3.58186648e+05]\n",
" [ 2.68631019e+04 -1.33867983e+05 -3.42729631e+03 ... 1.20704061e+05\n",
" 2.61548713e+06 2.01096011e+05]\n",
" [-2.00238083e+03 2.33287605e+05 2.52694121e+05 ... 3.58186648e+05\n",
" 2.01096011e+05 2.68561973e+06]]\n",
"\n",
" [[ 2.46084057e+06 -1.61990924e+05 2.05209921e+04 ... 5.80037219e+04\n",
" -2.11338740e+03 -1.79719159e+04]\n",
" [-1.61990924e+05 2.77033377e+06 1.21300499e+05 ... -2.88489495e+04\n",
" -1.56314630e+05 1.81241062e+05]\n",
" [ 2.05209921e+04 1.21300499e+05 2.26999115e+06 ... -1.22265990e+05\n",
" -1.16935046e+04 1.98900231e+05]\n",
" ...\n",
" [ 5.80037219e+04 -2.88489495e+04 -1.22265990e+05 ... 2.44690178e+06\n",
" 9.92107868e+04 2.87651396e+05]\n",
" [-2.11338740e+03 -1.56314630e+05 -1.16935046e+04 ... 9.92107868e+04\n",
" 2.24557593e+06 1.34111998e+05]\n",
" [-1.79719159e+04 1.81241062e+05 1.98900231e+05 ... 2.87651396e+05\n",
" 1.34111998e+05 2.34916522e+06]]\n",
"\n",
" ...\n",
"\n",
" [[ 1.92829727e+06 -1.59315776e+05 1.58992088e+03 ... 2.85003295e+04\n",
" 4.22064899e+03 -3.58290465e+04]\n",
" [-1.59315776e+05 2.20558282e+06 8.03460647e+04 ... -4.04459318e+04\n",
" -1.07734468e+05 1.56172361e+05]\n",
" [ 1.58992088e+03 8.03460647e+04 1.83371422e+06 ... -1.11957845e+05\n",
" 2.80346243e+03 1.49548841e+05]\n",
" ...\n",
" [ 2.85003295e+04 -4.04459318e+04 -1.11957845e+05 ... 1.91110989e+06\n",
" 9.80180683e+04 1.85174267e+05]\n",
" [ 4.22064899e+03 -1.07734468e+05 2.80346243e+03 ... 9.80180683e+04\n",
" 1.77703462e+06 8.43241833e+04]\n",
" [-3.58290465e+04 1.56172361e+05 1.49548841e+05 ... 1.85174267e+05\n",
" 8.43241833e+04 1.89600377e+06]]\n",
"\n",
" [[ 3.47995050e+06 -2.33949522e+04 2.34412434e+05 ... 1.74855689e+05\n",
" 1.42301226e+05 1.03626216e+05]\n",
" [-2.33949522e+04 3.91344835e+06 3.68391546e+05 ... 1.48516088e+05\n",
" -5.32260301e+04 4.19584565e+05]\n",
" [ 2.34412434e+05 3.68391546e+05 3.16329032e+06 ... -7.83149391e+04\n",
" 1.04717741e+05 4.01535013e+05]\n",
" ...\n",
" [ 1.74855689e+05 1.48516088e+05 -7.83149391e+04 ... 3.45461426e+06\n",
" 1.69208897e+05 5.00179297e+05]\n",
" [ 1.42301226e+05 -5.32260301e+04 1.04717741e+05 ... 1.69208897e+05\n",
" 3.23120264e+06 3.85596283e+05]\n",
" [ 1.03626216e+05 4.19584565e+05 4.01535013e+05 ... 5.00179297e+05\n",
" 3.85596283e+05 3.23650330e+06]]\n",
"\n",
" [[ 2.87968337e+06 -1.54796543e+05 6.48388780e+04 ... 6.24845766e+04\n",
" 4.24673639e+04 -9.91376939e+03]\n",
" [-1.54796543e+05 3.29781697e+06 1.82272830e+05 ... 9.13276300e+03\n",
" -1.04968382e+05 2.54428804e+05]\n",
" [ 6.48388780e+04 1.82272830e+05 2.68323883e+06 ... -1.43183305e+05\n",
" 1.35349522e+04 2.71469510e+05]\n",
" ...\n",
" [ 6.24845766e+04 9.13276300e+03 -1.43183305e+05 ... 2.94868837e+06\n",
" 1.25676092e+05 3.72271111e+05]\n",
" [ 4.24673639e+04 -1.04968382e+05 1.35349522e+04 ... 1.25676092e+05\n",
" 2.69905870e+06 2.40775248e+05]\n",
" [-9.91376939e+03 2.54428804e+05 2.71469510e+05 ... 3.72271111e+05\n",
" 2.40775248e+05 2.76807176e+06]]]\n"
]
}
],
"source": [
"from sklearn.metrics import silhouette_samples\n",
"print(silhouette_samples(np.array(embeddings), labels=cluster_labels))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sesame",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}