283 lines
11 KiB
Plaintext
283 lines
11 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "59f8a415",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModel\n",
|
||
|
"from qwen_vl_utils import process_vision_info\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"# --- Configuration ---\n",
|
||
|
"MODEL_NAME = \"microsoft/layoutlmv3-base\" # You can choose other model sizes\n",
|
||
|
"IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n",
|
||
|
"BATCH_SIZE = 32\n",
|
||
|
"# --- End Configuration ---\n",
|
||
|
"\n",
|
||
|
"# Check for GPU availability\n",
|
||
|
"device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
|
||
|
"print(f\"Using device: {device}\")\n",
|
||
|
"\n",
|
||
|
"# Load the model and processor\n",
|
||
|
"model = AutoModel.from_pretrained(\n",
|
||
|
" MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\" # , attn_implementation=\"flash_attention_2\",\n",
|
||
|
")\n",
|
||
|
"processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "cdfdab0e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_image_embeddings(image_paths):\n",
|
||
|
" \"\"\"\n",
|
||
|
" Processes a batch of images and extracts their embeddings.\n",
|
||
|
" \"\"\"\n",
|
||
|
" images_pil = []\n",
|
||
|
" valid_paths = []\n",
|
||
|
" for path in image_paths:\n",
|
||
|
" if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n",
|
||
|
" try:\n",
|
||
|
" # The processor expects PIL images in RGB format\n",
|
||
|
" images_pil.append(Image.open(path).convert(\"RGB\"))\n",
|
||
|
" valid_paths.append(path)\n",
|
||
|
" except Exception as e:\n",
|
||
|
" print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n",
|
||
|
"\n",
|
||
|
" if not images_pil:\n",
|
||
|
" return np.array([]), []\n",
|
||
|
"\n",
|
||
|
" # For pure vision feature extraction, we can provide an empty text prompt.\n",
|
||
|
" # The processor handles tokenizing text and preparing images.\n",
|
||
|
" inputs = processor(\n",
|
||
|
" # text=[\"\"] * len(images_pil),\n",
|
||
|
" images=images_pil,\n",
|
||
|
" padding=True,\n",
|
||
|
" return_tensors=\"pt\"\n",
|
||
|
" ).to(device)\n",
|
||
|
"\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" # Get the vision embeddings from the model's vision tower\n",
|
||
|
" vision_outputs = model.forward(pixel_values=inputs['pixel_values'].to(dtype=model.dtype)) # , grid_thw=inputs['image_grid_thw'])\n",
|
||
|
" # We'll use the pooled output as the embedding\n",
|
||
|
" embeddings = vision_outputs[0][:,0,:]\n",
|
||
|
"\n",
|
||
|
" return embeddings.to(torch.float16).cpu().numpy()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "cdaebb7b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import json\n",
|
||
|
"\n",
|
||
|
"# --- Process all images in the directory ---\n",
|
||
|
"image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
|
||
|
"all_embeddings = []\n",
|
||
|
"filepaths = []\n",
|
||
|
"\n",
|
||
|
"with open(\"embeddings_factures_ostepoathie_1k.json\", \"w\") as f:\n",
|
||
|
" f.write(\"[\\n\")\n",
|
||
|
" first = True\n",
|
||
|
" for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n",
|
||
|
" batch_paths = image_files[i:i+BATCH_SIZE]\n",
|
||
|
" batch_embeddings = get_image_embeddings(batch_paths)\n",
|
||
|
" embeddings_list = [emb.tolist() for emb in batch_embeddings]\n",
|
||
|
" for path, emb in zip(batch_paths, embeddings_list):\n",
|
||
|
" if not first:\n",
|
||
|
" f.write(\",\\n\")\n",
|
||
|
" json.dump({\"filepath\": path, \"embedding\": emb}, f)\n",
|
||
|
" first = False\n",
|
||
|
" f.write(\"\\n]\\n\")\n",
|
||
|
"\n",
|
||
|
"print(\"Embeddings extracted and saved.\")\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "27fea4f3",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/nguyendc/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n",
|
||
|
"Generating visualization...\n",
|
||
|
"[t-SNE] Computing 91 nearest neighbors...\n",
|
||
|
"[t-SNE] Indexed 1090 samples in 0.000s...\n",
|
||
|
"[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n",
|
||
|
"[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n",
|
||
|
"[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n",
|
||
|
"[t-SNE] Mean sigma: 1.117175\n",
|
||
|
"[t-SNE] Computed conditional probabilities in 0.018s\n",
|
||
|
"[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n",
|
||
|
"[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n",
|
||
|
"[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n",
|
||
|
"[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n",
|
||
|
"[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n",
|
||
|
"[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n",
|
||
|
"[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n",
|
||
|
"[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n",
|
||
|
"[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n",
|
||
|
"[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n",
|
||
|
"[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n",
|
||
|
"[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n",
|
||
|
"[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n",
|
||
|
"[t-SNE] KL divergence after 1000 iterations: 0.727630\n",
|
||
|
"UMAP visualization computed. Launch the app to see the plot.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"\n",
|
||
|
" <iframe\n",
|
||
|
" width=\"100%\"\n",
|
||
|
" height=\"800\"\n",
|
||
|
" src=\"http://localhost:5151/?notebook=True&subscription=9f7aad1e-72b8-41ab-873d-b776dbb9e9d5\"\n",
|
||
|
" frameborder=\"0\"\n",
|
||
|
" allowfullscreen\n",
|
||
|
" \n",
|
||
|
" ></iframe>\n",
|
||
|
" "
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.lib.display.IFrame at 0x798114e78160>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Notebook sessions cannot wait\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import fiftyone as fo\n",
|
||
|
"import fiftyone.brain as fob\n",
|
||
|
"import numpy as np\n",
|
||
|
"from sklearn.mixture import GaussianMixture\n",
|
||
|
"import json\n",
|
||
|
"\n",
|
||
|
"DATASET_NAME = \"mock\"\n",
|
||
|
"\n",
|
||
|
"json_path = \"./embeddings_factures_osteopathie_1k.json\"\n",
|
||
|
"\n",
|
||
|
"with open(json_path, \"r\") as file:\n",
|
||
|
" embedding_data = json.load(file)\n",
|
||
|
"\n",
|
||
|
"file_paths = []\n",
|
||
|
"embeddings = []\n",
|
||
|
"for i, record in enumerate(embedding_data):\n",
|
||
|
" file_paths.append(record.get(\"filepath\"))\n",
|
||
|
" embeddings.append(record.get(\"embedding\"))\n",
|
||
|
"\n",
|
||
|
"if DATASET_NAME in fo.list_datasets():\n",
|
||
|
" dataset = fo.load_dataset(DATASET_NAME)\n",
|
||
|
" dataset.delete()\n",
|
||
|
"dataset = fo.Dataset(DATASET_NAME)\n",
|
||
|
"\n",
|
||
|
"# Add samples to the dataset\n",
|
||
|
"samples = [fo.Sample(filepath=p) for p in file_paths]\n",
|
||
|
"dataset.add_samples(samples)\n",
|
||
|
"\n",
|
||
|
"# Building Gaussian mixture model (GMM)\n",
|
||
|
"\n",
|
||
|
"n_gaussians = 50\n",
|
||
|
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
|
||
|
"gmm.fit(embeddings)\n",
|
||
|
"cluster_labels = gmm.predict(embeddings)\n",
|
||
|
"\n",
|
||
|
"# Adding labeled embeddings to visulization\n",
|
||
|
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
|
||
|
"for sample, label in zip(dataset, cluster_labels):\n",
|
||
|
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
|
||
|
" sample.save()\n",
|
||
|
"\n",
|
||
|
"n_gaussians = 100\n",
|
||
|
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
|
||
|
"gmm.fit(embeddings)\n",
|
||
|
"cluster_labels = gmm.predict(embeddings)\n",
|
||
|
"\n",
|
||
|
"# Adding labeled embeddings to visulization\n",
|
||
|
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
|
||
|
"for sample, label in zip(dataset, cluster_labels):\n",
|
||
|
" sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
|
||
|
" sample.save()\n",
|
||
|
"\n",
|
||
|
"# --- Visualize the Embeddings with UMAP ---\n",
|
||
|
"# This will compute a 2D representation of your embeddings\n",
|
||
|
"# for visualization.\n",
|
||
|
"res = fob.compute_visualization(\n",
|
||
|
" dataset,\n",
|
||
|
" embeddings=embeddings,\n",
|
||
|
" brain_key=\"qwen_vision_viz\",\n",
|
||
|
" method=\"tsne\",\n",
|
||
|
" verbose=True\n",
|
||
|
")\n",
|
||
|
"dataset.set_values(\"qwen_umap\", res.current_points)\n",
|
||
|
"\n",
|
||
|
"print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
|
||
|
"session = fo.launch_app(dataset)\n",
|
||
|
"session.wait()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "sesame",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.18"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|