[Add] Adaptive clustering - add files
This commit is contained in:
119
src/data_image_selection.py
Normal file
119
src/data_image_selection.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from sklearn.mixture import GaussianMixture
|
||||
from sklearn.metrics import silhouette_samples
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
||||
"""Loads file paths and embeddings from a JSON file."""
|
||||
print(f"Loading embeddings from {json_path}...")
|
||||
with open(json_path, "r") as f:
|
||||
embedding_data = json.load(f)
|
||||
|
||||
file_paths = [record["filepath"] for record in embedding_data]
|
||||
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
|
||||
|
||||
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
|
||||
return file_paths, embeddings
|
||||
|
||||
def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[GaussianMixture, float]:
|
||||
"""Fits a GMM and adds cluster labels to the dataset."""
|
||||
field_name = f"gmm_cluster_{n_clusters}"
|
||||
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
|
||||
|
||||
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
|
||||
try:
|
||||
gmm.fit(embeddings)
|
||||
print("GMM clustering complete.")
|
||||
except:
|
||||
print("Passing this, error")
|
||||
return None, np.inf
|
||||
print("BIC score:", gmm.bic(embeddings))
|
||||
return gmm, gmm.bic(embeddings)
|
||||
|
||||
def analyze_and_select_gmm(
|
||||
embeddings: np.ndarray,
|
||||
min_num_clusters: int,
|
||||
max_num_clusters: int,
|
||||
step_size: int = 10
|
||||
):
|
||||
n_clusters_list = list(range(min_num_clusters, max_num_clusters + 1, step_size))
|
||||
bic_list = []
|
||||
gmm_list = []
|
||||
for n_clusters in tqdm(n_clusters_list, desc = "GMM analysis based on number of clusters"):
|
||||
gmm, bic_score = create_gmm_and_bic_score(embeddings=embeddings, n_clusters=n_clusters)
|
||||
gmm_list.append(gmm)
|
||||
bic_list.append(bic_score)
|
||||
|
||||
bic_min_index = np.argmin(bic_list)
|
||||
selected_gmm = gmm_list[bic_min_index]
|
||||
return selected_gmm
|
||||
|
||||
def select_samples(
|
||||
embeddings: np.ndarray,
|
||||
gaussian_mixtures: GaussianMixture,
|
||||
filepaths: List[str]
|
||||
):
|
||||
percentage_to_label:float = 0.5
|
||||
percentage_representative: float = 0.33
|
||||
|
||||
total_number_images = len(filepaths)
|
||||
total_images_to_label = int(total_number_images * percentage_to_label)
|
||||
|
||||
num_high_confidence = int(total_images_to_label * percentage_representative)
|
||||
num_low_confidence = total_images_to_label - num_high_confidence
|
||||
|
||||
print(f"Total images: {total_number_images}")
|
||||
print(f"Total to label ({percentage_to_label*100}%): {total_images_to_label}")
|
||||
print(f" - Highly representative samples: {num_high_confidence}")
|
||||
print(f" - Least representative samples: {num_low_confidence}")
|
||||
|
||||
labels = gaussian_mixtures.predict(embeddings)
|
||||
individual_scores = silhouette_samples(embeddings, labels)
|
||||
|
||||
sorted_indices = np.argsort(individual_scores)
|
||||
# 3. Identify the indices for low and high confidence samples
|
||||
low_confidence_indices = sorted_indices[:num_low_confidence]
|
||||
high_confidence_indices = sorted_indices[-num_high_confidence:]
|
||||
|
||||
# 4. Combine them into a final list of indices to be labeled
|
||||
final_indices_to_label = np.concatenate([low_confidence_indices, high_confidence_indices])
|
||||
|
||||
# 5. Get the corresponding image file paths
|
||||
labeling_queue = [filepaths[i] for i in final_indices_to_label]
|
||||
|
||||
return labeling_queue
|
||||
|
||||
def main():
|
||||
"""Main function to run the analysis pipeline."""
|
||||
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
||||
|
||||
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
||||
# parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
|
||||
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
|
||||
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
|
||||
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
|
||||
parser.add_argument("--step_size_clusters", type=int, default=10, help="Maximum numbers of clusters for GMM.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load data
|
||||
filepaths, embeddings = load_embeddings_from_json(args.json_path)
|
||||
|
||||
gaussian_mixture = analyze_and_select_gmm(
|
||||
embeddings=embeddings,
|
||||
min_num_clusters=args.min_num_clusters,
|
||||
max_num_clusters=args.max_num_clusters,
|
||||
step_size=args.step_size_clusters
|
||||
)
|
||||
|
||||
selected_filepaths = select_samples(
|
||||
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
|
||||
)
|
||||
|
||||
with open('./output_files.json', "w") as file:
|
||||
file.write(json.dumps(selected_filepaths, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
106
src/display_embeddings_and_images.py
Normal file
106
src/display_embeddings_and_images.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import fiftyone as fo
|
||||
import fiftyone.brain as fob
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import json
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
|
||||
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
||||
"""Loads file paths and embeddings from a JSON file."""
|
||||
print(f"Loading embeddings from {json_path}...")
|
||||
with open(json_path, "r") as f:
|
||||
embedding_data = json.load(f)
|
||||
|
||||
file_paths = [record["filepath"] for record in embedding_data]
|
||||
embeddings = np.array([record["embedding"] for record in embedding_data], dtype=np.float32)
|
||||
|
||||
print(f"Loaded {len(file_paths)} samples with embedding dimension {embeddings.shape[1]}.")
|
||||
return file_paths, embeddings
|
||||
|
||||
def create_fiftyone_dataset(dataset_name: str, file_paths: List[str]) -> fo.Dataset:
|
||||
"""Creates or overwrites a FiftyOne dataset from a list of file paths."""
|
||||
if dataset_name in fo.list_datasets():
|
||||
print(f"Dataset '{dataset_name}' already exists. Deleting it.")
|
||||
fo.delete_dataset(dataset_name)
|
||||
|
||||
print(f"Creating new dataset '{dataset_name}'...")
|
||||
dataset = fo.Dataset(dataset_name)
|
||||
samples = [fo.Sample(filepath=p) for p in file_paths]
|
||||
dataset.add_samples(samples)
|
||||
print("Dataset created successfully.")
|
||||
return dataset
|
||||
|
||||
def add_gmm_clusters_to_dataset(dataset: fo.Dataset, embeddings: np.ndarray, n_clusters: int):
|
||||
"""Fits a GMM and adds cluster labels to the dataset."""
|
||||
field_name = f"gmm_cluster_{n_clusters}"
|
||||
print(f"Running GMM with {n_clusters} components. Results will be in field '{field_name}'...")
|
||||
|
||||
gmm = GaussianMixture(n_components=n_clusters, random_state=42)
|
||||
cluster_labels = gmm.fit_predict(embeddings)
|
||||
|
||||
dataset.add_sample_field(f"{field_name}", fo.IntField)
|
||||
for sample, label in zip(dataset, cluster_labels):
|
||||
sample[field_name] = int(label)
|
||||
sample.save()
|
||||
print("GMM clustering complete.")
|
||||
|
||||
def compute_embedding_visualization(
|
||||
dataset: fo.Dataset,
|
||||
embeddings: np.ndarray,
|
||||
brain_key: str,
|
||||
method: str = "tsne"
|
||||
):
|
||||
"""Computes a 2D visualization of embeddings and adds it to the dataset."""
|
||||
vis_field_name = f"{brain_key}_{method}"
|
||||
print(f"Computing {method.upper()} visualization with brain key '{brain_key}'...")
|
||||
|
||||
results = fob.compute_visualization(
|
||||
dataset,
|
||||
embeddings=embeddings,
|
||||
brain_key=brain_key,
|
||||
method=method,
|
||||
verbose=True
|
||||
)
|
||||
dataset.set_values(vis_field_name, results.current_points)
|
||||
print(f"Visualization complete. Points stored in field '{vis_field_name}'.")
|
||||
|
||||
def main():
|
||||
"""Main function to run the analysis pipeline."""
|
||||
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
||||
|
||||
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
||||
parser.add_argument("--dataset_name", type=str, default="embedding_analysis_dataset", help="Name for the FiftyOne dataset.")
|
||||
parser.add_argument("--n_clusters", type=int, nargs='+', default=[50, 100], help="A list of numbers of clusters for GMM.")
|
||||
parser.add_argument("--vis_method", type=str, default="tsne", choices=["tsne", "umap"], help="Visualization method.")
|
||||
parser.add_argument("--brain_key", type=str, default="doc_embeddings_viz", help="FiftyOne brain key for the visualization.")
|
||||
parser.add_argument("--no_launch", action="store_true", help="Do not launch the FiftyOne App after processing.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. Load data
|
||||
file_paths, embeddings = load_embeddings_from_json(args.json_path)
|
||||
|
||||
# 2. Create FiftyOne Dataset
|
||||
dataset = create_fiftyone_dataset(args.dataset_name, file_paths)
|
||||
|
||||
# 3. Run GMM clustering for each specified number of clusters
|
||||
for n in args.n_clusters:
|
||||
add_gmm_clusters_to_dataset(dataset, embeddings, n)
|
||||
|
||||
# 4. Compute visualization
|
||||
compute_embedding_visualization(dataset, embeddings, args.brain_key, args.vis_method)
|
||||
|
||||
# 5. Launch the app
|
||||
if not args.no_launch:
|
||||
print("Launching the FiftyOne App...")
|
||||
session = fo.launch_app(dataset)
|
||||
session.wait()
|
||||
else:
|
||||
print("Processing complete. To view the results, launch the app manually with:")
|
||||
print(f">>> import fiftyone as fo")
|
||||
print(f">>> dataset = fo.load_dataset('{args.dataset_name}')")
|
||||
print(f">>> session = fo.launch_app(dataset)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user