Files
embedding-clustering/filter/fillter_all.py

376 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Universal Clustering Filter
Filters clustering results for multiple algorithms:
- DBSCAN: handles noise points, uses density-based selection
- GMM: uses probability-based selection, no noise points
- K-Means: standard centroid-based selection
Uses cosine distance metric for all calculations.
"""
import json
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_distances
import argparse
import os
from pathlib import Path
class UniversalClusterFilter:
def __init__(self, embeddings_path, clustering_results_path):
"""
Initialize universal cluster filter
Args:
embeddings_path: Path to embeddings JSON file
clustering_results_path: Path to clustering results JSON
"""
self.embeddings_path = embeddings_path
self.clustering_results_path = clustering_results_path
self.embeddings = None
self.embeddings_normalized = None
self.clustering_results = None
self.filepath_to_embedding = {}
self.algorithm = None
def load_data(self):
"""Load embeddings and clustering results"""
print("Loading embeddings...")
with open(self.embeddings_path, 'r') as f:
embeddings_data = json.load(f)
# Create mapping from filepath to embedding
embeddings_list = []
filepaths = []
for item in embeddings_data:
self.filepath_to_embedding[item['filepath']] = item['embedding']
embeddings_list.append(item['embedding'])
filepaths.append(item['filepath'])
self.embeddings = np.array(embeddings_list, dtype=np.float32)
self.embeddings_normalized = normalize(self.embeddings, norm='l2')
print(f"Loaded {len(embeddings_list)} embeddings")
print("Loading clustering results...")
with open(self.clustering_results_path, 'r') as f:
self.clustering_results = json.load(f)
# Detect algorithm type
self.algorithm = self.clustering_results.get('method', 'UNKNOWN')
print(f"Detected algorithm: {self.algorithm}")
print(f"Loaded clustering results: {self.clustering_results['n_clusters']} clusters, "
f"{self.clustering_results['n_samples']} samples")
def group_by_clusters(self):
"""Group data points by cluster labels (algorithm-agnostic)"""
clusters = {}
noise_points = []
for result in self.clustering_results['results']:
cluster_id = result['cluster']
filepath = result['filepath']
# Check for noise points (DBSCAN specific)
is_noise = result.get('is_noise', False)
if is_noise or cluster_id == -1:
noise_points.append({
'filepath': filepath,
'embedding': self.filepath_to_embedding[filepath],
'metadata': result
})
else:
if cluster_id not in clusters:
clusters[cluster_id] = []
clusters[cluster_id].append({
'filepath': filepath,
'embedding': self.filepath_to_embedding[filepath],
'metadata': result
})
return clusters, noise_points
def calculate_cluster_centroid(self, cluster_points):
"""Calculate centroid of a cluster using normalized embeddings"""
embeddings = np.array([point['embedding'] for point in cluster_points])
embeddings_normalized = normalize(embeddings, norm='l2')
# For cosine distance, centroid is the normalized mean
centroid = np.mean(embeddings_normalized, axis=0)
centroid_normalized = normalize(centroid.reshape(1, -1), norm='l2')[0]
return centroid_normalized
def calculate_cosine_distances_to_centroid(self, cluster_points, centroid):
"""Calculate cosine distances from each point to cluster centroid"""
embeddings = np.array([point['embedding'] for point in cluster_points])
embeddings_normalized = normalize(embeddings, norm='l2')
# Calculate cosine distances to centroid
distances = cosine_distances(embeddings_normalized, centroid.reshape(1, -1)).flatten()
return distances
def filter_cluster_standard(self, cluster_points, selection_ratio=0.5):
"""
Standard filtering: 25% center + 75% border of selected points
"""
if len(cluster_points) == 0:
return []
# Calculate how many points to select
total_points = len(cluster_points)
num_to_select = max(1, int(total_points * selection_ratio))
# If we need to select all or almost all points, just return all
if num_to_select >= total_points:
return cluster_points
# Calculate centroid
centroid = self.calculate_cluster_centroid(cluster_points)
# Calculate distances to centroid
distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid)
# Create list of (point, distance) pairs
point_distance_pairs = list(zip(cluster_points, distances))
# Sort by distance (closest to furthest from centroid)
point_distance_pairs.sort(key=lambda x: x[1])
# Calculate how many points to select from center and border
center_count = max(1, int(num_to_select * 0.25)) # 25% from center
border_count = num_to_select - center_count # 75% from border
selected_points = []
# Select center points (closest to centroid)
center_points = [pair[0] for pair in point_distance_pairs[:center_count]]
for point in center_points:
point['selection_type'] = 'center'
selected_points.extend(center_points)
# Select border points (furthest from centroid)
if border_count > 0:
border_points = [pair[0] for pair in point_distance_pairs[-border_count:]]
for point in border_points:
point['selection_type'] = 'border'
selected_points.extend(border_points)
print(f"Cluster with {total_points} points -> selected {len(selected_points)} points "
f"({center_count} center + {border_count} border)")
return selected_points
def filter_cluster_gmm(self, cluster_points, selection_ratio=0.5):
"""
GMM-specific filtering: consider probability scores if available
"""
if len(cluster_points) == 0:
return []
# Check if we have probability scores
has_probabilities = any('probability' in point['metadata'] for point in cluster_points)
if has_probabilities:
# Use probability-based selection
total_points = len(cluster_points)
num_to_select = max(1, int(total_points * selection_ratio))
if num_to_select >= total_points:
return cluster_points
# Sort by probability (highest confidence first)
sorted_points = sorted(cluster_points,
key=lambda x: x['metadata'].get('probability', 0),
reverse=True)
# Take top probability points
selected_points = sorted_points[:num_to_select]
for point in selected_points:
point['selection_type'] = 'high_probability'
print(f"GMM Cluster with {total_points} points -> selected {len(selected_points)} points "
f"(top probability)")
return selected_points
else:
# Fall back to standard filtering
return self.filter_cluster_standard(cluster_points, selection_ratio)
def filter_all_clusters(self, selection_ratio=0.5):
"""Filter all clusters according to algorithm-specific criteria"""
print("\n" + "="*60)
print(f"FILTERING {self.algorithm} CLUSTERING RESULTS")
print("="*60)
clusters, noise_points = self.group_by_clusters()
print(f"Found {len(clusters)} clusters and {len(noise_points)} noise points")
filtered_results = []
# Process each cluster
for cluster_id, cluster_points in clusters.items():
print(f"\nProcessing Cluster {cluster_id}:")
# Choose filtering method based on algorithm
if self.algorithm.upper() == 'GMM' or 'GAUSSIAN' in self.algorithm.upper():
filtered_points = self.filter_cluster_gmm(cluster_points, selection_ratio)
else:
filtered_points = self.filter_cluster_standard(cluster_points, selection_ratio)
# Add cluster information
for point in filtered_points:
filtered_results.append({
'filepath': point['filepath'],
'cluster': cluster_id,
'is_noise': False,
'selection_type': point.get('selection_type', 'cluster_filtered'),
'original_metadata': point['metadata']
})
# Add all noise points (DBSCAN only)
if noise_points:
print(f"\nAdding all {len(noise_points)} noise points...")
for point in noise_points:
filtered_results.append({
'filepath': point['filepath'],
'cluster': -1,
'is_noise': True,
'selection_type': 'noise',
'original_metadata': point['metadata']
})
return filtered_results
def save_filtered_results(self, filtered_results, output_path=None):
"""Save filtered results to JSON file"""
if output_path is None:
base_name = Path(self.clustering_results_path).stem
output_path = f"{base_name}_filtered.json"
# Create summary statistics
cluster_stats = {}
noise_count = 0
selection_type_stats = {}
for result in filtered_results:
# Cluster stats
if result['is_noise']:
noise_count += 1
else:
cluster_id = result['cluster']
if cluster_id not in cluster_stats:
cluster_stats[cluster_id] = 0
cluster_stats[cluster_id] += 1
# Selection type stats
sel_type = result['selection_type']
selection_type_stats[sel_type] = selection_type_stats.get(sel_type, 0) + 1
# Prepare output data
output_data = {
"method": f"{self.algorithm}_FILTERED",
"original_algorithm": self.algorithm,
"original_n_clusters": self.clustering_results['n_clusters'],
"original_n_samples": self.clustering_results['n_samples'],
"filtered_n_samples": len(filtered_results),
"filtering_criteria": {
"cluster_selection_ratio": 0.5,
"center_points_ratio": 0.25,
"border_points_ratio": 0.75,
"noise_points": "all_selected" if noise_count > 0 else "none_present"
},
"cluster_statistics": cluster_stats,
"selection_type_statistics": selection_type_stats,
"noise_points": noise_count,
"results": filtered_results
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=4, ensure_ascii=False)
print("\n" + "="*60)
print("FILTERING SUMMARY")
print("="*60)
print(f"Algorithm: {self.algorithm}")
print(f"Original samples: {self.clustering_results['n_samples']}")
print(f"Filtered samples: {len(filtered_results)}")
print(f"Reduction ratio: {len(filtered_results)/self.clustering_results['n_samples']:.2%}")
print("\nCluster breakdown:")
for cluster_id, count in sorted(cluster_stats.items()):
print(f" Cluster {cluster_id}: {count} points")
if noise_count > 0:
print(f" Noise points: {noise_count} points")
print("\nSelection type breakdown:")
for sel_type, count in selection_type_stats.items():
print(f" {sel_type}: {count} points")
print(f"\nFiltered results saved to: {output_path}")
return output_path
def create_filepath_list(self, filtered_results, output_txt_path=None):
"""Create a simple text file with filtered filepaths"""
if output_txt_path is None:
base_name = Path(self.clustering_results_path).stem
output_txt_path = f"{base_name}_filtered_filepaths.txt"
filepaths = [result['filepath'] for result in filtered_results]
with open(output_txt_path, 'w', encoding='utf-8') as f:
for filepath in filepaths:
f.write(f"{filepath}\n")
print(f"Filepath list saved to: {output_txt_path}")
return output_txt_path
def main():
parser = argparse.ArgumentParser(description="Universal filter for clustering results")
parser.add_argument("--embeddings_path", required=True,
help="Path to embeddings JSON file")
parser.add_argument("--clustering_results_path", required=True,
help="Path to clustering results JSON file")
parser.add_argument("--output_path",
help="Output path for filtered results (optional)")
parser.add_argument("--selection_ratio", type=float, default=0.5,
help="Ratio of points to select from each cluster (default: 0.5)")
parser.add_argument("--create_filepath_list", action="store_true",
help="Also create a simple text file with filtered filepaths")
args = parser.parse_args()
# Validate input files exist
if not os.path.exists(args.embeddings_path):
print(f"Error: Embeddings file not found: {args.embeddings_path}")
return
if not os.path.exists(args.clustering_results_path):
print(f"Error: Clustering results file not found: {args.clustering_results_path}")
return
# Initialize filter
filter_obj = UniversalClusterFilter(args.embeddings_path, args.clustering_results_path)
# Load data
filter_obj.load_data()
# Filter clusters
filtered_results = filter_obj.filter_all_clusters(args.selection_ratio)
# Save results
filter_obj.save_filtered_results(filtered_results, args.output_path)
# Create filepath list if requested
if args.create_filepath_list:
filter_obj.create_filepath_list(filtered_results)
print("\nFiltering completed successfully!")
if __name__ == "__main__":
main()