From 641f81980aade296c694c8e0220888015e269293 Mon Sep 17 00:00:00 2001 From: lphatnguyen Date: Fri, 11 Jul 2025 14:20:29 +0000 Subject: [PATCH] [Add] Clustering - adaptive selection --- src/data_image_selection.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/data_image_selection.py b/src/data_image_selection.py index 15507fc..39c24c2 100644 --- a/src/data_image_selection.py +++ b/src/data_image_selection.py @@ -5,6 +5,7 @@ from sklearn.mixture import GaussianMixture from sklearn.metrics import silhouette_samples from tqdm import tqdm import argparse +import os def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]: """Loads file paths and embeddings from a JSON file.""" @@ -30,8 +31,8 @@ def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[G except: print("Passing this, error") return None, np.inf - print("BIC score:", gmm.bic(embeddings)) - return gmm, gmm.bic(embeddings) + print("BIC score:", np.abs(gmm.bic(embeddings))) + return gmm, np.abs(gmm.bic(embeddings)) def analyze_and_select_gmm( embeddings: np.ndarray, @@ -91,7 +92,7 @@ def main(): parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.") parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.") - # parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.") + parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.") # parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.") parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.") parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.") @@ -100,6 +101,8 @@ def main(): # 1. Load data filepaths, embeddings = load_embeddings_from_json(args.json_path) + with open(args.json_label_path) as file: + label_data = json.load(file) gaussian_mixture = analyze_and_select_gmm( embeddings=embeddings, @@ -112,8 +115,17 @@ def main(): embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths ) + selected_label = [] + selected_prefixes = list(set([os.path.splitext(os.path.split(datum)[1])[0].split("_scale")[0][:-2] for datum in selected_filepaths])) + for prefix in tqdm(selected_prefixes): + try: + label = list(filter(lambda x: bool(prefix.count(x["image"])), label_data))[0] + selected_label.append(label) + except: + continue + with open('./output_files.json', "w") as file: - file.write(json.dumps(selected_filepaths, indent=4)) + file.write(json.dumps(selected_label, indent=4)) if __name__ == "__main__": main() \ No newline at end of file