[Add] Adaptive clustering - add files

This commit is contained in:
2025-07-10 20:36:07 +00:00
committed by trungkienbkhn
parent 8e1d0568a2
commit bfdaa142c2
3 changed files with 772 additions and 0 deletions

View File

@@ -0,0 +1,106 @@
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
from sklearn.mixture import GaussianMixture
import json
import argparse
from typing import List, Tuple
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
"""Loads file paths and embeddings from a JSON file."""
print(f"Loading embeddings from {json_path}...")
with open(json_path, "r") as f:
embedding_data = json.load(f)
file_paths = [record["filepath"] for record in embedding_data]
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
return file_paths, embeddings
def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset:
"""Creates or overwrites a FiftyOne dataset from a list of file paths."""
if dataset_name in fo.list_datasets():
print(f"Dataset '{dataset_name}' already exists. Deleting it.")
fo.delete_dataset(dataset_name)
print(f"Creating new dataset '{dataset_name}'...")
dataset = fo.Dataset(dataset_name)
samples = [fo.Sample(filepath=p) for p in file_paths]
dataset.add_samples(samples)
print("Dataset created successfully.")
return dataset
def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int):
"""Fits a GMM and adds cluster labels to the dataset."""
field_name = f"gmm_cluster_{n_clusters}"
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
cluster_labels = gmm.fit_predict(embeddings)
dataset.add_sample_field(f"{field_name}", fo.IntField)
for sample, label in zip(dataset, cluster_labels):
sample[field_name] = int(label)
sample.save()
print("GMM clustering complete.")
def compute_embedding_visualization(
dataset: fo.Dataset,
embeddings: np.ndarray,
brain_key: str,
method: str = "tsne"
):
"""Computes a 2D visualization of embeddings and adds it to the dataset."""
vis_field_name = f"{brain_key}_{method}"
print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...")
results = fob.compute_visualization(
dataset,
embeddings=embeddings,
brain_key=brain_key,
method=method,
verbose=True
)
dataset.set_values(vis_field_name, results.current_points)
print(f"Visualization complete. Points stored in field '{vis_field_name}'.")
def main():
"""Main function to run the analysis pipeline."""
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.")
parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.")
parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.")
parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.")
parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.")
args = parser.parse_args()
# 1. Load data
file_paths, embeddings = load_embeddings_from_json(args.json_path)
# 2. Create FiftyOne Dataset
dataset = create_fiftyone_dataset(args.dataset_name, file_paths)
# 3. Run GMM clustering for each specified number of clusters
for n in args.n_clusters:
add_gmm_clusters_to_dataset(dataset, embeddings, n)
# 4. Compute visualization
compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method)
# 5. Launch the app
if not args.no_launch:
print("Launching the FiftyOne App...")
session = fo.launch_app(dataset)
session.wait()
else:
print("Processing complete. To view the results, launch the app manually with:")
print(f">>> import fiftyone as fo")
print(f">>> dataset = fo.load_dataset('{args.dataset_name}')")
print(f">>> session = fo.launch_app(dataset)")
if __name__ == "__main__":
main()