update source code and pipeline
This commit is contained in:
496
filter/dbscan_v2.py
Normal file
496
filter/dbscan_v2.py
Normal file
@@ -0,0 +1,496 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DBSCAN Clustering Filter
|
||||
|
||||
Filters clustering results based on specific criteria:
|
||||
- For each cluster: select 50% of points
|
||||
- 25% from center region (closest to centroid)
|
||||
- 25% from border region (furthest from centroid)
|
||||
- All noise points are selected
|
||||
- Uses cosine distance metric
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import normalize
|
||||
from sklearn.metrics.pairwise import cosine_distances
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
class DBSCANFilter:
|
||||
def __init__(self, embeddings_path, clustering_results_path):
|
||||
"""
|
||||
Initialize DBSCAN filter
|
||||
|
||||
Args:
|
||||
embeddings_path: Path to embeddings JSON file
|
||||
clustering_results_path: Path to DBSCAN clustering results JSON
|
||||
"""
|
||||
self.embeddings_path = embeddings_path
|
||||
self.clustering_results_path = clustering_results_path
|
||||
self.embeddings = None
|
||||
self.embeddings_normalized = None
|
||||
self.clustering_results = None
|
||||
self.filepath_to_embedding = {}
|
||||
|
||||
def load_data(self):
|
||||
"""Load embeddings and clustering results"""
|
||||
print("Loading embeddings...")
|
||||
with open(self.embeddings_path, 'r') as f:
|
||||
embeddings_data = json.load(f)
|
||||
|
||||
# Create mapping from filepath to embedding
|
||||
embeddings_list = []
|
||||
filepaths = []
|
||||
for item in embeddings_data:
|
||||
self.filepath_to_embedding[item['filepath']] = item['embedding']
|
||||
embeddings_list.append(item['embedding'])
|
||||
filepaths.append(item['filepath'])
|
||||
|
||||
self.embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
self.embeddings_normalized = normalize(self.embeddings, norm='l2')
|
||||
print(f"Loaded {len(embeddings_list)} embeddings")
|
||||
|
||||
print("Loading clustering results...")
|
||||
with open(self.clustering_results_path, 'r') as f:
|
||||
self.clustering_results = json.load(f)
|
||||
|
||||
print(f"Loaded clustering results: {self.clustering_results['n_clusters']} clusters, "
|
||||
f"{self.clustering_results['n_samples']} samples")
|
||||
|
||||
def group_by_clusters(self):
|
||||
"""Group data points by cluster labels"""
|
||||
clusters = {}
|
||||
noise_points = []
|
||||
|
||||
for result in self.clustering_results['results']:
|
||||
cluster_id = result['cluster']
|
||||
filepath = result['filepath']
|
||||
if 'is_noise' in result:
|
||||
is_noise = result['is_noise']
|
||||
else:
|
||||
is_noise = False
|
||||
|
||||
if is_noise or cluster_id == -1:
|
||||
noise_points.append({
|
||||
'filepath': filepath,
|
||||
'embedding': self.filepath_to_embedding[filepath]
|
||||
})
|
||||
else:
|
||||
if cluster_id not in clusters:
|
||||
clusters[cluster_id] = []
|
||||
clusters[cluster_id].append({
|
||||
'filepath': filepath,
|
||||
'embedding': self.filepath_to_embedding[filepath]
|
||||
})
|
||||
|
||||
return clusters, noise_points
|
||||
|
||||
def calculate_cluster_centroid(self, cluster_points):
|
||||
"""Calculate centroid of a cluster using normalized embeddings"""
|
||||
embeddings = np.array([point['embedding'] for point in cluster_points])
|
||||
embeddings_normalized = normalize(embeddings, norm='l2')
|
||||
|
||||
# For cosine distance, centroid is the normalized mean
|
||||
centroid = np.mean(embeddings_normalized, axis=0)
|
||||
centroid_normalized = normalize(centroid.reshape(1, -1), norm='l2')[0]
|
||||
|
||||
return centroid_normalized
|
||||
|
||||
def calculate_cosine_distances_to_centroid(self, cluster_points, centroid):
|
||||
"""Calculate cosine distances from each point to cluster centroid"""
|
||||
embeddings = np.array([point['embedding'] for point in cluster_points])
|
||||
embeddings_normalized = normalize(embeddings, norm='l2')
|
||||
|
||||
# Calculate cosine distances to centroid
|
||||
distances = cosine_distances(embeddings_normalized, centroid.reshape(1, -1)).flatten()
|
||||
|
||||
return distances
|
||||
|
||||
# v1 0.5 data, 0.5 center 0.5 border
|
||||
|
||||
# v2 0.5 data, 0.25 center 0.75 border
|
||||
# def filter_cluster(self, cluster_points, selection_ratio=0.5):
|
||||
|
||||
# v3 0.75 data, 0.25 center 0.75 border
|
||||
|
||||
|
||||
#dbscan 014
|
||||
# def filter_cluster(self, cluster_points, selection_ratio=0.3):
|
||||
|
||||
# """
|
||||
# Filter points from a cluster
|
||||
|
||||
# Args:
|
||||
# cluster_points: List of points in the cluster
|
||||
# selection_ratio: Ratio of points to select (default: 0.5 = 50%)
|
||||
|
||||
# Returns:
|
||||
# List of selected points
|
||||
# """
|
||||
# if len(cluster_points) == 0:
|
||||
# return []
|
||||
|
||||
|
||||
# # Calculate how many points to select
|
||||
# total_points = len(cluster_points)
|
||||
# num_to_select = max(15, int(total_points * selection_ratio))
|
||||
|
||||
# # If we need to select all or almost all points, just return all
|
||||
# if num_to_select >= total_points:
|
||||
# return cluster_points
|
||||
|
||||
# # Calculate centroid
|
||||
# centroid = self.calculate_cluster_centroid(cluster_points)
|
||||
|
||||
# # Calculate distances to centroid
|
||||
# distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid)
|
||||
|
||||
# # Create list of (point, distance) pairs
|
||||
# point_distance_pairs = list(zip(cluster_points, distances))
|
||||
|
||||
# # Sort by distance (closest to furthest from centroid)
|
||||
# point_distance_pairs.sort(key=lambda x: x[1])
|
||||
|
||||
# dis = 0.1
|
||||
# # count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis)
|
||||
# all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis]
|
||||
# print(f"Number of center points (distance < {dis}): {len(all_center_points)}")
|
||||
|
||||
# # count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis)
|
||||
# all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis]
|
||||
# print(f"Number of border points (distance >= {dis}): {len(all_border_points)}")
|
||||
|
||||
# # Calculate how many points to select from center and border
|
||||
# n_center = len(all_center_points)
|
||||
# center_count = max(1, int(n_center * 0.15))
|
||||
|
||||
# n_border = len(all_border_points)
|
||||
# if n_border < 70:
|
||||
# border_count = n_border
|
||||
# else:
|
||||
# border_count = max(0, int(n_border * 0.3)) # remaining from border
|
||||
|
||||
# selected_points = []
|
||||
|
||||
# random.seed(42)
|
||||
# # Select center points (closest to centroid)
|
||||
# # center_points = [pair[0] for pair in point_distance_pairs[:center_count]]
|
||||
# center_points = random.sample(all_center_points, center_count)
|
||||
# selected_points.extend(center_points)
|
||||
|
||||
# # # Select border points (furthest from centroid)
|
||||
# if border_count > 0:
|
||||
# # border_points = [pair[0] for pair in point_distance_pairs[-border_count:]]
|
||||
# border_points = random.sample(all_border_points, border_count)
|
||||
# selected_points.extend(border_points)
|
||||
|
||||
# print(f"Cluster with {total_points} points -> selected {len(selected_points)} points "
|
||||
# f"({center_count} center + {border_count} border)")
|
||||
|
||||
# return selected_points
|
||||
|
||||
# dbscan 015
|
||||
def filter_cluster(self, cluster_points, selection_ratio=0.3):
|
||||
|
||||
"""
|
||||
Filter points from a cluster
|
||||
|
||||
Args:
|
||||
cluster_points: List of points in the cluster
|
||||
selection_ratio: Ratio of points to select (default: 0.5 = 50%)
|
||||
|
||||
Returns:
|
||||
List of selected points
|
||||
"""
|
||||
if len(cluster_points) == 0:
|
||||
return []
|
||||
|
||||
|
||||
# Calculate how many points to select
|
||||
total_points = len(cluster_points)
|
||||
num_to_select = max(15, int(total_points * selection_ratio))
|
||||
|
||||
# If we need to select all or almost all points, just return all
|
||||
if num_to_select >= total_points:
|
||||
return cluster_points
|
||||
|
||||
# Calculate centroid
|
||||
centroid = self.calculate_cluster_centroid(cluster_points)
|
||||
|
||||
# Calculate distances to centroid
|
||||
distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid)
|
||||
|
||||
# Create list of (point, distance) pairs
|
||||
point_distance_pairs = list(zip(cluster_points, distances))
|
||||
|
||||
# Sort by distance (closest to furthest from centroid)
|
||||
point_distance_pairs.sort(key=lambda x: x[1])
|
||||
|
||||
dis = 0.1
|
||||
# count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis)
|
||||
all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis]
|
||||
print(f"Number of center points (distance < {dis}): {len(all_center_points)}")
|
||||
|
||||
# count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis)
|
||||
all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis]
|
||||
print(f"Number of border points (distance >= {dis}): {len(all_border_points)}")
|
||||
|
||||
# Calculate how many points to select from center and border
|
||||
n_center = len(all_center_points)
|
||||
center_count = max(1, int(n_center * 0.15))
|
||||
|
||||
n_border = len(all_border_points)
|
||||
if n_border < 70:
|
||||
border_count = n_border
|
||||
else:
|
||||
border_count = max(0, int(n_border * 0.3)) # remaining from border
|
||||
|
||||
selected_points = []
|
||||
|
||||
random.seed(42)
|
||||
# Select center points (closest to centroid)
|
||||
# center_points = [pair[0] for pair in point_distance_pairs[:center_count]]
|
||||
center_points = random.sample(all_center_points, center_count)
|
||||
selected_points.extend(center_points)
|
||||
|
||||
# # Select border points (furthest from centroid)
|
||||
if border_count > 0:
|
||||
# border_points = [pair[0] for pair in point_distance_pairs[-border_count:]]
|
||||
border_points = random.sample(all_border_points, border_count)
|
||||
selected_points.extend(border_points)
|
||||
|
||||
print(f"Cluster with {total_points} points -> selected {len(selected_points)} points "
|
||||
f"({center_count} center + {border_count} border)")
|
||||
|
||||
return selected_points
|
||||
|
||||
#gmm
|
||||
# def filter_cluster(self, cluster_points, selection_ratio=0.3):
|
||||
|
||||
# """
|
||||
# Filter points from a cluster
|
||||
|
||||
# Args:
|
||||
# cluster_points: List of points in the cluster
|
||||
# selection_ratio: Ratio of points to select (default: 0.5 = 50%)
|
||||
|
||||
# Returns:
|
||||
# List of selected points
|
||||
# """
|
||||
# if len(cluster_points) == 0:
|
||||
# return []
|
||||
|
||||
|
||||
# # Calculate how many points to select
|
||||
# total_points = len(cluster_points)
|
||||
# num_to_select = max(15, int(total_points * selection_ratio))
|
||||
|
||||
# # If we need to select all or almost all points, just return all
|
||||
# if num_to_select >= total_points:
|
||||
# return cluster_points
|
||||
|
||||
# # Calculate centroid
|
||||
# centroid = self.calculate_cluster_centroid(cluster_points)
|
||||
|
||||
# # Calculate distances to centroid
|
||||
# distances = self.calculate_cosine_distances_to_centroid(cluster_points, centroid)
|
||||
|
||||
# # Create list of (point, distance) pairs
|
||||
# point_distance_pairs = list(zip(cluster_points, distances))
|
||||
|
||||
# # Sort by distance (closest to furthest from centroid)
|
||||
# point_distance_pairs.sort(key=lambda x: x[1])
|
||||
|
||||
# dis = 0.2
|
||||
# # count_center = sum(1 for pair in point_distance_pairs if pair[1] < dis)
|
||||
# all_center_points = [pair[0] for pair in point_distance_pairs if pair[1] < dis]
|
||||
# print(f"Number of center points (distance < {dis}): {len(all_center_points)}")
|
||||
|
||||
# # count_border = sum(1 for pair in point_distance_pairs if pair[1] >= dis)
|
||||
# all_border_points = [pair[0] for pair in point_distance_pairs if pair[1] >= dis]
|
||||
# print(f"Number of border points (distance >= {dis}): {len(all_border_points)}")
|
||||
|
||||
# # Calculate how many points to select from center and border
|
||||
# n_center = len(all_center_points)
|
||||
# center_count = max(1, int(n_center * 0.15))
|
||||
|
||||
# n_border = len(all_border_points)
|
||||
# if n_border < 70:
|
||||
# border_count = n_border
|
||||
# else:
|
||||
# border_count = max(0, int(n_border * 0.3)) # remaining from border
|
||||
|
||||
# selected_points = []
|
||||
|
||||
# random.seed(42)
|
||||
# # Select center points (closest to centroid)
|
||||
# # center_points = [pair[0] for pair in point_distance_pairs[:center_count]]
|
||||
# center_points = random.sample(all_center_points, center_count)
|
||||
# selected_points.extend(center_points)
|
||||
|
||||
# # # Select border points (furthest from centroid)
|
||||
# if border_count > 0:
|
||||
# # border_points = [pair[0] for pair in point_distance_pairs[-border_count:]]
|
||||
# border_points = random.sample(all_border_points, border_count)
|
||||
# selected_points.extend(border_points)
|
||||
|
||||
# print(f"Cluster with {total_points} points -> selected {len(selected_points)} points "
|
||||
# f"({center_count} center + {border_count} border)")
|
||||
|
||||
# return selected_points
|
||||
|
||||
def filter_all_clusters(self):
|
||||
"""Filter all clusters according to the specified criteria"""
|
||||
print("\n" + "="*60)
|
||||
print("FILTERING DBSCAN CLUSTERING RESULTS")
|
||||
print("="*60)
|
||||
|
||||
clusters, noise_points = self.group_by_clusters()
|
||||
|
||||
print(f"Found {len(clusters)} clusters and {len(noise_points)} noise points")
|
||||
|
||||
filtered_results = []
|
||||
|
||||
# Process each cluster
|
||||
for cluster_id, cluster_points in clusters.items():
|
||||
print(f"\nProcessing Cluster {cluster_id}:")
|
||||
filtered_points = self.filter_cluster(cluster_points)
|
||||
|
||||
# Add cluster information
|
||||
for point in filtered_points:
|
||||
filtered_results.append({
|
||||
'filepath': point['filepath'],
|
||||
'cluster': cluster_id,
|
||||
'is_noise': False,
|
||||
'selection_type': 'cluster_filtered'
|
||||
})
|
||||
|
||||
# Add all noise points
|
||||
print(f"\nAdding all {len(noise_points)} noise points...")
|
||||
for point in noise_points:
|
||||
filtered_results.append({
|
||||
'filepath': point['filepath'],
|
||||
'cluster': -1,
|
||||
'is_noise': True,
|
||||
'selection_type': 'noise'
|
||||
})
|
||||
|
||||
return filtered_results
|
||||
|
||||
def save_filtered_results(self, filtered_results, output_path=None):
|
||||
"""Save filtered results to JSON file"""
|
||||
if output_path is None:
|
||||
# Generate output filename based on input
|
||||
base_name = Path(self.clustering_results_path).stem
|
||||
output_path = f"{base_name}_filtered.json"
|
||||
|
||||
# Create summary statistics
|
||||
cluster_stats = {}
|
||||
noise_count = 0
|
||||
|
||||
for result in filtered_results:
|
||||
if result['is_noise']:
|
||||
noise_count += 1
|
||||
else:
|
||||
cluster_id = result['cluster']
|
||||
if cluster_id not in cluster_stats:
|
||||
cluster_stats[cluster_id] = 0
|
||||
cluster_stats[cluster_id] += 1
|
||||
|
||||
# Prepare output data
|
||||
output_data = {
|
||||
"method": "DBSCAN_FILTERED",
|
||||
"original_n_clusters": self.clustering_results['n_clusters'],
|
||||
"original_n_samples": self.clustering_results['n_samples'],
|
||||
"filtered_n_samples": len(filtered_results),
|
||||
"filtering_criteria": {
|
||||
"cluster_selection_ratio": 0.5,
|
||||
"center_points_ratio": 0.5, # 50% of selected points from center
|
||||
"border_points_ratio": 0.5, # 50% of selected points from border
|
||||
"noise_points": "all_selected"
|
||||
},
|
||||
"cluster_statistics": cluster_stats,
|
||||
"noise_points": noise_count,
|
||||
"results": filtered_results
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("FILTERING SUMMARY")
|
||||
print("="*60)
|
||||
print(f"Original samples: {self.clustering_results['n_samples']}")
|
||||
print(f"Filtered samples: {len(filtered_results)}")
|
||||
print(f"Reduction ratio: {len(filtered_results)/self.clustering_results['n_samples']:.2%}")
|
||||
print("\nCluster breakdown:")
|
||||
for cluster_id, count in sorted(cluster_stats.items()):
|
||||
print(f" Cluster {cluster_id}: {count} points")
|
||||
print(f" Noise points: {noise_count} points")
|
||||
print(f"\nFiltered results saved to: {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
def create_filepath_list(self, filtered_results, output_txt_path=None):
|
||||
"""Create a simple text file with filtered filepaths"""
|
||||
if output_txt_path is None:
|
||||
base_name = Path(self.clustering_results_path).stem
|
||||
output_txt_path = f"{base_name}_filtered_filepaths.txt"
|
||||
|
||||
filepaths = [result['filepath'] for result in filtered_results]
|
||||
|
||||
with open(output_txt_path, 'w', encoding='utf-8') as f:
|
||||
for filepath in filepaths:
|
||||
f.write(f"{filepath}\n")
|
||||
|
||||
print(f"Filepath list saved to: {output_txt_path}")
|
||||
return output_txt_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Filter DBSCAN clustering results")
|
||||
parser.add_argument("--embeddings_path", required=True,
|
||||
help="Path to embeddings JSON file")
|
||||
parser.add_argument("--clustering_results_path", required=True,
|
||||
help="Path to DBSCAN clustering results JSON file")
|
||||
parser.add_argument("--output_path",
|
||||
help="Output path for filtered results (optional)")
|
||||
parser.add_argument("--create_filepath_list", action="store_true",
|
||||
help="Also create a simple text file with filtered filepaths")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input files exist
|
||||
if not os.path.exists(args.embeddings_path):
|
||||
print(f"Error: Embeddings file not found: {args.embeddings_path}")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.clustering_results_path):
|
||||
print(f"Error: Clustering results file not found: {args.clustering_results_path}")
|
||||
return
|
||||
|
||||
# Initialize filter
|
||||
filter_obj = DBSCANFilter(args.embeddings_path, args.clustering_results_path)
|
||||
|
||||
# Load data
|
||||
filter_obj.load_data()
|
||||
|
||||
# Filter clusters
|
||||
filtered_results = filter_obj.filter_all_clusters()
|
||||
|
||||
# Save results
|
||||
filter_obj.save_filtered_results(filtered_results, args.output_path)
|
||||
|
||||
# Create filepath list if requested
|
||||
if args.create_filepath_list:
|
||||
filter_obj.create_filepath_list(filtered_results)
|
||||
|
||||
print("\nFiltering completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user