diff --git a/clustering_example.ipynb b/clustering_example.ipynb index 3fd5e3b..b41d259 100644 --- a/clustering_example.ipynb +++ b/clustering_example.ipynb @@ -111,7 +111,15 @@ "execution_count": null, "id": "27fea4f3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 100% |███████████████| 1091/1091 [174.1ms elapsed, 0s remaining, 6.3K samples/s] \n" + ] + } + ], "source": [ "import fiftyone as fo\n", "import fiftyone.brain as fob\n", diff --git a/clustering_layoutlm.ipynb b/clustering_layoutlm.ipynb index cfed077..afb1849 100644 --- a/clustering_layoutlm.ipynb +++ b/clustering_layoutlm.ipynb @@ -124,66 +124,7 @@ "name": "stdout", "output_type": "stream", "text": [ - " 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n", - "Generating visualization...\n", - "[t-SNE] Computing 91 nearest neighbors...\n", - "[t-SNE] Indexed 1090 samples in 0.000s...\n", - "[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n", - "[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n", - "[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n", - "[t-SNE] Mean sigma: 1.117175\n", - "[t-SNE] Computed conditional probabilities in 0.018s\n", - "[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n", - "[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n", - "[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n", - "[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n", - "[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n", - "[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n", - "[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n", - "[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n", - "[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n", - "[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n", - "[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n", - "[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n", - "[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n", - "[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n", - "[t-SNE] KL divergence after 1000 iterations: 0.727630\n", - "UMAP visualization computed. Launch the app to see the plot.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Notebook sessions cannot wait\n" + " 100% |███████████████| 1090/1090 [182.4ms elapsed, 0s remaining, 6.0K samples/s] \n" ] } ], @@ -229,32 +170,135 @@ " sample[\"gmm_cluster_50_gaussians\"] = int(label)\n", " sample.save()\n", "\n", - "n_gaussians = 100\n", - "gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", - "gmm.fit(embeddings)\n", - "cluster_labels = gmm.predict(embeddings)\n", + "# n_gaussians = 100\n", + "# gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n", + "# gmm.fit(embeddings)\n", + "# cluster_labels = gmm.predict(embeddings)\n", "\n", - "# Adding labeled embeddings to visulization\n", - "dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", - "for sample, label in zip(dataset, cluster_labels):\n", - " sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", - " sample.save()\n", + "# # Adding labeled embeddings to visulization\n", + "# dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n", + "# for sample, label in zip(dataset, cluster_labels):\n", + "# sample[\"gmm_cluster_200_gaussians\"] = int(label)\n", + "# sample.save()\n", "\n", - "# --- Visualize the Embeddings with UMAP ---\n", - "# This will compute a 2D representation of your embeddings\n", - "# for visualization.\n", - "res = fob.compute_visualization(\n", - " dataset,\n", - " embeddings=embeddings,\n", - " brain_key=\"qwen_vision_viz\",\n", - " method=\"tsne\",\n", - " verbose=True\n", - ")\n", - "dataset.set_values(\"qwen_umap\", res.current_points)\n", + "# # --- Visualize the Embeddings with UMAP ---\n", + "# # This will compute a 2D representation of your embeddings\n", + "# # for visualization.\n", + "# res = fob.compute_visualization(\n", + "# dataset,\n", + "# embeddings=embeddings,\n", + "# brain_key=\"qwen_vision_viz\",\n", + "# method=\"tsne\",\n", + "# verbose=True\n", + "# )\n", + "# dataset.set_values(\"qwen_umap\", res.current_points)\n", "\n", - "print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", - "session = fo.launch_app(dataset)\n", - "session.wait()" + "# print(\"UMAP visualization computed. Launch the app to see the plot.\")\n", + "# session = fo.launch_app(dataset)\n", + "# session.wait()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8c23ca3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[ 2.22716806e+06 -1.55841522e+05 1.20176465e+04 ... 7.78593400e+03\n", + " 1.90060269e+03 -1.81817591e+02]\n", + " [-1.55841522e+05 2.45425677e+06 9.82699914e+04 ... -5.37744672e+04\n", + " -1.14976214e+05 1.95498686e+05]\n", + " [ 1.20176465e+04 9.82699914e+04 2.05845769e+06 ... -6.42761657e+04\n", + " -1.57428613e+03 1.90527654e+05]\n", + " ...\n", + " [ 7.78593400e+03 -5.37744672e+04 -6.42761657e+04 ... 2.19897397e+06\n", + " 9.79439501e+04 2.63393791e+05]\n", + " [ 1.90060269e+03 -1.14976214e+05 -1.57428613e+03 ... 9.79439501e+04\n", + " 2.09478614e+06 1.35758244e+05]\n", + " [-1.81817591e+02 1.95498686e+05 1.90527654e+05 ... 2.63393791e+05\n", + " 1.35758244e+05 2.12691079e+06]]\n", + "\n", + " [[ 2.83658946e+06 -1.34780282e+05 6.49580505e+04 ... 8.54367930e+04\n", + " 2.68631019e+04 -2.00238083e+03]\n", + " [-1.34780282e+05 3.16043262e+06 1.74730403e+05 ... -1.52450289e+04\n", + " -1.33867983e+05 2.33287605e+05]\n", + " [ 6.49580505e+04 1.74730403e+05 2.62079599e+06 ... -1.34705133e+05\n", + " -3.42729631e+03 2.52694121e+05]\n", + " ...\n", + " [ 8.54367930e+04 -1.52450289e+04 -1.34705133e+05 ... 2.83179172e+06\n", + " 1.20704061e+05 3.58186648e+05]\n", + " [ 2.68631019e+04 -1.33867983e+05 -3.42729631e+03 ... 1.20704061e+05\n", + " 2.61548713e+06 2.01096011e+05]\n", + " [-2.00238083e+03 2.33287605e+05 2.52694121e+05 ... 3.58186648e+05\n", + " 2.01096011e+05 2.68561973e+06]]\n", + "\n", + " [[ 2.46084057e+06 -1.61990924e+05 2.05209921e+04 ... 5.80037219e+04\n", + " -2.11338740e+03 -1.79719159e+04]\n", + " [-1.61990924e+05 2.77033377e+06 1.21300499e+05 ... -2.88489495e+04\n", + " -1.56314630e+05 1.81241062e+05]\n", + " [ 2.05209921e+04 1.21300499e+05 2.26999115e+06 ... -1.22265990e+05\n", + " -1.16935046e+04 1.98900231e+05]\n", + " ...\n", + " [ 5.80037219e+04 -2.88489495e+04 -1.22265990e+05 ... 2.44690178e+06\n", + " 9.92107868e+04 2.87651396e+05]\n", + " [-2.11338740e+03 -1.56314630e+05 -1.16935046e+04 ... 9.92107868e+04\n", + " 2.24557593e+06 1.34111998e+05]\n", + " [-1.79719159e+04 1.81241062e+05 1.98900231e+05 ... 2.87651396e+05\n", + " 1.34111998e+05 2.34916522e+06]]\n", + "\n", + " ...\n", + "\n", + " [[ 1.92829727e+06 -1.59315776e+05 1.58992088e+03 ... 2.85003295e+04\n", + " 4.22064899e+03 -3.58290465e+04]\n", + " [-1.59315776e+05 2.20558282e+06 8.03460647e+04 ... -4.04459318e+04\n", + " -1.07734468e+05 1.56172361e+05]\n", + " [ 1.58992088e+03 8.03460647e+04 1.83371422e+06 ... -1.11957845e+05\n", + " 2.80346243e+03 1.49548841e+05]\n", + " ...\n", + " [ 2.85003295e+04 -4.04459318e+04 -1.11957845e+05 ... 1.91110989e+06\n", + " 9.80180683e+04 1.85174267e+05]\n", + " [ 4.22064899e+03 -1.07734468e+05 2.80346243e+03 ... 9.80180683e+04\n", + " 1.77703462e+06 8.43241833e+04]\n", + " [-3.58290465e+04 1.56172361e+05 1.49548841e+05 ... 1.85174267e+05\n", + " 8.43241833e+04 1.89600377e+06]]\n", + "\n", + " [[ 3.47995050e+06 -2.33949522e+04 2.34412434e+05 ... 1.74855689e+05\n", + " 1.42301226e+05 1.03626216e+05]\n", + " [-2.33949522e+04 3.91344835e+06 3.68391546e+05 ... 1.48516088e+05\n", + " -5.32260301e+04 4.19584565e+05]\n", + " [ 2.34412434e+05 3.68391546e+05 3.16329032e+06 ... -7.83149391e+04\n", + " 1.04717741e+05 4.01535013e+05]\n", + " ...\n", + " [ 1.74855689e+05 1.48516088e+05 -7.83149391e+04 ... 3.45461426e+06\n", + " 1.69208897e+05 5.00179297e+05]\n", + " [ 1.42301226e+05 -5.32260301e+04 1.04717741e+05 ... 1.69208897e+05\n", + " 3.23120264e+06 3.85596283e+05]\n", + " [ 1.03626216e+05 4.19584565e+05 4.01535013e+05 ... 5.00179297e+05\n", + " 3.85596283e+05 3.23650330e+06]]\n", + "\n", + " [[ 2.87968337e+06 -1.54796543e+05 6.48388780e+04 ... 6.24845766e+04\n", + " 4.24673639e+04 -9.91376939e+03]\n", + " [-1.54796543e+05 3.29781697e+06 1.82272830e+05 ... 9.13276300e+03\n", + " -1.04968382e+05 2.54428804e+05]\n", + " [ 6.48388780e+04 1.82272830e+05 2.68323883e+06 ... -1.43183305e+05\n", + " 1.35349522e+04 2.71469510e+05]\n", + " ...\n", + " [ 6.24845766e+04 9.13276300e+03 -1.43183305e+05 ... 2.94868837e+06\n", + " 1.25676092e+05 3.72271111e+05]\n", + " [ 4.24673639e+04 -1.04968382e+05 1.35349522e+04 ... 1.25676092e+05\n", + " 2.69905870e+06 2.40775248e+05]\n", + " [-9.91376939e+03 2.54428804e+05 2.71469510e+05 ... 3.72271111e+05\n", + " 2.40775248e+05 2.76807176e+06]]]\n" + ] + } + ], + "source": [ + "from sklearn.metrics import silhouette_samples\n", + "print(silhouette_samples(np.array(embeddings), labels=cluster_labels))" ] } ], diff --git a/src/embedding_extraction.py b/src/embedding_extraction.py index 4e3e2a3..a730c78 100644 --- a/src/embedding_extraction.py +++ b/src/embedding_extraction.py @@ -57,7 +57,7 @@ def get_layoutlm_image_embeddings( return embeddings.cpu().numpy(), valid_paths -def get_image_embeddings( +def get_qwenvl_image_embeddings( model: AutoModel, processor: AutoProcessor, image_paths: List[str],