[Add] Adaptive clustering

This commit is contained in:
trungkienbkhn 2025-07-10 20:35:46 +00:00
parent 65adb5d4ba
commit 8e1d0568a2
3 changed files with 137 additions and 85 deletions

View File

@ -111,7 +111,15 @@
"execution_count": null,
"id": "27fea4f3",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 100% |███████████████| 1091/1091 [174.1ms elapsed, 0s remaining, 6.3K samples/s] \n"
]
}
],
"source": [
"import fiftyone as fo\n",
"import fiftyone.brain as fob\n",

View File

@ -124,66 +124,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
" 100% |███████████████| 1090/1090 [179.3ms elapsed, 0s remaining, 6.1K samples/s] \n",
"Generating visualization...\n",
"[t-SNE] Computing 91 nearest neighbors...\n",
"[t-SNE] Indexed 1090 samples in 0.000s...\n",
"[t-SNE] Computed neighbors for 1090 samples in 0.116s...\n",
"[t-SNE] Computed conditional probabilities for sample 1000 / 1090\n",
"[t-SNE] Computed conditional probabilities for sample 1090 / 1090\n",
"[t-SNE] Mean sigma: 1.117175\n",
"[t-SNE] Computed conditional probabilities in 0.018s\n",
"[t-SNE] Iteration 50: error = 63.3665466, gradient norm = 0.0725846 (50 iterations in 0.175s)\n",
"[t-SNE] Iteration 100: error = 61.4698219, gradient norm = 0.0620725 (50 iterations in 0.064s)\n",
"[t-SNE] Iteration 150: error = 60.9316177, gradient norm = 0.0641517 (50 iterations in 0.061s)\n",
"[t-SNE] Iteration 200: error = 60.9066200, gradient norm = 0.0622800 (50 iterations in 0.063s)\n",
"[t-SNE] Iteration 250: error = 60.9928894, gradient norm = 0.0575472 (50 iterations in 0.062s)\n",
"[t-SNE] KL divergence after 250 iterations with early exaggeration: 60.992889\n",
"[t-SNE] Iteration 300: error = 0.9203525, gradient norm = 0.0060658 (50 iterations in 0.060s)\n",
"[t-SNE] Iteration 350: error = 0.8058189, gradient norm = 0.0041889 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 400: error = 0.7736985, gradient norm = 0.0027541 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 450: error = 0.7591555, gradient norm = 0.0019742 (50 iterations in 0.064s)\n",
"[t-SNE] Iteration 500: error = 0.7508169, gradient norm = 0.0015802 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 550: error = 0.7453735, gradient norm = 0.0013677 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 600: error = 0.7403219, gradient norm = 0.0012937 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 650: error = 0.7358239, gradient norm = 0.0008007 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 700: error = 0.7335896, gradient norm = 0.0008103 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 750: error = 0.7321761, gradient norm = 0.0006007 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 800: error = 0.7311318, gradient norm = 0.0006798 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 850: error = 0.7299591, gradient norm = 0.0006054 (50 iterations in 0.057s)\n",
"[t-SNE] Iteration 900: error = 0.7292019, gradient norm = 0.0004605 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 950: error = 0.7283084, gradient norm = 0.0005973 (50 iterations in 0.056s)\n",
"[t-SNE] Iteration 1000: error = 0.7276300, gradient norm = 0.0005537 (50 iterations in 0.057s)\n",
"[t-SNE] KL divergence after 1000 iterations: 0.727630\n",
"UMAP visualization computed. Launch the app to see the plot.\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"800\"\n",
" src=\"http://localhost:5151/?notebook=True&subscription=9f7aad1e-72b8-41ab-873d-b776dbb9e9d5\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x798114e78160>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Notebook sessions cannot wait\n"
" 100% |███████████████| 1090/1090 [182.4ms elapsed, 0s remaining, 6.0K samples/s] \n"
]
}
],
@ -229,32 +170,135 @@
" sample[\"gmm_cluster_50_gaussians\"] = int(label)\n",
" sample.save()\n",
"\n",
"n_gaussians = 100\n",
"gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"gmm.fit(embeddings)\n",
"cluster_labels = gmm.predict(embeddings)\n",
"# n_gaussians = 100\n",
"# gmm = GaussianMixture(n_components=n_gaussians, random_state=42)\n",
"# gmm.fit(embeddings)\n",
"# cluster_labels = gmm.predict(embeddings)\n",
"\n",
"# Adding labeled embeddings to visulization\n",
"dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"for sample, label in zip(dataset, cluster_labels):\n",
" sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
" sample.save()\n",
"# # Adding labeled embeddings to visulization\n",
"# dataset.add_sample_field(\"gmm_cluster\", fo.IntField)\n",
"# for sample, label in zip(dataset, cluster_labels):\n",
"# sample[\"gmm_cluster_200_gaussians\"] = int(label)\n",
"# sample.save()\n",
"\n",
"# --- Visualize the Embeddings with UMAP ---\n",
"# This will compute a 2D representation of your embeddings\n",
"# for visualization.\n",
"res = fob.compute_visualization(\n",
" dataset,\n",
" embeddings=embeddings,\n",
" brain_key=\"qwen_vision_viz\",\n",
" method=\"tsne\",\n",
" verbose=True\n",
")\n",
"dataset.set_values(\"qwen_umap\", res.current_points)\n",
"# # --- Visualize the Embeddings with UMAP ---\n",
"# # This will compute a 2D representation of your embeddings\n",
"# # for visualization.\n",
"# res = fob.compute_visualization(\n",
"# dataset,\n",
"# embeddings=embeddings,\n",
"# brain_key=\"qwen_vision_viz\",\n",
"# method=\"tsne\",\n",
"# verbose=True\n",
"# )\n",
"# dataset.set_values(\"qwen_umap\", res.current_points)\n",
"\n",
"print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
"session = fo.launch_app(dataset)\n",
"session.wait()"
"# print(\"UMAP visualization computed. Launch the app to see the plot.\")\n",
"# session = fo.launch_app(dataset)\n",
"# session.wait()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8c23ca3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[[ 2.22716806e+06 -1.55841522e+05 1.20176465e+04 ... 7.78593400e+03\n",
" 1.90060269e+03 -1.81817591e+02]\n",
" [-1.55841522e+05 2.45425677e+06 9.82699914e+04 ... -5.37744672e+04\n",
" -1.14976214e+05 1.95498686e+05]\n",
" [ 1.20176465e+04 9.82699914e+04 2.05845769e+06 ... -6.42761657e+04\n",
" -1.57428613e+03 1.90527654e+05]\n",
" ...\n",
" [ 7.78593400e+03 -5.37744672e+04 -6.42761657e+04 ... 2.19897397e+06\n",
" 9.79439501e+04 2.63393791e+05]\n",
" [ 1.90060269e+03 -1.14976214e+05 -1.57428613e+03 ... 9.79439501e+04\n",
" 2.09478614e+06 1.35758244e+05]\n",
" [-1.81817591e+02 1.95498686e+05 1.90527654e+05 ... 2.63393791e+05\n",
" 1.35758244e+05 2.12691079e+06]]\n",
"\n",
" [[ 2.83658946e+06 -1.34780282e+05 6.49580505e+04 ... 8.54367930e+04\n",
" 2.68631019e+04 -2.00238083e+03]\n",
" [-1.34780282e+05 3.16043262e+06 1.74730403e+05 ... -1.52450289e+04\n",
" -1.33867983e+05 2.33287605e+05]\n",
" [ 6.49580505e+04 1.74730403e+05 2.62079599e+06 ... -1.34705133e+05\n",
" -3.42729631e+03 2.52694121e+05]\n",
" ...\n",
" [ 8.54367930e+04 -1.52450289e+04 -1.34705133e+05 ... 2.83179172e+06\n",
" 1.20704061e+05 3.58186648e+05]\n",
" [ 2.68631019e+04 -1.33867983e+05 -3.42729631e+03 ... 1.20704061e+05\n",
" 2.61548713e+06 2.01096011e+05]\n",
" [-2.00238083e+03 2.33287605e+05 2.52694121e+05 ... 3.58186648e+05\n",
" 2.01096011e+05 2.68561973e+06]]\n",
"\n",
" [[ 2.46084057e+06 -1.61990924e+05 2.05209921e+04 ... 5.80037219e+04\n",
" -2.11338740e+03 -1.79719159e+04]\n",
" [-1.61990924e+05 2.77033377e+06 1.21300499e+05 ... -2.88489495e+04\n",
" -1.56314630e+05 1.81241062e+05]\n",
" [ 2.05209921e+04 1.21300499e+05 2.26999115e+06 ... -1.22265990e+05\n",
" -1.16935046e+04 1.98900231e+05]\n",
" ...\n",
" [ 5.80037219e+04 -2.88489495e+04 -1.22265990e+05 ... 2.44690178e+06\n",
" 9.92107868e+04 2.87651396e+05]\n",
" [-2.11338740e+03 -1.56314630e+05 -1.16935046e+04 ... 9.92107868e+04\n",
" 2.24557593e+06 1.34111998e+05]\n",
" [-1.79719159e+04 1.81241062e+05 1.98900231e+05 ... 2.87651396e+05\n",
" 1.34111998e+05 2.34916522e+06]]\n",
"\n",
" ...\n",
"\n",
" [[ 1.92829727e+06 -1.59315776e+05 1.58992088e+03 ... 2.85003295e+04\n",
" 4.22064899e+03 -3.58290465e+04]\n",
" [-1.59315776e+05 2.20558282e+06 8.03460647e+04 ... -4.04459318e+04\n",
" -1.07734468e+05 1.56172361e+05]\n",
" [ 1.58992088e+03 8.03460647e+04 1.83371422e+06 ... -1.11957845e+05\n",
" 2.80346243e+03 1.49548841e+05]\n",
" ...\n",
" [ 2.85003295e+04 -4.04459318e+04 -1.11957845e+05 ... 1.91110989e+06\n",
" 9.80180683e+04 1.85174267e+05]\n",
" [ 4.22064899e+03 -1.07734468e+05 2.80346243e+03 ... 9.80180683e+04\n",
" 1.77703462e+06 8.43241833e+04]\n",
" [-3.58290465e+04 1.56172361e+05 1.49548841e+05 ... 1.85174267e+05\n",
" 8.43241833e+04 1.89600377e+06]]\n",
"\n",
" [[ 3.47995050e+06 -2.33949522e+04 2.34412434e+05 ... 1.74855689e+05\n",
" 1.42301226e+05 1.03626216e+05]\n",
" [-2.33949522e+04 3.91344835e+06 3.68391546e+05 ... 1.48516088e+05\n",
" -5.32260301e+04 4.19584565e+05]\n",
" [ 2.34412434e+05 3.68391546e+05 3.16329032e+06 ... -7.83149391e+04\n",
" 1.04717741e+05 4.01535013e+05]\n",
" ...\n",
" [ 1.74855689e+05 1.48516088e+05 -7.83149391e+04 ... 3.45461426e+06\n",
" 1.69208897e+05 5.00179297e+05]\n",
" [ 1.42301226e+05 -5.32260301e+04 1.04717741e+05 ... 1.69208897e+05\n",
" 3.23120264e+06 3.85596283e+05]\n",
" [ 1.03626216e+05 4.19584565e+05 4.01535013e+05 ... 5.00179297e+05\n",
" 3.85596283e+05 3.23650330e+06]]\n",
"\n",
" [[ 2.87968337e+06 -1.54796543e+05 6.48388780e+04 ... 6.24845766e+04\n",
" 4.24673639e+04 -9.91376939e+03]\n",
" [-1.54796543e+05 3.29781697e+06 1.82272830e+05 ... 9.13276300e+03\n",
" -1.04968382e+05 2.54428804e+05]\n",
" [ 6.48388780e+04 1.82272830e+05 2.68323883e+06 ... -1.43183305e+05\n",
" 1.35349522e+04 2.71469510e+05]\n",
" ...\n",
" [ 6.24845766e+04 9.13276300e+03 -1.43183305e+05 ... 2.94868837e+06\n",
" 1.25676092e+05 3.72271111e+05]\n",
" [ 4.24673639e+04 -1.04968382e+05 1.35349522e+04 ... 1.25676092e+05\n",
" 2.69905870e+06 2.40775248e+05]\n",
" [-9.91376939e+03 2.54428804e+05 2.71469510e+05 ... 3.72271111e+05\n",
" 2.40775248e+05 2.76807176e+06]]]\n"
]
}
],
"source": [
"from sklearn.metrics import silhouette_samples\n",
"print(silhouette_samples(np.array(embeddings), labels=cluster_labels))"
]
}
],

View File

@ -57,7 +57,7 @@ def get_layoutlm_image_embeddings(
return embeddings.cpu().numpy(), valid_paths
def get_image_embeddings(
def get_qwenvl_image_embeddings(
model: AutoModel,
processor: AutoProcessor,
image_paths: List[str],