import fiftyone as fo import fiftyone.brain as fob import numpy as np from sklearn.mixture import GaussianMixture import json import argparse from typing import List, Tuple def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]: """Loads file paths and embeddings from a JSON file.""" print(f"Loading embeddings from {json_path}...") with open(json_path, "r") as f: embedding_data = json.load(f) file_paths = [record["filepath"] for record in embedding_data] embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32) print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.") return file_paths, embeddings def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset: """Creates or overwrites a FiftyOne dataset from a list of file paths.""" if dataset_name in fo.list_datasets(): print(f"Dataset '{dataset_name}' already exists. Deleting it.") fo.delete_dataset(dataset_name) print(f"Creating new dataset '{dataset_name}'...") dataset = fo.Dataset(dataset_name) samples = [fo.Sample(filepath=p) for p in file_paths] dataset.add_samples(samples) print("Dataset created successfully.") return dataset def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int): """Fits a GMM and adds cluster labels to the dataset.""" field_name = f"gmm_cluster_{n_clusters}" print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...") gmm = GaussianMixture(n_components=n_clusters, random_state=42) cluster_labels = gmm.fit_predict(embeddings) dataset.add_sample_field(f"{field_name}", fo.IntField) for sample, label in zip(dataset, cluster_labels): sample[field_name] = int(label) sample.save() print("GMM clustering complete.") def compute_embedding_visualization( dataset: fo.Dataset, embeddings: np.ndarray, brain_key: str, method: str = "tsne" ): """Computes a 2D visualization of embeddings and adds it to the dataset.""" vis_field_name = f"{brain_key}_{method}" print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...") results = fob.compute_visualization( dataset, embeddings=embeddings, brain_key=brain_key, method=method, verbose=True ) dataset.set_values(vis_field_name, results.current_points) print(f"Visualization complete. Points stored in field '{vis_field_name}'.") def main(): """Main function to run the analysis pipeline.""" parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.") parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.") parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.") parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.") parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.") parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.") parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.") args = parser.parse_args() # 1. Load data file_paths, embeddings = load_embeddings_from_json(args.json_path) # 2. Create FiftyOne Dataset dataset = create_fiftyone_dataset(args.dataset_name, file_paths) # 3. Run GMM clustering for each specified number of clusters for n in args.n_clusters: add_gmm_clusters_to_dataset(dataset, embeddings, n) # 4. Compute visualization compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method) # 5. Launch the app if not args.no_launch: print("Launching the FiftyOne App...") session = fo.launch_app(dataset) session.wait() else: print("Processing complete. To view the results, launch the app manually with:") print(f">>> import fiftyone as fo") print(f">>> dataset = fo.load_dataset('{args.dataset_name}')") print(f">>> session = fo.launch_app(dataset)") if __name__ == "__main__": main()