[Add] Clustering - adaptive selection
This commit is contained in:
parent
bfdaa142c2
commit
641f81980a
@ -5,6 +5,7 @@ from sklearn.mixture import GaussianMixture
|
||||
from sklearn.metrics import silhouette_samples
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import os
|
||||
|
||||
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
||||
"""Loads file paths and embeddings from a JSON file."""
|
||||
@ -30,8 +31,8 @@ def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[G
|
||||
except:
|
||||
print("Passing this, error")
|
||||
return None, np.inf
|
||||
print("BIC score:", gmm.bic(embeddings))
|
||||
return gmm, gmm.bic(embeddings)
|
||||
print("BIC score:", np.abs(gmm.bic(embeddings)))
|
||||
return gmm, np.abs(gmm.bic(embeddings))
|
||||
|
||||
def analyze_and_select_gmm(
|
||||
embeddings: np.ndarray,
|
||||
@ -91,7 +92,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
||||
|
||||
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
||||
# parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
|
||||
parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
|
||||
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
|
||||
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
|
||||
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
|
||||
@ -100,6 +101,8 @@ def main():
|
||||
|
||||
# 1. Load data
|
||||
filepaths, embeddings = load_embeddings_from_json(args.json_path)
|
||||
with open(args.json_label_path) as file:
|
||||
label_data = json.load(file)
|
||||
|
||||
gaussian_mixture = analyze_and_select_gmm(
|
||||
embeddings=embeddings,
|
||||
@ -112,8 +115,17 @@ def main():
|
||||
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
|
||||
)
|
||||
|
||||
selected_label = []
|
||||
selected_prefixes = list(set([os.path.splitext(os.path.split(datum)[1])[0].split("_scale")[0][:-2] for datum in selected_filepaths]))
|
||||
for prefix in tqdm(selected_prefixes):
|
||||
try:
|
||||
label = list(filter(lambda x: bool(prefix.count(x["image"])), label_data))[0]
|
||||
selected_label.append(label)
|
||||
except:
|
||||
continue
|
||||
|
||||
with open('./output_files.json', "w") as file:
|
||||
file.write(json.dumps(selected_filepaths, indent=4))
|
||||
file.write(json.dumps(selected_label, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user