#!/usr/bin/env python3 """ Universal Clustering Filter Filters clustering results for multiple algorithms: - DBSCAN: handles noise points, uses density-based selection - GMM: uses probability-based selection, no noise points - K-Means: standard centroid-based selection Uses cosine distance metric for all calculations. """ import json import numpy as np from sklearn.preprocessing import normalize from sklearn.metrics.pairwise import cosine_distances import argparse import os from pathlib import Path class UniversalClusterFilter: def __init__(self, embeddings_path, clustering_results_path): """ Initialize universal cluster filter Args: embeddings_path: Path to embeddings JSON file clustering_results_path: Path to clustering results JSON """ self.embeddings_path = embeddings_path self.clustering_results_path = clustering_results_path self.embeddings = None self.embeddings_normalized = None self.clustering_results = None self.filepath_to_embedding = {} self.algorithm = None def load_data(self): """Load embeddings and clustering results""" print("Loading embeddings...") with open(self.embeddings_path, 'r') as f: embeddings_data = json.load(f) # Create mapping from filepath to embedding embeddings_list = [] filepaths = [] for item in embeddings_data: self.filepath_to_embedding[item['filepath']] = item['embedding'] embeddings_list.append(item['embedding']) filepaths.append(item['filepath']) self.embeddings = np.array(embeddings_list, dtype=np.float32) self.embeddings_normalized = normalize(self.embeddings, norm='l2') print(f"Loaded {len(embeddings_list)} embeddings") print("Loading clustering results...") with open(self.clustering_results_path, 'r') as f: self.clustering_results = json.load(f) # Detect algorithm type self.algorithm = self.clustering_results.get('method', 'UNKNOWN') print(f"Detected algorithm: {self.algorithm}") print(f"Loaded clustering results: {self.clustering_results['n_clusters']} clusters, " f"{self.clustering_results['n_samples']} samples") def group_by_clusters(self): """Group data points by cluster labels (algorithm-agnostic)""" clusters = {} noise_points = [] for result in self.clustering_results['results']: cluster_id = result['cluster'] filepath = result['filepath'] # Check for noise points (DBSCAN specific) is_noise = result.get('is_noise', False) if is_noise or cluster_id == -1: noise_points.append({ 'filepath': filepath, 'embedding': self.filepath_to_embedding[filepath], 'metadata': result }) else: if cluster_id not in clusters: clusters[cluster_id] = [] clusters[cluster_id].append({ 'filepath': filepath, 'embedding': self.filepath_to_embedding[filepath], 'metadata': result }) return clusters, noise_points def calculate_cluster_centroid(self, cluster_points): """Calculate centroid of a cluster using normalized embeddings""" embeddings = np.array([point['embedding'] for point in cluster_points]) embeddings_normalized = normalize(embeddings, norm='l2') # For cosine distance, centroid is the normalized mean centroid = np.mean(embeddings_normalized, axis=0) centroid_normalized = normalize(centroid.reshape(1, -1), norm='l2')[0] return centroid_normalized def calculate_cosine_distances_to_centroid(self, cluster_points, centroid): """Calculate cosine distances from each point to cluster centroid""" embeddings = np.array([point['embedding'] for point in cluster_points]) embeddings_normalized = normalize(embeddings, norm='l2') # Calculate cosine distances to centroid distances = cosine_distances(embeddings_normalized, centroid.reshape(1, -1)).flatten() return distances def filter_cluster_standard(self, cluster_points, selection_ratio=0.5): """ Standard filtering: 25% center + 75% border of selected points """ if len(cluster_points) == 0: return [] # Calculate how many points to select total_points = len(cluster_points) num_to_select = max(1, int(total_points * selection_ratio)) # If we need to select all or almost all points, just return all if num_to_select >= total_points: return cluster_points # Calculate centroid centroid = self.calculate_cluster_centroid(cluster_points) # Calculate distances to centroid distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid) # Create list of (point, distance) pairs point_distance_pairs = list(zip(cluster_points, distances)) # Sort by distance (closest to furthest from centroid) point_distance_pairs.sort(key=lambda x: x[1]) # Calculate how many points to select from center and border center_count = max(1, int(num_to_select * 0.25)) # 25% from center border_count = num_to_select - center_count # 75% from border selected_points = [] # Select center points (closest to centroid) center_points = [pair[0] for pair in point_distance_pairs[:center_count]] for point in center_points: point['selection_type'] = 'center' selected_points.extend(center_points) # Select border points (furthest from centroid) if border_count > 0: border_points = [pair[0] for pair in point_distance_pairs[-border_count:]] for point in border_points: point['selection_type'] = 'border' selected_points.extend(border_points) print(f"Cluster with {total_points} points -> selected {len(selected_points)} points " f"({center_count} center + {border_count} border)") return selected_points def filter_cluster_gmm(self, cluster_points, selection_ratio=0.5): """ GMM-specific filtering: consider probability scores if available """ if len(cluster_points) == 0: return [] # Check if we have probability scores has_probabilities = any('probability' in point['metadata'] for point in cluster_points) if has_probabilities: # Use probability-based selection total_points = len(cluster_points) num_to_select = max(1, int(total_points * selection_ratio)) if num_to_select >= total_points: return cluster_points # Sort by probability (highest confidence first) sorted_points = sorted(cluster_points, key=lambda x: x['metadata'].get('probability', 0), reverse=True) # Take top probability points selected_points = sorted_points[:num_to_select] for point in selected_points: point['selection_type'] = 'high_probability' print(f"GMM Cluster with {total_points} points -> selected {len(selected_points)} points " f"(top probability)") return selected_points else: # Fall back to standard filtering return self.filter_cluster_standard(cluster_points, selection_ratio) def filter_all_clusters(self, selection_ratio=0.5): """Filter all clusters according to algorithm-specific criteria""" print("\n" + "="*60) print(f"FILTERING {self.algorithm} CLUSTERING RESULTS") print("="*60) clusters, noise_points = self.group_by_clusters() print(f"Found {len(clusters)} clusters and {len(noise_points)} noise points") filtered_results = [] # Process each cluster for cluster_id, cluster_points in clusters.items(): print(f"\nProcessing Cluster {cluster_id}:") # Choose filtering method based on algorithm if self.algorithm.upper() == 'GMM' or 'GAUSSIAN' in self.algorithm.upper(): filtered_points = self.filter_cluster_gmm(cluster_points, selection_ratio) else: filtered_points = self.filter_cluster_standard(cluster_points, selection_ratio) # Add cluster information for point in filtered_points: filtered_results.append({ 'filepath': point['filepath'], 'cluster': cluster_id, 'is_noise': False, 'selection_type': point.get('selection_type', 'cluster_filtered'), 'original_metadata': point['metadata'] }) # Add all noise points (DBSCAN only) if noise_points: print(f"\nAdding all {len(noise_points)} noise points...") for point in noise_points: filtered_results.append({ 'filepath': point['filepath'], 'cluster': -1, 'is_noise': True, 'selection_type': 'noise', 'original_metadata': point['metadata'] }) return filtered_results def save_filtered_results(self, filtered_results, output_path=None): """Save filtered results to JSON file""" if output_path is None: base_name = Path(self.clustering_results_path).stem output_path = f"{base_name}_filtered.json" # Create summary statistics cluster_stats = {} noise_count = 0 selection_type_stats = {} for result in filtered_results: # Cluster stats if result['is_noise']: noise_count += 1 else: cluster_id = result['cluster'] if cluster_id not in cluster_stats: cluster_stats[cluster_id] = 0 cluster_stats[cluster_id] += 1 # Selection type stats sel_type = result['selection_type'] selection_type_stats[sel_type] = selection_type_stats.get(sel_type, 0) + 1 # Prepare output data output_data = { "method": f"{self.algorithm}_FILTERED", "original_algorithm": self.algorithm, "original_n_clusters": self.clustering_results['n_clusters'], "original_n_samples": self.clustering_results['n_samples'], "filtered_n_samples": len(filtered_results), "filtering_criteria": { "cluster_selection_ratio": 0.5, "center_points_ratio": 0.25, "border_points_ratio": 0.75, "noise_points": "all_selected" if noise_count > 0 else "none_present" }, "cluster_statistics": cluster_stats, "selection_type_statistics": selection_type_stats, "noise_points": noise_count, "results": filtered_results } with open(output_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, indent=4, ensure_ascii=False) print("\n" + "="*60) print("FILTERING SUMMARY") print("="*60) print(f"Algorithm: {self.algorithm}") print(f"Original samples: {self.clustering_results['n_samples']}") print(f"Filtered samples: {len(filtered_results)}") print(f"Reduction ratio: {len(filtered_results)/self.clustering_results['n_samples']:.2%}") print("\nCluster breakdown:") for cluster_id, count in sorted(cluster_stats.items()): print(f" Cluster {cluster_id}: {count} points") if noise_count > 0: print(f" Noise points: {noise_count} points") print("\nSelection type breakdown:") for sel_type, count in selection_type_stats.items(): print(f" {sel_type}: {count} points") print(f"\nFiltered results saved to: {output_path}") return output_path def create_filepath_list(self, filtered_results, output_txt_path=None): """Create a simple text file with filtered filepaths""" if output_txt_path is None: base_name = Path(self.clustering_results_path).stem output_txt_path = f"{base_name}_filtered_filepaths.txt" filepaths = [result['filepath'] for result in filtered_results] with open(output_txt_path, 'w', encoding='utf-8') as f: for filepath in filepaths: f.write(f"{filepath}\n") print(f"Filepath list saved to: {output_txt_path}") return output_txt_path def main(): parser = argparse.ArgumentParser(description="Universal filter for clustering results") parser.add_argument("--embeddings_path", required=True, help="Path to embeddings JSON file") parser.add_argument("--clustering_results_path", required=True, help="Path to clustering results JSON file") parser.add_argument("--output_path", help="Output path for filtered results (optional)") parser.add_argument("--selection_ratio", type=float, default=0.5, help="Ratio of points to select from each cluster (default: 0.5)") parser.add_argument("--create_filepath_list", action="store_true", help="Also create a simple text file with filtered filepaths") args = parser.parse_args() # Validate input files exist if not os.path.exists(args.embeddings_path): print(f"Error: Embeddings file not found: {args.embeddings_path}") return if not os.path.exists(args.clustering_results_path): print(f"Error: Clustering results file not found: {args.clustering_results_path}") return # Initialize filter filter_obj = UniversalClusterFilter(args.embeddings_path, args.clustering_results_path) # Load data filter_obj.load_data() # Filter clusters filtered_results = filter_obj.filter_all_clusters(args.selection_ratio) # Save results filter_obj.save_filtered_results(filtered_results, args.output_path) # Create filepath list if requested if args.create_filepath_list: filter_obj.create_filepath_list(filtered_results) print("\nFiltering completed successfully!") if __name__ == "__main__": main()