Files
embedding-clustering/src/display_embeddings_and_images.py

106 lines
4.4 KiB
Python

import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
from sklearn.mixture import GaussianMixture
import json
import argparse
from typing import List, Tuple
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
"""Loads file paths and embeddings from a JSON file."""
print(f"Loading embeddings from {json_path}...")
with open(json_path, "r") as f:
embedding_data = json.load(f)
file_paths = [record["filepath"] for record in embedding_data]
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
return file_paths, embeddings
def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset:
"""Creates or overwrites a FiftyOne dataset from a list of file paths."""
if dataset_name in fo.list_datasets():
print(f"Dataset '{dataset_name}' already exists. Deleting it.")
fo.delete_dataset(dataset_name)
print(f"Creating new dataset '{dataset_name}'...")
dataset = fo.Dataset(dataset_name)
samples = [fo.Sample(filepath=p) for p in file_paths]
dataset.add_samples(samples)
print("Dataset created successfully.")
return dataset
def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int):
"""Fits a GMM and adds cluster labels to the dataset."""
field_name = f"gmm_cluster_{n_clusters}"
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
cluster_labels = gmm.fit_predict(embeddings)
dataset.add_sample_field(f"{field_name}", fo.IntField)
for sample, label in zip(dataset, cluster_labels):
sample[field_name] = int(label)
sample.save()
print("GMM clustering complete.")
def compute_embedding_visualization(
dataset: fo.Dataset,
embeddings: np.ndarray,
brain_key: str,
method: str = "tsne"
):
"""Computes a 2D visualization of embeddings and adds it to the dataset."""
vis_field_name = f"{brain_key}_{method}"
print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...")
results = fob.compute_visualization(
dataset,
embeddings=embeddings,
brain_key=brain_key,
method=method,
verbose=True
)
dataset.set_values(vis_field_name, results.current_points)
print(f"Visualization complete. Points stored in field '{vis_field_name}'.")
def main():
"""Main function to run the analysis pipeline."""
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.")
parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.")
parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.")
parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.")
parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.")
args = parser.parse_args()
# 1. Load data
file_paths, embeddings = load_embeddings_from_json(args.json_path)
# 2. Create FiftyOne Dataset
dataset = create_fiftyone_dataset(args.dataset_name, file_paths)
# 3. Run GMM clustering for each specified number of clusters
for n in args.n_clusters:
add_gmm_clusters_to_dataset(dataset, embeddings, n)
# 4. Compute visualization
compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method)
# 5. Launch the app
if not args.no_launch:
print("Launching the FiftyOne App...")
session = fo.launch_app(dataset)
session.wait()
else:
print("Processing complete. To view the results, launch the app manually with:")
print(f">>> import fiftyone as fo")
print(f">>> dataset = fo.load_dataset('{args.dataset_name}')")
print(f">>> session = fo.launch_app(dataset)")
if __name__ == "__main__":
main()