[Add] Adaptive clustering - add files

This commit is contained in:
2025-07-10 20:36:07 +00:00
committed by trungkienbkhn
parent 8e1d0568a2
commit bfdaa142c2
3 changed files with 772 additions and 0 deletions

119
src/data_image_selection.py Normal file
View File

@@ -0,0 +1,119 @@
import json
import numpy as np
from typing import List, Tuple
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_samples
from tqdm import tqdm
import argparse
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
"""Loads file paths and embeddings from a JSON file."""
print(f"Loading embeddings from {json_path}...")
with open(json_path, "r") as f:
embedding_data = json.load(f)
file_paths = [record["filepath"] for record in embedding_data]
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
return file_paths, embeddings
def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[GaussianMixture, float]:
"""Fits a GMM and adds cluster labels to the dataset."""
field_name = f"gmm_cluster_{n_clusters}"
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
try:
gmm.fit(embeddings)
print("GMM clustering complete.")
except:
print("Passing this, error")
return None, np.inf
print("BIC score:", gmm.bic(embeddings))
return gmm, gmm.bic(embeddings)
def analyze_and_select_gmm(
embeddings: np.ndarray,
min_num_clusters: int,
max_num_clusters: int,
step_size: int = 10
):
n_clusters_list = list(range(min_num_clusters, max_num_clusters + 1, step_size))
bic_list = []
gmm_list = []
for n_clusters in tqdm(n_clusters_list, desc = "GMM analysis based on number of clusters"):
gmm, bic_score = create_gmm_and_bic_score(embeddings=embeddings, n_clusters=n_clusters)
gmm_list.append(gmm)
bic_list.append(bic_score)
bic_min_index = np.argmin(bic_list)
selected_gmm = gmm_list[bic_min_index]
return selected_gmm
def select_samples(
embeddings: np.ndarray,
gaussian_mixtures: GaussianMixture,
filepaths: List[str]
):
percentage_to_label:float = 0.5
percentage_representative: float = 0.33
total_number_images = len(filepaths)
total_images_to_label = int(total_number_images * percentage_to_label)
num_high_confidence = int(total_images_to_label * percentage_representative)
num_low_confidence = total_images_to_label - num_high_confidence
print(f"Total images: {total_number_images}")
print(f"Total to label ({percentage_to_label*100}%): {total_images_to_label}")
print(f" - Highly representative samples: {num_high_confidence}")
print(f" - Least representative samples: {num_low_confidence}")
labels = gaussian_mixtures.predict(embeddings)
individual_scores = silhouette_samples(embeddings, labels)
sorted_indices = np.argsort(individual_scores)
# 3. Identify the indices for low and high confidence samples
low_confidence_indices = sorted_indices[:num_low_confidence]
high_confidence_indices = sorted_indices[-num_high_confidence:]
# 4. Combine them into a final list of indices to be labeled
final_indices_to_label = np.concatenate([low_confidence_indices, high_confidence_indices])
# 5. Get the corresponding image file paths
labeling_queue = [filepaths[i] for i in final_indices_to_label]
return labeling_queue
def main():
"""Main function to run the analysis pipeline."""
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
# parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
parser.add_argument("--step_size_clusters", type=int, default=10, help="Maximum numbers of clusters for GMM.")
args = parser.parse_args()
# 1. Load data
filepaths, embeddings = load_embeddings_from_json(args.json_path)
gaussian_mixture = analyze_and_select_gmm(
embeddings=embeddings,
min_num_clusters=args.min_num_clusters,
max_num_clusters=args.max_num_clusters,
step_size=args.step_size_clusters
)
selected_filepaths = select_samples(
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
)
with open('./output_files.json', "w") as file:
file.write(json.dumps(selected_filepaths, indent=4))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,106 @@
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
from sklearn.mixture import GaussianMixture
import json
import argparse
from typing import List, Tuple
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
"""Loads file paths and embeddings from a JSON file."""
print(f"Loading embeddings from {json_path}...")
with open(json_path, "r") as f:
embedding_data = json.load(f)
file_paths = [record["filepath"] for record in embedding_data]
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
return file_paths, embeddings
def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset:
"""Creates or overwrites a FiftyOne dataset from a list of file paths."""
if dataset_name in fo.list_datasets():
print(f"Dataset '{dataset_name}' already exists. Deleting it.")
fo.delete_dataset(dataset_name)
print(f"Creating new dataset '{dataset_name}'...")
dataset = fo.Dataset(dataset_name)
samples = [fo.Sample(filepath=p) for p in file_paths]
dataset.add_samples(samples)
print("Dataset created successfully.")
return dataset
def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int):
"""Fits a GMM and adds cluster labels to the dataset."""
field_name = f"gmm_cluster_{n_clusters}"
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
cluster_labels = gmm.fit_predict(embeddings)
dataset.add_sample_field(f"{field_name}", fo.IntField)
for sample, label in zip(dataset, cluster_labels):
sample[field_name] = int(label)
sample.save()
print("GMM clustering complete.")
def compute_embedding_visualization(
dataset: fo.Dataset,
embeddings: np.ndarray,
brain_key: str,
method: str = "tsne"
):
"""Computes a 2D visualization of embeddings and adds it to the dataset."""
vis_field_name = f"{brain_key}_{method}"
print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...")
results = fob.compute_visualization(
dataset,
embeddings=embeddings,
brain_key=brain_key,
method=method,
verbose=True
)
dataset.set_values(vis_field_name, results.current_points)
print(f"Visualization complete. Points stored in field '{vis_field_name}'.")
def main():
"""Main function to run the analysis pipeline."""
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.")
parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.")
parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.")
parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.")
parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.")
args = parser.parse_args()
# 1. Load data
file_paths, embeddings = load_embeddings_from_json(args.json_path)
# 2. Create FiftyOne Dataset
dataset = create_fiftyone_dataset(args.dataset_name, file_paths)
# 3. Run GMM clustering for each specified number of clusters
for n in args.n_clusters:
add_gmm_clusters_to_dataset(dataset, embeddings, n)
# 4. Compute visualization
compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method)
# 5. Launch the app
if not args.no_launch:
print("Launching the FiftyOne App...")
session = fo.launch_app(dataset)
session.wait()
else:
print("Processing complete. To view the results, launch the app manually with:")
print(f">>> import fiftyone as fo")
print(f">>> dataset = fo.load_dataset('{args.dataset_name}')")
print(f">>> session = fo.launch_app(dataset)")
if __name__ == "__main__":
main()