diff --git a/clustering_example.ipynb b/clustering_example.ipynb
index 3fd5e3b..b41d259 100644
--- a/clustering_example.ipynb
+++ b/clustering_example.ipynb
@@ -111,7 +111,15 @@
"execution_count": null,
"id": "27fea4f3",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 100% |███████████████| 1091/1091 [174.1ms elapsed, 0s remaining, 6.3K samples/s] \n"
+ ]
+ }
+ ],
"source": [
"import fiftyone as fo\n",
"import fiftyone.brain as fob\n",
diff --git a/clustering_layoutlm.ipynb b/clustering_layoutlm.ipynb
index cfed077..afb1849 100644
--- a/clustering_layoutlm.ipynb
+++ b/clustering_layoutlm.ipynb
@@ -124,66 +124,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n",
- "Generating visualization...\n",
- "[t-SNE] Computing 91 nearest neighbors...\n",
- "[t-SNE] Indexed 1090 samples in 0.000s...\n",
- "[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n",
- "[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n",
- "[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n",
- "[t-SNE] Mean sigma: 1.117175\n",
- "[t-SNE] Computed conditional probabilities in 0.018s\n",
- "[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n",
- "[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n",
- "[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n",
- "[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n",
- "[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n",
- "[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n",
- "[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n",
- "[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n",
- "[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n",
- "[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n",
- "[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n",
- "[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n",
- "[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n",
- "[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n",
- "[t-SNE] KL divergence after 1000 iterations: 0.727630\n",
- "UMAP visualization computed. Launch the app to see the plot.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Notebook sessions cannot wait\n"
+ " 100% |███████████████| 1090/1090 [182.4ms elapsed, 0s remaining, 6.0K samples/s] \n"
]
}
],
@@ -229,32 +170,135 @@
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
- "n_gaussians = 100\n",
- "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
- "gmm.fit(embeddings)\n",
- "cluster_labels = gmm.predict(embeddings)\n",
+ "# n_gaussians = 100\n",
+ "# gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
+ "# gmm.fit(embeddings)\n",
+ "# cluster_labels = gmm.predict(embeddings)\n",
"\n",
- "# Adding labeled embeddings to visulization\n",
- "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
- "for sample, label in zip(dataset, cluster_labels):\n",
- " sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
- " sample.save()\n",
+ "# # Adding labeled embeddings to visulization\n",
+ "# dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
+ "# for sample, label in zip(dataset, cluster_labels):\n",
+ "# sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
+ "# sample.save()\n",
"\n",
- "# --- Visualize the Embeddings with UMAP ---\n",
- "# This will compute a 2D representation of your embeddings\n",
- "# for visualization.\n",
- "res = fob.compute_visualization(\n",
- " dataset,\n",
- " embeddings=embeddings,\n",
- " brain_key=\"qwen_vision_viz\",\n",
- " method=\"tsne\",\n",
- " verbose=True\n",
- ")\n",
- "dataset.set_values(\"qwen_umap\", res.current_points)\n",
+ "# # --- Visualize the Embeddings with UMAP ---\n",
+ "# # This will compute a 2D representation of your embeddings\n",
+ "# # for visualization.\n",
+ "# res = fob.compute_visualization(\n",
+ "# dataset,\n",
+ "# embeddings=embeddings,\n",
+ "# brain_key=\"qwen_vision_viz\",\n",
+ "# method=\"tsne\",\n",
+ "# verbose=True\n",
+ "# )\n",
+ "# dataset.set_values(\"qwen_umap\", res.current_points)\n",
"\n",
- "print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
- "session = fo.launch_app(dataset)\n",
- "session.wait()"
+ "# print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
+ "# session = fo.launch_app(dataset)\n",
+ "# session.wait()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a8c23ca3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[[ 2.22716806e+06 -1.55841522e+05 1.20176465e+04 ... 7.78593400e+03\n",
+ " 1.90060269e+03 -1.81817591e+02]\n",
+ " [-1.55841522e+05 2.45425677e+06 9.82699914e+04 ... -5.37744672e+04\n",
+ " -1.14976214e+05 1.95498686e+05]\n",
+ " [ 1.20176465e+04 9.82699914e+04 2.05845769e+06 ... -6.42761657e+04\n",
+ " -1.57428613e+03 1.90527654e+05]\n",
+ " ...\n",
+ " [ 7.78593400e+03 -5.37744672e+04 -6.42761657e+04 ... 2.19897397e+06\n",
+ " 9.79439501e+04 2.63393791e+05]\n",
+ " [ 1.90060269e+03 -1.14976214e+05 -1.57428613e+03 ... 9.79439501e+04\n",
+ " 2.09478614e+06 1.35758244e+05]\n",
+ " [-1.81817591e+02 1.95498686e+05 1.90527654e+05 ... 2.63393791e+05\n",
+ " 1.35758244e+05 2.12691079e+06]]\n",
+ "\n",
+ " [[ 2.83658946e+06 -1.34780282e+05 6.49580505e+04 ... 8.54367930e+04\n",
+ " 2.68631019e+04 -2.00238083e+03]\n",
+ " [-1.34780282e+05 3.16043262e+06 1.74730403e+05 ... -1.52450289e+04\n",
+ " -1.33867983e+05 2.33287605e+05]\n",
+ " [ 6.49580505e+04 1.74730403e+05 2.62079599e+06 ... -1.34705133e+05\n",
+ " -3.42729631e+03 2.52694121e+05]\n",
+ " ...\n",
+ " [ 8.54367930e+04 -1.52450289e+04 -1.34705133e+05 ... 2.83179172e+06\n",
+ " 1.20704061e+05 3.58186648e+05]\n",
+ " [ 2.68631019e+04 -1.33867983e+05 -3.42729631e+03 ... 1.20704061e+05\n",
+ " 2.61548713e+06 2.01096011e+05]\n",
+ " [-2.00238083e+03 2.33287605e+05 2.52694121e+05 ... 3.58186648e+05\n",
+ " 2.01096011e+05 2.68561973e+06]]\n",
+ "\n",
+ " [[ 2.46084057e+06 -1.61990924e+05 2.05209921e+04 ... 5.80037219e+04\n",
+ " -2.11338740e+03 -1.79719159e+04]\n",
+ " [-1.61990924e+05 2.77033377e+06 1.21300499e+05 ... -2.88489495e+04\n",
+ " -1.56314630e+05 1.81241062e+05]\n",
+ " [ 2.05209921e+04 1.21300499e+05 2.26999115e+06 ... -1.22265990e+05\n",
+ " -1.16935046e+04 1.98900231e+05]\n",
+ " ...\n",
+ " [ 5.80037219e+04 -2.88489495e+04 -1.22265990e+05 ... 2.44690178e+06\n",
+ " 9.92107868e+04 2.87651396e+05]\n",
+ " [-2.11338740e+03 -1.56314630e+05 -1.16935046e+04 ... 9.92107868e+04\n",
+ " 2.24557593e+06 1.34111998e+05]\n",
+ " [-1.79719159e+04 1.81241062e+05 1.98900231e+05 ... 2.87651396e+05\n",
+ " 1.34111998e+05 2.34916522e+06]]\n",
+ "\n",
+ " ...\n",
+ "\n",
+ " [[ 1.92829727e+06 -1.59315776e+05 1.58992088e+03 ... 2.85003295e+04\n",
+ " 4.22064899e+03 -3.58290465e+04]\n",
+ " [-1.59315776e+05 2.20558282e+06 8.03460647e+04 ... -4.04459318e+04\n",
+ " -1.07734468e+05 1.56172361e+05]\n",
+ " [ 1.58992088e+03 8.03460647e+04 1.83371422e+06 ... -1.11957845e+05\n",
+ " 2.80346243e+03 1.49548841e+05]\n",
+ " ...\n",
+ " [ 2.85003295e+04 -4.04459318e+04 -1.11957845e+05 ... 1.91110989e+06\n",
+ " 9.80180683e+04 1.85174267e+05]\n",
+ " [ 4.22064899e+03 -1.07734468e+05 2.80346243e+03 ... 9.80180683e+04\n",
+ " 1.77703462e+06 8.43241833e+04]\n",
+ " [-3.58290465e+04 1.56172361e+05 1.49548841e+05 ... 1.85174267e+05\n",
+ " 8.43241833e+04 1.89600377e+06]]\n",
+ "\n",
+ " [[ 3.47995050e+06 -2.33949522e+04 2.34412434e+05 ... 1.74855689e+05\n",
+ " 1.42301226e+05 1.03626216e+05]\n",
+ " [-2.33949522e+04 3.91344835e+06 3.68391546e+05 ... 1.48516088e+05\n",
+ " -5.32260301e+04 4.19584565e+05]\n",
+ " [ 2.34412434e+05 3.68391546e+05 3.16329032e+06 ... -7.83149391e+04\n",
+ " 1.04717741e+05 4.01535013e+05]\n",
+ " ...\n",
+ " [ 1.74855689e+05 1.48516088e+05 -7.83149391e+04 ... 3.45461426e+06\n",
+ " 1.69208897e+05 5.00179297e+05]\n",
+ " [ 1.42301226e+05 -5.32260301e+04 1.04717741e+05 ... 1.69208897e+05\n",
+ " 3.23120264e+06 3.85596283e+05]\n",
+ " [ 1.03626216e+05 4.19584565e+05 4.01535013e+05 ... 5.00179297e+05\n",
+ " 3.85596283e+05 3.23650330e+06]]\n",
+ "\n",
+ " [[ 2.87968337e+06 -1.54796543e+05 6.48388780e+04 ... 6.24845766e+04\n",
+ " 4.24673639e+04 -9.91376939e+03]\n",
+ " [-1.54796543e+05 3.29781697e+06 1.82272830e+05 ... 9.13276300e+03\n",
+ " -1.04968382e+05 2.54428804e+05]\n",
+ " [ 6.48388780e+04 1.82272830e+05 2.68323883e+06 ... -1.43183305e+05\n",
+ " 1.35349522e+04 2.71469510e+05]\n",
+ " ...\n",
+ " [ 6.24845766e+04 9.13276300e+03 -1.43183305e+05 ... 2.94868837e+06\n",
+ " 1.25676092e+05 3.72271111e+05]\n",
+ " [ 4.24673639e+04 -1.04968382e+05 1.35349522e+04 ... 1.25676092e+05\n",
+ " 2.69905870e+06 2.40775248e+05]\n",
+ " [-9.91376939e+03 2.54428804e+05 2.71469510e+05 ... 3.72271111e+05\n",
+ " 2.40775248e+05 2.76807176e+06]]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.metrics import silhouette_samples\n",
+ "print(silhouette_samples(np.array(embeddings), labels=cluster_labels))"
]
}
],
diff --git a/src/embedding_extraction.py b/src/embedding_extraction.py
index 4e3e2a3..a730c78 100644
--- a/src/embedding_extraction.py
+++ b/src/embedding_extraction.py
@@ -57,7 +57,7 @@ def get_layoutlm_image_embeddings(
return embeddings.cpu().numpy(), valid_paths
-def get_image_embeddings(
+def get_qwenvl_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],