[Add] Adaptive clustering - add files
This commit is contained in:
106
src/display_embeddings_and_images.py
Normal file
106
src/display_embeddings_and_images.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import fiftyone as fo
|
||||
import fiftyone.brain as fob
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import json
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
|
||||
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
||||
"""Loads file paths and embeddings from a JSON file."""
|
||||
print(f"Loading embeddings from {json_path}...")
|
||||
with open(json_path, "r") as f:
|
||||
embedding_data = json.load(f)
|
||||
|
||||
file_paths = [record["filepath"] for record in embedding_data]
|
||||
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
|
||||
|
||||
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
|
||||
return file_paths, embeddings
|
||||
|
||||
def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset:
|
||||
"""Creates or overwrites a FiftyOne dataset from a list of file paths."""
|
||||
if dataset_name in fo.list_datasets():
|
||||
print(f"Dataset '{dataset_name}' already exists. Deleting it.")
|
||||
fo.delete_dataset(dataset_name)
|
||||
|
||||
print(f"Creating new dataset '{dataset_name}'...")
|
||||
dataset = fo.Dataset(dataset_name)
|
||||
samples = [fo.Sample(filepath=p) for p in file_paths]
|
||||
dataset.add_samples(samples)
|
||||
print("Dataset created successfully.")
|
||||
return dataset
|
||||
|
||||
def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int):
|
||||
"""Fits a GMM and adds cluster labels to the dataset."""
|
||||
field_name = f"gmm_cluster_{n_clusters}"
|
||||
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
|
||||
|
||||
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
|
||||
cluster_labels = gmm.fit_predict(embeddings)
|
||||
|
||||
dataset.add_sample_field(f"{field_name}", fo.IntField)
|
||||
for sample, label in zip(dataset, cluster_labels):
|
||||
sample[field_name] = int(label)
|
||||
sample.save()
|
||||
print("GMM clustering complete.")
|
||||
|
||||
def compute_embedding_visualization(
|
||||
dataset: fo.Dataset,
|
||||
embeddings: np.ndarray,
|
||||
brain_key: str,
|
||||
method: str = "tsne"
|
||||
):
|
||||
"""Computes a 2D visualization of embeddings and adds it to the dataset."""
|
||||
vis_field_name = f"{brain_key}_{method}"
|
||||
print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...")
|
||||
|
||||
results = fob.compute_visualization(
|
||||
dataset,
|
||||
embeddings=embeddings,
|
||||
brain_key=brain_key,
|
||||
method=method,
|
||||
verbose=True
|
||||
)
|
||||
dataset.set_values(vis_field_name, results.current_points)
|
||||
print(f"Visualization complete. Points stored in field '{vis_field_name}'.")
|
||||
|
||||
def main():
|
||||
"""Main function to run the analysis pipeline."""
|
||||
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
||||
|
||||
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
||||
parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.")
|
||||
parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.")
|
||||
parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.")
|
||||
parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.")
|
||||
parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load data
|
||||
file_paths, embeddings = load_embeddings_from_json(args.json_path)
|
||||
|
||||
# 2. Create FiftyOne Dataset
|
||||
dataset = create_fiftyone_dataset(args.dataset_name, file_paths)
|
||||
|
||||
# 3. Run GMM clustering for each specified number of clusters
|
||||
for n in args.n_clusters:
|
||||
add_gmm_clusters_to_dataset(dataset, embeddings, n)
|
||||
|
||||
# 4. Compute visualization
|
||||
compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method)
|
||||
|
||||
# 5. Launch the app
|
||||
if not args.no_launch:
|
||||
print("Launching the FiftyOne App...")
|
||||
session = fo.launch_app(dataset)
|
||||
session.wait()
|
||||
else:
|
||||
print("Processing complete. To view the results, launch the app manually with:")
|
||||
print(f">>> import fiftyone as fo")
|
||||
print(f">>> dataset = fo.load_dataset('{args.dataset_name}')")
|
||||
print(f">>> session = fo.launch_app(dataset)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user