#!/usr/bin/env python3 """ DBSCAN Clustering Filter Filters clustering results based on specific criteria: - For each cluster: select 50% of points - 25% from center region (closest to centroid) - 25% from border region (furthest from centroid) - All noise points are selected - Uses cosine distance metric """ import json import numpy as np from sklearn.preprocessing import normalize from sklearn.metrics.pairwise import cosine_distances import argparse import os from pathlib import Path import random class DBSCANFilter: def __init__(self, embeddings_path, clustering_results_path): """ Initialize DBSCAN filter Args: embeddings_path: Path to embeddings JSON file clustering_results_path: Path to DBSCAN clustering results JSON """ self.embeddings_path = embeddings_path self.clustering_results_path = clustering_results_path self.embeddings = None self.embeddings_normalized = None self.clustering_results = None self.filepath_to_embedding = {} def load_data(self): """Load embeddings and clustering results""" print("Loading embeddings...") with open(self.embeddings_path, 'r') as f: embeddings_data = json.load(f) # Create mapping from filepath to embedding embeddings_list = [] filepaths = [] for item in embeddings_data: self.filepath_to_embedding[item['filepath']] = item['embedding'] embeddings_list.append(item['embedding']) filepaths.append(item['filepath']) self.embeddings = np.array(embeddings_list, dtype=np.float32) self.embeddings_normalized = normalize(self.embeddings, norm='l2') print(f"Loaded {len(embeddings_list)} embeddings") print("Loading clustering results...") with open(self.clustering_results_path, 'r') as f: self.clustering_results = json.load(f) print(f"Loaded clustering results: {self.clustering_results['n_clusters']} clusters, " f"{self.clustering_results['n_samples']} samples") def group_by_clusters(self): """Group data points by cluster labels""" clusters = {} noise_points = [] for result in self.clustering_results['results']: cluster_id = result['cluster'] filepath = result['filepath'] if 'is_noise' in result: is_noise = result['is_noise'] else: is_noise = False if is_noise or cluster_id == -1: noise_points.append({ 'filepath': filepath, 'embedding': self.filepath_to_embedding[filepath] }) else: if cluster_id not in clusters: clusters[cluster_id] = [] clusters[cluster_id].append({ 'filepath': filepath, 'embedding': self.filepath_to_embedding[filepath] }) return clusters, noise_points def calculate_cluster_centroid(self, cluster_points): """Calculate centroid of a cluster using normalized embeddings""" embeddings = np.array([point['embedding'] for point in cluster_points]) embeddings_normalized = normalize(embeddings, norm='l2') # For cosine distance, centroid is the normalized mean centroid = np.mean(embeddings_normalized, axis=0) centroid_normalized = normalize(centroid.reshape(1, -1), norm='l2')[0] return centroid_normalized def calculate_cosine_distances_to_centroid(self, cluster_points, centroid): """Calculate cosine distances from each point to cluster centroid""" embeddings = np.array([point['embedding'] for point in cluster_points]) embeddings_normalized = normalize(embeddings, norm='l2') # Calculate cosine distances to centroid distances = cosine_distances(embeddings_normalized, centroid.reshape(1, -1)).flatten() return distances # v1 0.5 data, 0.5 center 0.5 border # v2 0.5 data, 0.25 center 0.75 border # def filter_cluster(self, cluster_points, selection_ratio=0.5): # v3 0.75 data, 0.25 center 0.75 border #dbscan 014 # def filter_cluster(self, cluster_points, selection_ratio=0.3): # """ # Filter points from a cluster # Args: # cluster_points: List of points in the cluster # selection_ratio: Ratio of points to select (default: 0.5 = 50%) # Returns: # List of selected points # """ # if len(cluster_points) == 0: # return [] # # Calculate how many points to select # total_points = len(cluster_points) # num_to_select = max(15, int(total_points * selection_ratio)) # # If we need to select all or almost all points, just return all # if num_to_select >= total_points: # return cluster_points # # Calculate centroid # centroid = self.calculate_cluster_centroid(cluster_points) # # Calculate distances to centroid # distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid) # # Create list of (point, distance) pairs # point_distance_pairs = list(zip(cluster_points, distances)) # # Sort by distance (closest to furthest from centroid) # point_distance_pairs.sort(key=lambda x: x[1]) # dis = 0.1 # # count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis) # all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis] # print(f"Number of center points (distance < {dis}): {len(all_center_points)}") # # count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis) # all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis] # print(f"Number of border points (distance >= {dis}): {len(all_border_points)}") # # Calculate how many points to select from center and border # n_center = len(all_center_points) # center_count = max(1, int(n_center * 0.15)) # n_border = len(all_border_points) # if n_border < 70: # border_count = n_border # else: # border_count = max(0, int(n_border * 0.3)) # remaining from border # selected_points = [] # random.seed(42) # # Select center points (closest to centroid) # # center_points = [pair[0] for pair in point_distance_pairs[:center_count]] # center_points = random.sample(all_center_points, center_count) # selected_points.extend(center_points) # # # Select border points (furthest from centroid) # if border_count > 0: # # border_points = [pair[0] for pair in point_distance_pairs[-border_count:]] # border_points = random.sample(all_border_points, border_count) # selected_points.extend(border_points) # print(f"Cluster with {total_points} points -> selected {len(selected_points)} points " # f"({center_count} center + {border_count} border)") # return selected_points # dbscan 015 def filter_cluster(self, cluster_points, selection_ratio=0.3): """ Filter points from a cluster Args: cluster_points: List of points in the cluster selection_ratio: Ratio of points to select (default: 0.5 = 50%) Returns: List of selected points """ if len(cluster_points) == 0: return [] # Calculate how many points to select total_points = len(cluster_points) num_to_select = max(15, int(total_points * selection_ratio)) # If we need to select all or almost all points, just return all if num_to_select >= total_points: return cluster_points # Calculate centroid centroid = self.calculate_cluster_centroid(cluster_points) # Calculate distances to centroid distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid) # Create list of (point, distance) pairs point_distance_pairs = list(zip(cluster_points, distances)) # Sort by distance (closest to furthest from centroid) point_distance_pairs.sort(key=lambda x: x[1]) dis = 0.1 # count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis) all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis] print(f"Number of center points (distance < {dis}): {len(all_center_points)}") # count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis) all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis] print(f"Number of border points (distance >= {dis}): {len(all_border_points)}") # Calculate how many points to select from center and border n_center = len(all_center_points) center_count = max(1, int(n_center * 0.15)) n_border = len(all_border_points) if n_border < 70: border_count = n_border else: border_count = max(0, int(n_border * 0.3)) # remaining from border selected_points = [] random.seed(42) # Select center points (closest to centroid) # center_points = [pair[0] for pair in point_distance_pairs[:center_count]] center_points = random.sample(all_center_points, center_count) selected_points.extend(center_points) # # Select border points (furthest from centroid) if border_count > 0: # border_points = [pair[0] for pair in point_distance_pairs[-border_count:]] border_points = random.sample(all_border_points, border_count) selected_points.extend(border_points) print(f"Cluster with {total_points} points -> selected {len(selected_points)} points " f"({center_count} center + {border_count} border)") return selected_points #gmm # def filter_cluster(self, cluster_points, selection_ratio=0.3): # """ # Filter points from a cluster # Args: # cluster_points: List of points in the cluster # selection_ratio: Ratio of points to select (default: 0.5 = 50%) # Returns: # List of selected points # """ # if len(cluster_points) == 0: # return [] # # Calculate how many points to select # total_points = len(cluster_points) # num_to_select = max(15, int(total_points * selection_ratio)) # # If we need to select all or almost all points, just return all # if num_to_select >= total_points: # return cluster_points # # Calculate centroid # centroid = self.calculate_cluster_centroid(cluster_points) # # Calculate distances to centroid # distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid) # # Create list of (point, distance) pairs # point_distance_pairs = list(zip(cluster_points, distances)) # # Sort by distance (closest to furthest from centroid) # point_distance_pairs.sort(key=lambda x: x[1]) # dis = 0.2 # # count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis) # all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis] # print(f"Number of center points (distance < {dis}): {len(all_center_points)}") # # count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis) # all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis] # print(f"Number of border points (distance >= {dis}): {len(all_border_points)}") # # Calculate how many points to select from center and border # n_center = len(all_center_points) # center_count = max(1, int(n_center * 0.15)) # n_border = len(all_border_points) # if n_border < 70: # border_count = n_border # else: # border_count = max(0, int(n_border * 0.3)) # remaining from border # selected_points = [] # random.seed(42) # # Select center points (closest to centroid) # # center_points = [pair[0] for pair in point_distance_pairs[:center_count]] # center_points = random.sample(all_center_points, center_count) # selected_points.extend(center_points) # # # Select border points (furthest from centroid) # if border_count > 0: # # border_points = [pair[0] for pair in point_distance_pairs[-border_count:]] # border_points = random.sample(all_border_points, border_count) # selected_points.extend(border_points) # print(f"Cluster with {total_points} points -> selected {len(selected_points)} points " # f"({center_count} center + {border_count} border)") # return selected_points def filter_all_clusters(self): """Filter all clusters according to the specified criteria""" print("\n" + "="*60) print("FILTERING DBSCAN CLUSTERING RESULTS") print("="*60) clusters, noise_points = self.group_by_clusters() print(f"Found {len(clusters)} clusters and {len(noise_points)} noise points") filtered_results = [] # Process each cluster for cluster_id, cluster_points in clusters.items(): print(f"\nProcessing Cluster {cluster_id}:") filtered_points = self.filter_cluster(cluster_points) # Add cluster information for point in filtered_points: filtered_results.append({ 'filepath': point['filepath'], 'cluster': cluster_id, 'is_noise': False, 'selection_type': 'cluster_filtered' }) # Add all noise points print(f"\nAdding all {len(noise_points)} noise points...") for point in noise_points: filtered_results.append({ 'filepath': point['filepath'], 'cluster': -1, 'is_noise': True, 'selection_type': 'noise' }) return filtered_results def save_filtered_results(self, filtered_results, output_path=None): """Save filtered results to JSON file""" if output_path is None: # Generate output filename based on input base_name = Path(self.clustering_results_path).stem output_path = f"{base_name}_filtered.json" # Create summary statistics cluster_stats = {} noise_count = 0 for result in filtered_results: if result['is_noise']: noise_count += 1 else: cluster_id = result['cluster'] if cluster_id not in cluster_stats: cluster_stats[cluster_id] = 0 cluster_stats[cluster_id] += 1 # Prepare output data output_data = { "method": "DBSCAN_FILTERED", "original_n_clusters": self.clustering_results['n_clusters'], "original_n_samples": self.clustering_results['n_samples'], "filtered_n_samples": len(filtered_results), "filtering_criteria": { "cluster_selection_ratio": 0.5, "center_points_ratio": 0.5, # 50% of selected points from center "border_points_ratio": 0.5, # 50% of selected points from border "noise_points": "all_selected" }, "cluster_statistics": cluster_stats, "noise_points": noise_count, "results": filtered_results } with open(output_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, indent=4, ensure_ascii=False) print("\n" + "="*60) print("FILTERING SUMMARY") print("="*60) print(f"Original samples: {self.clustering_results['n_samples']}") print(f"Filtered samples: {len(filtered_results)}") print(f"Reduction ratio: {len(filtered_results)/self.clustering_results['n_samples']:.2%}") print("\nCluster breakdown:") for cluster_id, count in sorted(cluster_stats.items()): print(f" Cluster {cluster_id}: {count} points") print(f" Noise points: {noise_count} points") print(f"\nFiltered results saved to: {output_path}") return output_path def create_filepath_list(self, filtered_results, output_txt_path=None): """Create a simple text file with filtered filepaths""" if output_txt_path is None: base_name = Path(self.clustering_results_path).stem output_txt_path = f"{base_name}_filtered_filepaths.txt" filepaths = [result['filepath'] for result in filtered_results] with open(output_txt_path, 'w', encoding='utf-8') as f: for filepath in filepaths: f.write(f"{filepath}\n") print(f"Filepath list saved to: {output_txt_path}") return output_txt_path def main(): parser = argparse.ArgumentParser(description="Filter DBSCAN clustering results") parser.add_argument("--embeddings_path", required=True, help="Path to embeddings JSON file") parser.add_argument("--clustering_results_path", required=True, help="Path to DBSCAN clustering results JSON file") parser.add_argument("--output_path", help="Output path for filtered results (optional)") parser.add_argument("--create_filepath_list", action="store_true", help="Also create a simple text file with filtered filepaths") args = parser.parse_args() # Validate input files exist if not os.path.exists(args.embeddings_path): print(f"Error: Embeddings file not found: {args.embeddings_path}") return if not os.path.exists(args.clustering_results_path): print(f"Error: Clustering results file not found: {args.clustering_results_path}") return # Initialize filter filter_obj = DBSCANFilter(args.embeddings_path, args.clustering_results_path) # Load data filter_obj.load_data() # Filter clusters filtered_results = filter_obj.filter_all_clusters() # Save results filter_obj.save_filtered_results(filtered_results, args.output_path) # Create filepath list if requested if args.create_filepath_list: filter_obj.create_filepath_list(filtered_results) print("\nFiltering completed successfully!") if __name__ == "__main__": main()