[Add] Clustering - adaptive selection

This commit is contained in:
lphatnguyen 2025-07-11 14:20:29 +00:00 committed by trungkienbkhn
parent bfdaa142c2
commit 641f81980a

View File

@ -5,6 +5,7 @@ from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_samples
from tqdm import tqdm
import argparse
import os
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
"""Loads file paths and embeddings from a JSON file."""
@ -30,8 +31,8 @@ def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[G
except:
print("Passing this, error")
return None, np.inf
print("BIC score:", gmm.bic(embeddings))
return gmm, gmm.bic(embeddings)
print("BIC score:", np.abs(gmm.bic(embeddings)))
return gmm, np.abs(gmm.bic(embeddings))
def analyze_and_select_gmm(
embeddings: np.ndarray,
@ -91,7 +92,7 @@ def main():
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
# parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
@ -100,6 +101,8 @@ def main():
# 1. Load data
filepaths, embeddings = load_embeddings_from_json(args.json_path)
with open(args.json_label_path) as file:
label_data = json.load(file)
gaussian_mixture = analyze_and_select_gmm(
embeddings=embeddings,
@ -112,8 +115,17 @@ def main():
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
)
selected_label = []
selected_prefixes = list(set([os.path.splitext(os.path.split(datum)[1])[0].split("_scale")[0][:-2] for datum in selected_filepaths]))
for prefix in tqdm(selected_prefixes):
try:
label = list(filter(lambda x: bool(prefix.count(x["image"])), label_data))[0]
selected_label.append(label)
except:
continue
with open('./output_files.json', "w") as file:
file.write(json.dumps(selected_filepaths, indent=4))
file.write(json.dumps(selected_label, indent=4))
if __name__ == "__main__":
main()