650 lines
33 KiB
Python
650 lines
33 KiB
Python
|
#!/usr/bin/env python3
|
||
|
"""
|
||
|
Extensive Gaussian Mixture Model clustering with grid search for optimal parameters
|
||
|
Includes BIC and AIC metrics for model selection
|
||
|
"""
|
||
|
|
||
|
import json
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
from sklearn.mixture import GaussianMixture
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
|
||
|
from sklearn.decomposition import PCA
|
||
|
import datetime
|
||
|
import csv
|
||
|
import argparse
|
||
|
import warnings
|
||
|
warnings.filterwarnings('ignore')
|
||
|
|
||
|
class GMMExtensiveClustering:
|
||
|
def __init__(self, embeddings_path):
|
||
|
self.embeddings_path = embeddings_path
|
||
|
self.embeddings = None
|
||
|
self.file_paths = None
|
||
|
self.load_embeddings()
|
||
|
|
||
|
def load_embeddings(self):
|
||
|
"""Load embeddings from JSON file"""
|
||
|
print(f"Loading embeddings from {self.embeddings_path}...")
|
||
|
with open(self.embeddings_path, 'r') as f:
|
||
|
data = json.load(f)
|
||
|
|
||
|
self.file_paths = []
|
||
|
embeddings_list = []
|
||
|
|
||
|
for item in data:
|
||
|
self.file_paths.append(item['filepath'])
|
||
|
embeddings_list.append(item['embedding'])
|
||
|
|
||
|
self.embeddings = np.array(embeddings_list, dtype=np.float32)
|
||
|
print(f"Loaded {len(self.file_paths)} samples with embedding dimension {self.embeddings.shape[1]}")
|
||
|
|
||
|
# Standardize embeddings for better clustering
|
||
|
self.scaler = StandardScaler()
|
||
|
self.embeddings_scaled = self.scaler.fit_transform(self.embeddings)
|
||
|
|
||
|
def run_gmm_grid_search(self):
|
||
|
"""Run GMM with optimized grid search for faster execution"""
|
||
|
print("\n" + "="*70)
|
||
|
print("RUNNING GAUSSIAN MIXTURE MODEL CLUSTERING WITH OPTIMIZED GRID SEARCH")
|
||
|
print("="*70)
|
||
|
|
||
|
# Optimized GMM parameter candidates for faster execution
|
||
|
|
||
|
# Smart n_components range with larger steps
|
||
|
max_components = min(50, len(self.embeddings_scaled) // 20) # Reduced max and increased divisor
|
||
|
n_components_candidates = []
|
||
|
|
||
|
# Progressive step sizes: smaller steps for low numbers, larger for high
|
||
|
for n in range(2, max_components + 1):
|
||
|
if n <= 5:
|
||
|
n_components_candidates.append(n) # 2, 3, 4, 5
|
||
|
elif n <= 10:
|
||
|
if n % 2 == 0: # 6, 8, 10
|
||
|
n_components_candidates.append(n)
|
||
|
else:
|
||
|
if n % 3 == 2: # 11, 14, 17, 20
|
||
|
n_components_candidates.append(n)
|
||
|
|
||
|
# Reduced covariance types - focus on most important ones
|
||
|
covariance_types = [
|
||
|
# 'full', 'diag',
|
||
|
'tied', 'spherical'
|
||
|
] # Removed 'tied' and 'spherical' as they're less common
|
||
|
|
||
|
# Simplified regularization - focus on key values
|
||
|
reg_covar_candidates = [1e-5, 1e-4, 1e-3] # Removed extreme values
|
||
|
|
||
|
# Reduced n_init - 1 is often sufficient for good initialization methods
|
||
|
n_init_candidates = [1, 5] # Removed 10 to save time
|
||
|
|
||
|
# Focus on best initialization methods
|
||
|
init_params_candidates = ['kmeans', 'k-means++'] # Removed 'random' and 'random_from_data'
|
||
|
|
||
|
# Simplified max_iter - most problems converge quickly
|
||
|
max_iter_candidates = [100, 300] # Removed 500, added 300 as middle ground
|
||
|
|
||
|
print(f"Optimized parameter combinations:")
|
||
|
print(f" - n_components: {len(n_components_candidates)} values {n_components_candidates}")
|
||
|
print(f" - covariance_types: {len(covariance_types)} options {covariance_types}")
|
||
|
print(f" - reg_covar: {len(reg_covar_candidates)} values {reg_covar_candidates}")
|
||
|
print(f" - n_init: {len(n_init_candidates)} values {n_init_candidates}")
|
||
|
print(f" - init_params: {len(init_params_candidates)} options {init_params_candidates}")
|
||
|
print(f" - max_iter: {len(max_iter_candidates)} values {max_iter_candidates}")
|
||
|
|
||
|
total_combinations = (len(n_components_candidates) * len(covariance_types) *
|
||
|
len(reg_covar_candidates) * len(n_init_candidates) *
|
||
|
len(init_params_candidates) * len(max_iter_candidates))
|
||
|
print(f"Total combinations: {total_combinations} (optimized for speed)")
|
||
|
|
||
|
# Estimate time
|
||
|
estimated_time_per_combination = 0.5 # seconds (conservative estimate)
|
||
|
estimated_total_time = total_combinations * estimated_time_per_combination
|
||
|
print(f"Estimated runtime: {estimated_total_time/60:.1f} minutes")
|
||
|
print("This should be much faster...\n")
|
||
|
|
||
|
# Track all results for analysis
|
||
|
all_results = []
|
||
|
|
||
|
# Early stopping criteria for speed optimization
|
||
|
early_stopping_threshold = 0.7 # If we find a very good silhouette score, we can be less exhaustive
|
||
|
good_results_found = 0
|
||
|
max_good_results = 5 # Stop early if we find several very good results
|
||
|
|
||
|
best_bic_score = float('inf')
|
||
|
best_aic_score = float('inf')
|
||
|
best_silhouette_score = -1
|
||
|
best_params_bic = None
|
||
|
best_params_aic = None
|
||
|
best_params_silhouette = None
|
||
|
best_labels_bic = None
|
||
|
best_labels_aic = None
|
||
|
best_labels_silhouette = None
|
||
|
|
||
|
current_combination = 0
|
||
|
|
||
|
# Optimized iteration order: test simpler models first (fewer components, simpler covariance)
|
||
|
for covariance_type in covariance_types: # Start with covariance type
|
||
|
for n_components in n_components_candidates: # Then components
|
||
|
for init_params in init_params_candidates: # Good initialization methods
|
||
|
for reg_covar in reg_covar_candidates: # Regularization
|
||
|
for n_init in n_init_candidates: # Number of initializations
|
||
|
for max_iter in max_iter_candidates: # Iterations last
|
||
|
current_combination += 1
|
||
|
|
||
|
# Progress indicator with time estimation
|
||
|
if current_combination % 50 == 0 or current_combination == total_combinations:
|
||
|
progress = (current_combination / total_combinations) * 100
|
||
|
print(f"Progress: {current_combination}/{total_combinations} ({progress:.1f}%) - "
|
||
|
f"Best scores so far: BIC={best_bic_score:.2f}, Silhouette={best_silhouette_score:.3f}")
|
||
|
|
||
|
try:
|
||
|
# Early convergence check for faster models
|
||
|
tol = 1e-3 if n_components <= 5 else 1e-4 # Less strict tolerance for simple models
|
||
|
|
||
|
# Run GMM
|
||
|
gmm = GaussianMixture(
|
||
|
n_components=n_components,
|
||
|
covariance_type=covariance_type,
|
||
|
reg_covar=reg_covar,
|
||
|
n_init=n_init,
|
||
|
init_params=init_params,
|
||
|
max_iter=max_iter,
|
||
|
tol=tol, # Added tolerance for faster convergence
|
||
|
random_state=42
|
||
|
)
|
||
|
|
||
|
# Fit and predict
|
||
|
gmm.fit(self.embeddings_scaled)
|
||
|
labels = gmm.predict(self.embeddings_scaled)
|
||
|
|
||
|
# Quick validation - skip if model didn't converge properly
|
||
|
if not gmm.converged_ and max_iter <= 100:
|
||
|
continue # Skip non-converged simple models
|
||
|
|
||
|
# Calculate metrics
|
||
|
bic_score = gmm.bic(self.embeddings_scaled)
|
||
|
aic_score = gmm.aic(self.embeddings_scaled)
|
||
|
log_likelihood = gmm.score(self.embeddings_scaled)
|
||
|
|
||
|
# Only calculate clustering metrics if we have multiple clusters
|
||
|
if len(set(labels)) > 1:
|
||
|
silhouette = silhouette_score(self.embeddings_scaled, labels)
|
||
|
calinski_harabasz = calinski_harabasz_score(self.embeddings_scaled, labels)
|
||
|
davies_bouldin = davies_bouldin_score(self.embeddings_scaled, labels)
|
||
|
|
||
|
# Early stopping check
|
||
|
if silhouette > early_stopping_threshold:
|
||
|
good_results_found += 1
|
||
|
print(f"🎯 Excellent result found: n_comp={n_components}, cov={covariance_type}, "
|
||
|
f"silhouette={silhouette:.4f}")
|
||
|
|
||
|
else:
|
||
|
silhouette = -1
|
||
|
calinski_harabasz = 0
|
||
|
davies_bouldin = float('inf')
|
||
|
|
||
|
# Store result for analysis
|
||
|
result_info = {
|
||
|
'n_components': n_components,
|
||
|
'covariance_type': covariance_type,
|
||
|
'reg_covar': reg_covar,
|
||
|
'n_init': n_init,
|
||
|
'init_params': init_params,
|
||
|
'max_iter': max_iter,
|
||
|
'bic_score': bic_score,
|
||
|
'aic_score': aic_score,
|
||
|
'log_likelihood': log_likelihood,
|
||
|
'silhouette_score': silhouette,
|
||
|
'calinski_harabasz_score': calinski_harabasz,
|
||
|
'davies_bouldin_score': davies_bouldin,
|
||
|
'converged': gmm.converged_,
|
||
|
'n_iter': gmm.n_iter_,
|
||
|
'unique_clusters': len(set(labels))
|
||
|
}
|
||
|
|
||
|
all_results.append(result_info)
|
||
|
|
||
|
# Print promising results
|
||
|
if (silhouette > 0.3 and bic_score < np.percentile([r['bic_score'] for r in all_results], 25)):
|
||
|
print(f"n_components={n_components}, cov={covariance_type}, init={init_params}: "
|
||
|
f"BIC={bic_score:.2f}, AIC={aic_score:.2f}, silhouette={silhouette:.4f}")
|
||
|
|
||
|
# Track best results for different criteria
|
||
|
if bic_score < best_bic_score:
|
||
|
best_bic_score = bic_score
|
||
|
best_params_bic = {
|
||
|
'n_components': n_components,
|
||
|
'covariance_type': covariance_type,
|
||
|
'reg_covar': reg_covar,
|
||
|
'n_init': n_init,
|
||
|
'init_params': init_params,
|
||
|
'max_iter': max_iter
|
||
|
}
|
||
|
best_labels_bic = labels
|
||
|
|
||
|
if aic_score < best_aic_score:
|
||
|
best_aic_score = aic_score
|
||
|
best_params_aic = {
|
||
|
'n_components': n_components,
|
||
|
'covariance_type': covariance_type,
|
||
|
'reg_covar': reg_covar,
|
||
|
'n_init': n_init,
|
||
|
'init_params': init_params,
|
||
|
'max_iter': max_iter
|
||
|
}
|
||
|
best_labels_aic = labels
|
||
|
|
||
|
if silhouette > best_silhouette_score and len(set(labels)) > 1:
|
||
|
best_silhouette_score = silhouette
|
||
|
best_params_silhouette = {
|
||
|
'n_components': n_components,
|
||
|
'covariance_type': covariance_type,
|
||
|
'reg_covar': reg_covar,
|
||
|
'n_init': n_init,
|
||
|
'init_params': init_params,
|
||
|
'max_iter': max_iter
|
||
|
}
|
||
|
best_labels_silhouette = labels
|
||
|
|
||
|
# Early stopping check
|
||
|
if good_results_found >= 5 and silhouette > 0.6:
|
||
|
print(f"🛑 Early stopping triggered: Found {good_results_found} excellent results. "
|
||
|
f"Stopping at {current_combination}/{total_combinations} combinations.")
|
||
|
break
|
||
|
|
||
|
except Exception:
|
||
|
# Skip problematic parameter combinations
|
||
|
continue
|
||
|
|
||
|
# Break from nested loops if early stopping triggered
|
||
|
if good_results_found >= 5 and best_silhouette_score > 0.6:
|
||
|
break
|
||
|
if good_results_found >= 5 and best_silhouette_score > 0.6:
|
||
|
break
|
||
|
if good_results_found >= 5 and best_silhouette_score > 0.6:
|
||
|
break
|
||
|
if good_results_found >= 5 and best_silhouette_score > 0.6:
|
||
|
break
|
||
|
if good_results_found >= 5 and best_silhouette_score > 0.6:
|
||
|
break
|
||
|
|
||
|
# Analysis of results
|
||
|
print("\n" + "="*70)
|
||
|
print("GAUSSIAN MIXTURE MODEL GRID SEARCH ANALYSIS")
|
||
|
print("="*70)
|
||
|
|
||
|
if all_results:
|
||
|
import pandas as pd
|
||
|
df_results = pd.DataFrame(all_results)
|
||
|
|
||
|
print(f"Total parameter combinations tested: {len(df_results)}")
|
||
|
|
||
|
# Filter results with valid clustering (more than 1 cluster)
|
||
|
valid_results = df_results[df_results['unique_clusters'] > 1]
|
||
|
print(f"Combinations with valid clustering: {len(valid_results)}")
|
||
|
|
||
|
if len(valid_results) > 0:
|
||
|
# Best scores analysis
|
||
|
print(f"\nModel Selection Metrics:")
|
||
|
print(f"Best BIC score: {df_results['bic_score'].min():.2f}")
|
||
|
print(f"Best AIC score: {df_results['aic_score'].min():.2f}")
|
||
|
print(f"Best Log-Likelihood: {df_results['log_likelihood'].max():.2f}")
|
||
|
|
||
|
print(f"\nClustering Quality Metrics:")
|
||
|
print(f"Best silhouette score: {valid_results['silhouette_score'].max():.4f}")
|
||
|
print(f"Mean silhouette score: {valid_results['silhouette_score'].mean():.4f}")
|
||
|
print(f"Best Calinski-Harabasz score: {valid_results['calinski_harabasz_score'].max():.2f}")
|
||
|
print(f"Best Davies-Bouldin score: {valid_results['davies_bouldin_score'].min():.4f}")
|
||
|
|
||
|
# Top results by different criteria
|
||
|
print(f"\nTop 5 results by BIC (lower is better):")
|
||
|
top_bic = df_results.nsmallest(5, 'bic_score')
|
||
|
for idx, row in top_bic.iterrows():
|
||
|
print(f" n_comp={row['n_components']}, cov={row['covariance_type']}: "
|
||
|
f"BIC={row['bic_score']:.2f}, AIC={row['aic_score']:.2f}")
|
||
|
|
||
|
print(f"\nTop 5 results by AIC (lower is better):")
|
||
|
top_aic = df_results.nsmallest(5, 'aic_score')
|
||
|
for idx, row in top_aic.iterrows():
|
||
|
print(f" n_comp={row['n_components']}, cov={row['covariance_type']}: "
|
||
|
f"BIC={row['bic_score']:.2f}, AIC={row['aic_score']:.2f}")
|
||
|
|
||
|
if len(valid_results) > 0:
|
||
|
print(f"\nTop 5 results by Silhouette Score:")
|
||
|
top_silhouette = valid_results.nlargest(5, 'silhouette_score')
|
||
|
for idx, row in top_silhouette.iterrows():
|
||
|
print(f" n_comp={row['n_components']}, cov={row['covariance_type']}: "
|
||
|
f"silhouette={row['silhouette_score']:.4f}")
|
||
|
|
||
|
# Component count analysis
|
||
|
component_performance = df_results.groupby('n_components').agg({
|
||
|
'bic_score': 'min',
|
||
|
'aic_score': 'min',
|
||
|
'silhouette_score': 'max'
|
||
|
}).reset_index()
|
||
|
|
||
|
print(f"\nComponent count analysis (top 10 by BIC):")
|
||
|
top_components = component_performance.nsmallest(10, 'bic_score')
|
||
|
for idx, row in top_components.iterrows():
|
||
|
print(f" {row['n_components']} components: "
|
||
|
f"BIC={row['bic_score']:.2f}, AIC={row['aic_score']:.2f}, "
|
||
|
f"silhouette={row['silhouette_score']:.4f}")
|
||
|
|
||
|
print(f"\n📁 SAVING DETAILED RESULTS...")
|
||
|
print("="*30)
|
||
|
|
||
|
# Save detailed grid search results
|
||
|
self.save_gmm_grid_search_results(all_results,
|
||
|
best_params_bic, best_bic_score,
|
||
|
best_params_aic, best_aic_score,
|
||
|
best_params_silhouette, best_silhouette_score)
|
||
|
|
||
|
# Return best results based on BIC (primary), AIC (secondary), Silhouette (tertiary)
|
||
|
results = {
|
||
|
'bic': (best_labels_bic, best_params_bic, best_bic_score),
|
||
|
'aic': (best_labels_aic, best_params_aic, best_aic_score),
|
||
|
'silhouette': (best_labels_silhouette, best_params_silhouette, best_silhouette_score)
|
||
|
}
|
||
|
|
||
|
# Print best results
|
||
|
if best_labels_bic is not None:
|
||
|
print(f"\nBest GMM result by BIC:")
|
||
|
print(f"Parameters: {best_params_bic}")
|
||
|
print(f"BIC score: {best_bic_score:.2f}")
|
||
|
|
||
|
if best_labels_aic is not None:
|
||
|
print(f"\nBest GMM result by AIC:")
|
||
|
print(f"Parameters: {best_params_aic}")
|
||
|
print(f"AIC score: {best_aic_score:.2f}")
|
||
|
|
||
|
if best_labels_silhouette is not None:
|
||
|
print(f"\nBest GMM result by Silhouette:")
|
||
|
print(f"Parameters: {best_params_silhouette}")
|
||
|
print(f"Silhouette score: {best_silhouette_score:.4f}")
|
||
|
|
||
|
return results
|
||
|
|
||
|
def save_gmm_grid_search_results(self, all_results,
|
||
|
best_params_bic, best_bic_score,
|
||
|
best_params_aic, best_aic_score,
|
||
|
best_params_silhouette, best_silhouette_score):
|
||
|
"""Save detailed GMM grid search results to JSON file"""
|
||
|
|
||
|
# Prepare comprehensive results data
|
||
|
grid_search_data = {
|
||
|
"experiment_info": {
|
||
|
"timestamp": datetime.datetime.now().isoformat(),
|
||
|
"dataset_path": self.embeddings_path,
|
||
|
"total_samples": len(self.file_paths),
|
||
|
"embedding_dimension": self.embeddings.shape[1],
|
||
|
"total_combinations_tested": len(all_results),
|
||
|
"method": "Gaussian Mixture Model"
|
||
|
},
|
||
|
"best_results": {
|
||
|
"by_bic": {
|
||
|
"parameters": best_params_bic,
|
||
|
"bic_score": best_bic_score if best_bic_score != float('inf') else None
|
||
|
},
|
||
|
"by_aic": {
|
||
|
"parameters": best_params_aic,
|
||
|
"aic_score": best_aic_score if best_aic_score != float('inf') else None
|
||
|
},
|
||
|
"by_silhouette": {
|
||
|
"parameters": best_params_silhouette,
|
||
|
"silhouette_score": best_silhouette_score if best_silhouette_score > -1 else None
|
||
|
}
|
||
|
},
|
||
|
"all_trials": []
|
||
|
}
|
||
|
|
||
|
# Add all trial results
|
||
|
for i, result in enumerate(all_results):
|
||
|
trial_data = {
|
||
|
"trial_id": i + 1,
|
||
|
"parameters": {
|
||
|
"n_components": result['n_components'],
|
||
|
"covariance_type": result['covariance_type'],
|
||
|
"reg_covar": result['reg_covar'],
|
||
|
"n_init": result['n_init'],
|
||
|
"init_params": result['init_params'],
|
||
|
"max_iter": result['max_iter']
|
||
|
},
|
||
|
"results": {
|
||
|
"bic_score": result['bic_score'],
|
||
|
"aic_score": result['aic_score'],
|
||
|
"log_likelihood": result['log_likelihood'],
|
||
|
"silhouette_score": result['silhouette_score'],
|
||
|
"calinski_harabasz_score": result['calinski_harabasz_score'],
|
||
|
"davies_bouldin_score": result['davies_bouldin_score'],
|
||
|
"converged": result['converged'],
|
||
|
"n_iter": result['n_iter'],
|
||
|
"unique_clusters": result['unique_clusters']
|
||
|
}
|
||
|
}
|
||
|
grid_search_data["all_trials"].append(trial_data)
|
||
|
|
||
|
# Calculate summary statistics
|
||
|
if all_results:
|
||
|
bic_scores = [r['bic_score'] for r in all_results]
|
||
|
aic_scores = [r['aic_score'] for r in all_results]
|
||
|
log_likelihoods = [r['log_likelihood'] for r in all_results]
|
||
|
|
||
|
valid_silhouette = [r['silhouette_score'] for r in all_results if r['silhouette_score'] > -1]
|
||
|
|
||
|
grid_search_data["summary_statistics"] = {
|
||
|
"total_trials": len(all_results),
|
||
|
"valid_clustering_trials": len(valid_silhouette),
|
||
|
"bic_score": {
|
||
|
"best": min(bic_scores),
|
||
|
"worst": max(bic_scores),
|
||
|
"mean": sum(bic_scores) / len(bic_scores),
|
||
|
"median": sorted(bic_scores)[len(bic_scores)//2]
|
||
|
},
|
||
|
"aic_score": {
|
||
|
"best": min(aic_scores),
|
||
|
"worst": max(aic_scores),
|
||
|
"mean": sum(aic_scores) / len(aic_scores),
|
||
|
"median": sorted(aic_scores)[len(aic_scores)//2]
|
||
|
},
|
||
|
"log_likelihood": {
|
||
|
"best": max(log_likelihoods),
|
||
|
"worst": min(log_likelihoods),
|
||
|
"mean": sum(log_likelihoods) / len(log_likelihoods)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if valid_silhouette:
|
||
|
grid_search_data["summary_statistics"]["silhouette_score"] = {
|
||
|
"best": max(valid_silhouette),
|
||
|
"worst": min(valid_silhouette),
|
||
|
"mean": sum(valid_silhouette) / len(valid_silhouette),
|
||
|
"median": sorted(valid_silhouette)[len(valid_silhouette)//2]
|
||
|
}
|
||
|
|
||
|
# Top 10 results by different criteria
|
||
|
sorted_by_bic = sorted(all_results, key=lambda x: x['bic_score'])
|
||
|
sorted_by_aic = sorted(all_results, key=lambda x: x['aic_score'])
|
||
|
valid_results = [r for r in all_results if r['silhouette_score'] > -1]
|
||
|
sorted_by_silhouette = sorted(valid_results, key=lambda x: x['silhouette_score'], reverse=True)
|
||
|
|
||
|
grid_search_data["top_10_results"] = {
|
||
|
"by_bic": [],
|
||
|
"by_aic": [],
|
||
|
"by_silhouette": []
|
||
|
}
|
||
|
|
||
|
for i, result in enumerate(sorted_by_bic[:10]):
|
||
|
grid_search_data["top_10_results"]["by_bic"].append({
|
||
|
"rank": i + 1,
|
||
|
"parameters": {
|
||
|
"n_components": result['n_components'],
|
||
|
"covariance_type": result['covariance_type'],
|
||
|
"init_params": result['init_params']
|
||
|
},
|
||
|
"bic_score": result['bic_score'],
|
||
|
"aic_score": result['aic_score']
|
||
|
})
|
||
|
|
||
|
for i, result in enumerate(sorted_by_aic[:10]):
|
||
|
grid_search_data["top_10_results"]["by_aic"].append({
|
||
|
"rank": i + 1,
|
||
|
"parameters": {
|
||
|
"n_components": result['n_components'],
|
||
|
"covariance_type": result['covariance_type'],
|
||
|
"init_params": result['init_params']
|
||
|
},
|
||
|
"bic_score": result['bic_score'],
|
||
|
"aic_score": result['aic_score']
|
||
|
})
|
||
|
|
||
|
for i, result in enumerate(sorted_by_silhouette[:10]):
|
||
|
grid_search_data["top_10_results"]["by_silhouette"].append({
|
||
|
"rank": i + 1,
|
||
|
"parameters": {
|
||
|
"n_components": result['n_components'],
|
||
|
"covariance_type": result['covariance_type'],
|
||
|
"init_params": result['init_params']
|
||
|
},
|
||
|
"silhouette_score": result['silhouette_score']
|
||
|
})
|
||
|
|
||
|
# Save to file with timestamp
|
||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
filename = f"gmm_grid_search_detailed_{timestamp}.json"
|
||
|
# print()
|
||
|
|
||
|
# with open(filename, 'w') as f:
|
||
|
# json.dump(grid_search_data, f, indent=4, ensure_ascii=False)
|
||
|
|
||
|
print(f"Detailed grid search results saved to: {filename}")
|
||
|
|
||
|
# Also save a CSV summary for easy analysis
|
||
|
csv_filename = f"gmm_grid_search_summary_{timestamp}.csv"
|
||
|
self.save_grid_search_csv(all_results, csv_filename)
|
||
|
print(f"Grid search summary CSV saved to: {csv_filename}")
|
||
|
|
||
|
def save_grid_search_csv(self, all_results, filename):
|
||
|
"""Save grid search results as CSV for easy analysis"""
|
||
|
|
||
|
with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
|
||
|
fieldnames = ['trial_id', 'n_components', 'covariance_type', 'reg_covar',
|
||
|
'n_init', 'init_params', 'max_iter', 'bic_score', 'aic_score',
|
||
|
'log_likelihood', 'silhouette_score', 'calinski_harabasz_score',
|
||
|
'davies_bouldin_score', 'converged', 'n_iter', 'unique_clusters']
|
||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
|
||
|
writer.writeheader()
|
||
|
for i, result in enumerate(all_results):
|
||
|
writer.writerow({
|
||
|
'trial_id': i + 1,
|
||
|
'n_components': result['n_components'],
|
||
|
'covariance_type': result['covariance_type'],
|
||
|
'reg_covar': result['reg_covar'],
|
||
|
'n_init': result['n_init'],
|
||
|
'init_params': result['init_params'],
|
||
|
'max_iter': result['max_iter'],
|
||
|
'bic_score': result['bic_score'],
|
||
|
'aic_score': result['aic_score'],
|
||
|
'log_likelihood': result['log_likelihood'],
|
||
|
'silhouette_score': result['silhouette_score'],
|
||
|
'calinski_harabasz_score': result['calinski_harabasz_score'],
|
||
|
'davies_bouldin_score': result['davies_bouldin_score'],
|
||
|
'converged': result['converged'],
|
||
|
'n_iter': result['n_iter'],
|
||
|
'unique_clusters': result['unique_clusters']
|
||
|
})
|
||
|
|
||
|
def visualize_results(self, results):
|
||
|
"""Visualize clustering results using PCA"""
|
||
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
||
|
|
||
|
# Reduce dimensions for visualization
|
||
|
pca = PCA(n_components=2, random_state=42)
|
||
|
embeddings_2d = pca.fit_transform(self.embeddings_scaled)
|
||
|
|
||
|
methods = ['bic', 'aic', 'silhouette']
|
||
|
titles = ['Best by BIC', 'Best by AIC', 'Best by Silhouette']
|
||
|
|
||
|
for idx, (method, title) in enumerate(zip(methods, titles)):
|
||
|
labels, params, score = results[method]
|
||
|
|
||
|
if labels is not None:
|
||
|
unique_labels = set(labels)
|
||
|
colors = plt.cm.Set3(np.linspace(0, 1, len(unique_labels)))
|
||
|
|
||
|
for label, color in zip(unique_labels, colors):
|
||
|
mask = labels == label
|
||
|
axes[idx].scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
|
||
|
c=[color], s=50, alpha=0.7, label=f'Cluster {label}')
|
||
|
|
||
|
axes[idx].set_title(f'{title}\nn_components={params["n_components"]}, '
|
||
|
f'cov={params["covariance_type"]}')
|
||
|
else:
|
||
|
axes[idx].text(0.5, 0.5, 'No valid clustering', ha='center', va='center',
|
||
|
transform=axes[idx].transAxes, fontsize=12)
|
||
|
axes[idx].set_title(f'{title}\n(Failed)')
|
||
|
|
||
|
axes[idx].set_xlabel('PCA Component 1')
|
||
|
axes[idx].set_ylabel('PCA Component 2')
|
||
|
axes[idx].grid(True, alpha=0.3)
|
||
|
|
||
|
plt.tight_layout()
|
||
|
plt.savefig('gmm_clustering_results.png', dpi=300, bbox_inches='tight')
|
||
|
plt.show()
|
||
|
|
||
|
print(f"Visualization saved as 'gmm_clustering_results.png'")
|
||
|
|
||
|
def save_clustering_results(self, results):
|
||
|
"""Save final clustering results to JSON files"""
|
||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
||
|
for method in ['bic', 'aic', 'silhouette']:
|
||
|
labels, params, score = results[method]
|
||
|
|
||
|
if labels is not None:
|
||
|
clustering_results = []
|
||
|
for filepath, label in zip(self.file_paths, labels):
|
||
|
clustering_results.append({
|
||
|
"filepath": filepath,
|
||
|
"cluster": int(label)
|
||
|
})
|
||
|
|
||
|
filename = f"gmm_final_results_{method}_{timestamp}.json"
|
||
|
|
||
|
with open(filename, 'w') as f:
|
||
|
json.dump({
|
||
|
"method": f"GMM (best by {method.upper()})",
|
||
|
"parameters": params,
|
||
|
"n_components": params['n_components'],
|
||
|
"n_samples": len(labels),
|
||
|
f"{method}_score": score,
|
||
|
"results": clustering_results
|
||
|
}, f, indent=4)
|
||
|
|
||
|
print(f"Final clustering results ({method}) saved to: {filename}")
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description="Run extensive Gaussian Mixture Model clustering on document embeddings")
|
||
|
parser.add_argument("--embeddings_path", required=True, help="Path to embeddings JSON file")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Initialize clustering
|
||
|
clustering = GMMExtensiveClustering(args.embeddings_path)
|
||
|
|
||
|
# Run extensive grid search
|
||
|
results = clustering.run_gmm_grid_search()
|
||
|
|
||
|
if any(labels is not None for labels, _, _ in results.values()):
|
||
|
# Visualize and save results
|
||
|
clustering.visualize_results(results)
|
||
|
clustering.save_clustering_results(results)
|
||
|
print("\nGMM extensive clustering completed successfully!")
|
||
|
else:
|
||
|
print("\nGMM extensive clustering did not find suitable clusters.")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|