{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "59f8a415", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel\n", "from qwen_vl_utils import process_vision_info\n", "from PIL import Image\n", "import os\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "# --- Configuration ---\n", "MODEL_NAME = \"microsoft/layoutlmv3-base\" # You can choose other model sizes\n", "IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n", "BATCH_SIZE = 32\n", "# --- End Configuration ---\n", "\n", "# Check for GPU availability\n", "device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using device: {device}\")\n", "\n", "# Load the model and processor\n", "model = AutoModel.from_pretrained(\n", " MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\" # , attn_implementation=\"flash_attention_2\",\n", ")\n", "processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "cdfdab0e", "metadata": {}, "outputs": [], "source": [ "def get_image_embeddings(image_paths):\n", " \"\"\"\n", " Processes a batch of images and extracts their embeddings.\n", " \"\"\"\n", " images_pil = []\n", " valid_paths = []\n", " for path in image_paths:\n", " if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n", " try:\n", " # The processor expects PIL images in RGB format\n", " images_pil.append(Image.open(path).convert(\"RGB\"))\n", " valid_paths.append(path)\n", " except Exception as e:\n", " print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n", "\n", " if not images_pil:\n", " return np.array([]), []\n", "\n", " # For pure vision feature extraction, we can provide an empty text prompt.\n", " # The processor handles tokenizing text and preparing images.\n", " inputs = processor(\n", " # text=[\"\"] * len(images_pil),\n", " images=images_pil,\n", " padding=True,\n", " return_tensors=\"pt\"\n", " ).to(device)\n", "\n", " with torch.no_grad():\n", " # Get the vision embeddings from the model's vision tower\n", " vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])\n", " # We'll use the pooled output as the embedding\n", " embeddings = vision_outputs[0][:,0,:]\n", "\n", " return embeddings.to(torch.float16).cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "cdaebb7b", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# --- Process all images in the directory ---\n", "image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n", "all_embeddings = []\n", "filepaths = []\n", "\n", "with open(\"embeddings_factures_ostepoathie_1k.json\", \"w\") as f:\n", " f.write(\"[\\n\")\n", " first = True\n", " for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n", " batch_paths = image_files[i:i+BATCH_SIZE]\n", " batch_embeddings = get_image_embeddings(batch_paths)\n", " embeddings_list = [emb.tolist() for emb in batch_embeddings]\n", " for path, emb in zip(batch_paths, embeddings_list):\n", " if not first:\n", " f.write(\",\\n\")\n", " json.dump({\"filepath\": path, \"embedding\": emb}, f)\n", " first = False\n", " f.write(\"\\n]\\n\")\n", "\n", "print(\"Embeddings extracted and saved.\")\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "27fea4f3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/nguyendc/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 100% |███████████████| 1090/1090 [182.4ms elapsed, 0s remaining, 6.0K samples/s] \n" ] } ], "source": [ "import fiftyone as fo\n", "import fiftyone.brain as fob\n", "import numpy as np\n", "from sklearn.mixture import GaussianMixture\n", "import json\n", "\n", "DATASET_NAME = \"mock\"\n", "\n", "json_path = \"./embeddings_factures_osteopathie_1k.json\"\n", "\n", "with open(json_path, \"r\") as file:\n", " embedding_data = json.load(file)\n", "\n", "file_paths = []\n", "embeddings = []\n", "for i, record in enumerate(embedding_data):\n", " file_paths.append(record.get(\"filepath\"))\n", " embeddings.append(record.get(\"embedding\"))\n", "\n", "if DATASET_NAME in fo.list_datasets():\n", " dataset = fo.load_dataset(DATASET_NAME)\n", " dataset.delete()\n", "dataset = fo.Dataset(DATASET_NAME)\n", "\n", "# Add samples to the dataset\n", "samples = [fo.Sample(filepath=p) for p in file_paths]\n", "dataset.add_samples(samples)\n", "\n", "# Building Gaussian mixture model (GMM)\n", "\n", "n_gaussians = 50\n", "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", "gmm.fit(embeddings)\n", "cluster_labels = gmm.predict(embeddings)\n", "\n", "# Adding labeled embeddings to visulization\n", "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", "for sample, label in zip(dataset, cluster_labels):\n", " sample[\"gmm_cluster_50_gaussians\"] = int(label)\n", " sample.save()\n", "\n", "# n_gaussians = 100\n", "# gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", "# gmm.fit(embeddings)\n", "# cluster_labels = gmm.predict(embeddings)\n", "\n", "# # Adding labeled embeddings to visulization\n", "# dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", "# for sample, label in zip(dataset, cluster_labels):\n", "# sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", "# sample.save()\n", "\n", "# # --- Visualize the Embeddings with UMAP ---\n", "# # This will compute a 2D representation of your embeddings\n", "# # for visualization.\n", "# res = fob.compute_visualization(\n", "# dataset,\n", "# embeddings=embeddings,\n", "# brain_key=\"qwen_vision_viz\",\n", "# method=\"tsne\",\n", "# verbose=True\n", "# )\n", "# dataset.set_values(\"qwen_umap\", res.current_points)\n", "\n", "# print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", "# session = fo.launch_app(dataset)\n", "# session.wait()" ] }, { "cell_type": "code", "execution_count": null, "id": "a8c23ca3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[ 2.22716806e+06 -1.55841522e+05 1.20176465e+04 ... 7.78593400e+03\n", " 1.90060269e+03 -1.81817591e+02]\n", " [-1.55841522e+05 2.45425677e+06 9.82699914e+04 ... -5.37744672e+04\n", " -1.14976214e+05 1.95498686e+05]\n", " [ 1.20176465e+04 9.82699914e+04 2.05845769e+06 ... -6.42761657e+04\n", " -1.57428613e+03 1.90527654e+05]\n", " ...\n", " [ 7.78593400e+03 -5.37744672e+04 -6.42761657e+04 ... 2.19897397e+06\n", " 9.79439501e+04 2.63393791e+05]\n", " [ 1.90060269e+03 -1.14976214e+05 -1.57428613e+03 ... 9.79439501e+04\n", " 2.09478614e+06 1.35758244e+05]\n", " [-1.81817591e+02 1.95498686e+05 1.90527654e+05 ... 2.63393791e+05\n", " 1.35758244e+05 2.12691079e+06]]\n", "\n", " [[ 2.83658946e+06 -1.34780282e+05 6.49580505e+04 ... 8.54367930e+04\n", " 2.68631019e+04 -2.00238083e+03]\n", " [-1.34780282e+05 3.16043262e+06 1.74730403e+05 ... -1.52450289e+04\n", " -1.33867983e+05 2.33287605e+05]\n", " [ 6.49580505e+04 1.74730403e+05 2.62079599e+06 ... -1.34705133e+05\n", " -3.42729631e+03 2.52694121e+05]\n", " ...\n", " [ 8.54367930e+04 -1.52450289e+04 -1.34705133e+05 ... 2.83179172e+06\n", " 1.20704061e+05 3.58186648e+05]\n", " [ 2.68631019e+04 -1.33867983e+05 -3.42729631e+03 ... 1.20704061e+05\n", " 2.61548713e+06 2.01096011e+05]\n", " [-2.00238083e+03 2.33287605e+05 2.52694121e+05 ... 3.58186648e+05\n", " 2.01096011e+05 2.68561973e+06]]\n", "\n", " [[ 2.46084057e+06 -1.61990924e+05 2.05209921e+04 ... 5.80037219e+04\n", " -2.11338740e+03 -1.79719159e+04]\n", " [-1.61990924e+05 2.77033377e+06 1.21300499e+05 ... -2.88489495e+04\n", " -1.56314630e+05 1.81241062e+05]\n", " [ 2.05209921e+04 1.21300499e+05 2.26999115e+06 ... -1.22265990e+05\n", " -1.16935046e+04 1.98900231e+05]\n", " ...\n", " [ 5.80037219e+04 -2.88489495e+04 -1.22265990e+05 ... 2.44690178e+06\n", " 9.92107868e+04 2.87651396e+05]\n", " [-2.11338740e+03 -1.56314630e+05 -1.16935046e+04 ... 9.92107868e+04\n", " 2.24557593e+06 1.34111998e+05]\n", " [-1.79719159e+04 1.81241062e+05 1.98900231e+05 ... 2.87651396e+05\n", " 1.34111998e+05 2.34916522e+06]]\n", "\n", " ...\n", "\n", " [[ 1.92829727e+06 -1.59315776e+05 1.58992088e+03 ... 2.85003295e+04\n", " 4.22064899e+03 -3.58290465e+04]\n", " [-1.59315776e+05 2.20558282e+06 8.03460647e+04 ... -4.04459318e+04\n", " -1.07734468e+05 1.56172361e+05]\n", " [ 1.58992088e+03 8.03460647e+04 1.83371422e+06 ... -1.11957845e+05\n", " 2.80346243e+03 1.49548841e+05]\n", " ...\n", " [ 2.85003295e+04 -4.04459318e+04 -1.11957845e+05 ... 1.91110989e+06\n", " 9.80180683e+04 1.85174267e+05]\n", " [ 4.22064899e+03 -1.07734468e+05 2.80346243e+03 ... 9.80180683e+04\n", " 1.77703462e+06 8.43241833e+04]\n", " [-3.58290465e+04 1.56172361e+05 1.49548841e+05 ... 1.85174267e+05\n", " 8.43241833e+04 1.89600377e+06]]\n", "\n", " [[ 3.47995050e+06 -2.33949522e+04 2.34412434e+05 ... 1.74855689e+05\n", " 1.42301226e+05 1.03626216e+05]\n", " [-2.33949522e+04 3.91344835e+06 3.68391546e+05 ... 1.48516088e+05\n", " -5.32260301e+04 4.19584565e+05]\n", " [ 2.34412434e+05 3.68391546e+05 3.16329032e+06 ... -7.83149391e+04\n", " 1.04717741e+05 4.01535013e+05]\n", " ...\n", " [ 1.74855689e+05 1.48516088e+05 -7.83149391e+04 ... 3.45461426e+06\n", " 1.69208897e+05 5.00179297e+05]\n", " [ 1.42301226e+05 -5.32260301e+04 1.04717741e+05 ... 1.69208897e+05\n", " 3.23120264e+06 3.85596283e+05]\n", " [ 1.03626216e+05 4.19584565e+05 4.01535013e+05 ... 5.00179297e+05\n", " 3.85596283e+05 3.23650330e+06]]\n", "\n", " [[ 2.87968337e+06 -1.54796543e+05 6.48388780e+04 ... 6.24845766e+04\n", " 4.24673639e+04 -9.91376939e+03]\n", " [-1.54796543e+05 3.29781697e+06 1.82272830e+05 ... 9.13276300e+03\n", " -1.04968382e+05 2.54428804e+05]\n", " [ 6.48388780e+04 1.82272830e+05 2.68323883e+06 ... -1.43183305e+05\n", " 1.35349522e+04 2.71469510e+05]\n", " ...\n", " [ 6.24845766e+04 9.13276300e+03 -1.43183305e+05 ... 2.94868837e+06\n", " 1.25676092e+05 3.72271111e+05]\n", " [ 4.24673639e+04 -1.04968382e+05 1.35349522e+04 ... 1.25676092e+05\n", " 2.69905870e+06 2.40775248e+05]\n", " [-9.91376939e+03 2.54428804e+05 2.71469510e+05 ... 3.72271111e+05\n", " 2.40775248e+05 2.76807176e+06]]]\n" ] } ], "source": [ "from sklearn.metrics import silhouette_samples\n", "print(silhouette_samples(np.array(embeddings), labels=cluster_labels))" ] } ], "metadata": { "kernelspec": { "display_name": "sesame", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }