[Add] Clustering - adaptive selection
This commit is contained in:
parent
bfdaa142c2
commit
641f81980a
@ -5,6 +5,7 @@ from sklearn.mixture import GaussianMixture
|
|||||||
from sklearn.metrics import silhouette_samples
|
from sklearn.metrics import silhouette_samples
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
def load_embeddings_from_json(json_path: str) -> Tuple[List[str], np.ndarray]:
|
||||||
"""Loads file paths and embeddings from a JSON file."""
|
"""Loads file paths and embeddings from a JSON file."""
|
||||||
@ -30,8 +31,8 @@ def create_gmm_and_bic_score(embeddings: np.ndarray, n_clusters: int) -> Tuple[G
|
|||||||
except:
|
except:
|
||||||
print("Passing this, error")
|
print("Passing this, error")
|
||||||
return None, np.inf
|
return None, np.inf
|
||||||
print("BIC score:", gmm.bic(embeddings))
|
print("BIC score:", np.abs(gmm.bic(embeddings)))
|
||||||
return gmm, gmm.bic(embeddings)
|
return gmm, np.abs(gmm.bic(embeddings))
|
||||||
|
|
||||||
def analyze_and_select_gmm(
|
def analyze_and_select_gmm(
|
||||||
embeddings: np.ndarray,
|
embeddings: np.ndarray,
|
||||||
@ -91,7 +92,7 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
parser = argparse.ArgumentParser(description="Cluster and visualize document embeddings with FiftyOne.")
|
||||||
|
|
||||||
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
parser.add_argument("json_path", type=str, help="Path to the JSON file containing embeddings.")
|
||||||
# parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
|
parser.add_argument("json_label_path", type=str, help="Path to the JSON label file.")
|
||||||
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
|
# parser.add_argument("--bill-type", type = str, help="Health invoice document type, eg. Ostéopathie, lentille, etc.")
|
||||||
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
|
parser.add_argument("--min_num_clusters", type=int, default=10, help="Minimum numbers of clusters for GMM.")
|
||||||
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
|
parser.add_argument("--max_num_clusters", type=int, default=100, help="Maximum numbers of clusters for GMM.")
|
||||||
@ -100,6 +101,8 @@ def main():
|
|||||||
|
|
||||||
# 1. Load data
|
# 1. Load data
|
||||||
filepaths, embeddings = load_embeddings_from_json(args.json_path)
|
filepaths, embeddings = load_embeddings_from_json(args.json_path)
|
||||||
|
with open(args.json_label_path) as file:
|
||||||
|
label_data = json.load(file)
|
||||||
|
|
||||||
gaussian_mixture = analyze_and_select_gmm(
|
gaussian_mixture = analyze_and_select_gmm(
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
@ -112,8 +115,17 @@ def main():
|
|||||||
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
|
embeddings=embeddings, gaussian_mixtures=gaussian_mixture, filepaths=filepaths
|
||||||
)
|
)
|
||||||
|
|
||||||
|
selected_label = []
|
||||||
|
selected_prefixes = list(set([os.path.splitext(os.path.split(datum)[1])[0].split("_scale")[0][:-2] for datum in selected_filepaths]))
|
||||||
|
for prefix in tqdm(selected_prefixes):
|
||||||
|
try:
|
||||||
|
label = list(filter(lambda x: bool(prefix.count(x["image"])), label_data))[0]
|
||||||
|
selected_label.append(label)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
with open('./output_files.json', "w") as file:
|
with open('./output_files.json', "w") as file:
|
||||||
file.write(json.dumps(selected_filepaths, indent=4))
|
file.write(json.dumps(selected_label, indent=4))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
Reference in New Issue
Block a user