{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "59f8a415", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor\n", "from qwen_vl_utils import process_vision_info\n", "from PIL import Image\n", "import os\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "# --- Configuration ---\n", "MODEL_NAME = \"Qwen/Qwen2.5-VL-3B-Instruct\" # You can choose other model sizes\n", "IMAGE_DIR = \"/home/nguyendc/phat-dev/clustering/extracted_images\"\n", "BATCH_SIZE = 4\n", "# --- End Configuration ---\n", "\n", "# Check for GPU availability\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using device: {device}\")\n", "\n", "# Load the model and processor\n", "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n", " MODEL_NAME, torch_dtype=\"bfloat16\", device_map=\"cuda\", attn_implementation=\"flash_attention_2\",\n", ")\n", "processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "cdfdab0e", "metadata": {}, "outputs": [], "source": [ "def get_image_embeddings(image_paths):\n", " \"\"\"\n", " Processes a batch of images and extracts their embeddings.\n", " \"\"\"\n", " images_pil = []\n", " valid_paths = []\n", " for path in image_paths:\n", " if path.lower().endswith(('.png', '.jpg', '.jpeg')):\n", " try:\n", " # The processor expects PIL images in RGB format\n", " images_pil.append(Image.open(path).convert(\"RGB\"))\n", " valid_paths.append(path)\n", " except Exception as e:\n", " print(f\"Warning: Could not load image {path}. Skipping. Error: {e}\")\n", "\n", " if not images_pil:\n", " return np.array([]), []\n", "\n", " # For pure vision feature extraction, we can provide an empty text prompt.\n", " # The processor handles tokenizing text and preparing images.\n", " inputs = processor(\n", " text=[\"\"] * len(images_pil),\n", " images=images_pil,\n", " padding=True,\n", " return_tensors=\"pt\"\n", " ).to(device)\n", "\n", " with torch.no_grad():\n", " # Get the vision embeddings from the model's vision tower\n", " vision_outputs = model.visual(inputs['pixel_values'].to(dtype=model.dtype), grid_thw=inputs['image_grid_thw'])\n", " # We'll use the pooled output as the embedding\n", " embeddings = vision_outputs\n", "\n", " return embeddings.to(torch.float16).cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "cdaebb7b", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# --- Process all images in the directory ---\n", "image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n", "all_embeddings = []\n", "filepaths = []\n", "\n", "with open(\"embeddings_factures_osteopathie_1k_qwen.json\", \"w\") as f:\n", " f.write(\"[\\n\")\n", " first = True\n", " for i in tqdm(range(0, len(image_files), BATCH_SIZE)):\n", " batch_paths = image_files[i:i+BATCH_SIZE]\n", " batch_embeddings = get_image_embeddings(batch_paths)\n", " embeddings_list = [emb.tolist() for emb in batch_embeddings]\n", " for path, emb in zip(batch_paths, embeddings_list):\n", " if not first:\n", " f.write(\",\\n\")\n", " json.dump({\"filepath\": path, \"embedding\": emb}, f)\n", " first = False\n", " f.write(\"\\n]\\n\")\n", "\n", "print(\"Embeddings extracted and saved.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "27fea4f3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 100% |███████████████| 1091/1091 [174.1ms elapsed, 0s remaining, 6.3K samples/s] \n" ] } ], "source": [ "import fiftyone as fo\n", "import fiftyone.brain as fob\n", "import numpy as np\n", "from sklearn.mixture import GaussianMixture\n", "import json\n", "\n", "DATASET_NAME = \"mock\"\n", "\n", "json_path = \"./embeddings_factures_osteopathie_1k_qwen.json\"\n", "\n", "with open(json_path, \"r\") as file:\n", " embedding_data = json.load(file)\n", "\n", "file_paths = []\n", "embeddings = []\n", "for i, record in enumerate(embedding_data):\n", " file_paths.append(record.get(\"filepath\"))\n", " embeddings.append(record.get(\"embedding\"))\n", "\n", "if DATASET_NAME in fo.list_datasets():\n", " dataset = fo.load_dataset(DATASET_NAME)\n", " dataset.delete()\n", "dataset = fo.Dataset(DATASET_NAME)\n", "\n", "# Add samples to the dataset\n", "samples = [fo.Sample(filepath=p) for p in file_paths]\n", "dataset.add_samples(samples)\n", "\n", "# Building Gaussian mixture model (GMM)\n", "n_gaussians = 50\n", "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", "gmm.fit(embeddings)\n", "cluster_labels = gmm.predict(embeddings)\n", "\n", "# Adding labeled embeddings to visulization\n", "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", "for sample, label in zip(dataset, cluster_labels):\n", " sample[\"gmm_cluster_50_gaussians\"] = int(label)\n", " sample.save()\n", "\n", "n_gaussians = 200\n", "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", "gmm.fit(embeddings)\n", "cluster_labels = gmm.predict(embeddings)\n", "\n", "# Adding labeled embeddings to visulization\n", "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", "for sample, label in zip(dataset, cluster_labels):\n", " sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", " sample.save()\n", "\n", "# --- Visualize the Embeddings with UMAP ---\n", "# This will compute a 2D representation of your embeddings\n", "# for visualization.\n", "res = fob.compute_visualization(\n", " dataset,\n", " embeddings=embeddings,\n", " brain_key=\"qwen_vision_viz\",\n", " method=\"tsne\",\n", " verbose=True\n", ")\n", "dataset.set_values(\"qwen_umap\", res.current_points)\n", "\n", "print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", "session = fo.launch_app(dataset)" ] } ], "metadata": { "kernelspec": { "display_name": "sesame", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }