update source code and pipeline
This commit is contained in:
2801
check_filter/.visual_cache/plot_1.csv
Normal file
2801
check_filter/.visual_cache/plot_1.csv
Normal file
File diff suppressed because it is too large
Load Diff
2801
check_filter/.visual_cache/plot_2.csv
Normal file
2801
check_filter/.visual_cache/plot_2.csv
Normal file
File diff suppressed because it is too large
Load Diff
2801
check_filter/.visual_cache/plot_3.csv
Normal file
2801
check_filter/.visual_cache/plot_3.csv
Normal file
File diff suppressed because it is too large
Load Diff
2801
check_filter/.visual_cache/plot_4.csv
Normal file
2801
check_filter/.visual_cache/plot_4.csv
Normal file
File diff suppressed because it is too large
Load Diff
1
check_filter/run.sh
Normal file
1
check_filter/run.sh
Normal file
@@ -0,0 +1 @@
|
||||
streamlit run visual_data.py --server.port 8501
|
579
check_filter/visual_data copy 2.py
Normal file
579
check_filter/visual_data copy 2.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""
|
||||
Streamlit app để trực quan hóa embedding 3 chiều (PCA / UMAP / t-SNE) + phân cụm.
|
||||
|
||||
Chạy:
|
||||
streamlit run visual_data.py --server.port 8501
|
||||
|
||||
Yêu cầu cài đặt (một lần):
|
||||
pip install streamlit plotly scikit-learn umap-learn numpy pandas
|
||||
|
||||
Tính năng:
|
||||
- Load file JSON lớn chứa các object {"filepath": ..., "embedding": [...]} hoặc định dạng JSON lines.
|
||||
- Tùy chọn sample n phần tử (random) để tăng tốc.
|
||||
- Chọn thuật toán giảm chiều: PCA, UMAP, t-SNE.
|
||||
- Tham số điều chỉnh: n_neighbors, min_dist (UMAP); perplexity (t-SNE); n_components=3.
|
||||
- KMeans clustering (tuỳ chọn) để tô màu điểm; hoặc tô màu theo regex/substring trong tên file.
|
||||
- Lọc theo từ khóa trong đường dẫn.
|
||||
- Tải xuống toạ độ 3D + nhãn cluster.
|
||||
|
||||
File embedding quan sát được có thể không phải JSON array chuẩn; script sẽ thử:
|
||||
1. Parse như JSON array.
|
||||
2. Parse như JSON lines (mỗi dòng 1 object).
|
||||
3. Parse thủ công bằng cách tìm pattern {"filepath": ... , "embedding": [ ... ]}.
|
||||
|
||||
Nếu kích thước > ~1e6 bytes, dùng đọc streaming để giảm RAM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import colorsys
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
try:
|
||||
import umap # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
umap = None # handled later
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRecord:
|
||||
filepath: str
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
def _smart_json_object_stream(raw_text: str) -> Iterable[str]:
|
||||
"""Yield JSON object strings from a large raw buffer.
|
||||
|
||||
Heuristic: find balanced braces starting with {"filepath": ...}.
|
||||
This is a fallback when content is not standard array / jsonlines.
|
||||
"""
|
||||
brace = 0
|
||||
buf = []
|
||||
in_obj = False
|
||||
for ch in raw_text:
|
||||
if ch == '{':
|
||||
if not in_obj:
|
||||
in_obj = True
|
||||
buf = ['{']
|
||||
brace = 1
|
||||
else:
|
||||
brace += 1
|
||||
buf.append(ch)
|
||||
elif ch == '}':
|
||||
if in_obj:
|
||||
brace -= 1
|
||||
buf.append('}')
|
||||
if brace == 0:
|
||||
yield ''.join(buf)
|
||||
in_obj = False
|
||||
else:
|
||||
# stray closing
|
||||
continue
|
||||
else:
|
||||
if in_obj:
|
||||
buf.append(ch)
|
||||
|
||||
|
||||
def load_embeddings(
|
||||
path: str,
|
||||
sample_size: Optional[int] = None,
|
||||
sampling_seed: int = 42,
|
||||
max_objects: Optional[int] = None,
|
||||
) -> List[EmbeddingRecord]:
|
||||
"""Load embeddings from a possibly large JSON / JSONL / raw file.
|
||||
|
||||
Args:
|
||||
path: file path
|
||||
sample_size: random sample (after load) if provided
|
||||
sampling_seed: RNG seed
|
||||
max_objects: hard cap to stop early (for speed)
|
||||
"""
|
||||
# size = os.path.getsize(path) # kích thước có thể dùng sau nếu muốn tối ưu đọc streaming
|
||||
# First attempt: JSON array
|
||||
records: List[EmbeddingRecord] = []
|
||||
def to_rec(obj) -> Optional[EmbeddingRecord]:
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
if 'embedding' in obj:
|
||||
fp = str(obj.get('filepath') or obj.get('file_path') or obj.get('path') or '')
|
||||
emb = obj['embedding']
|
||||
if isinstance(emb, list) and fp:
|
||||
return EmbeddingRecord(fp, emb)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
text_stripped = text.strip()
|
||||
if text_stripped.startswith('[') and text_stripped.endswith(']'):
|
||||
arr = json.loads(text_stripped)
|
||||
for obj in arr:
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
else:
|
||||
raise ValueError('Not a JSON array')
|
||||
except Exception:
|
||||
# Retry as JSON lines
|
||||
records = []
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line_strip = line.strip().rstrip(',')
|
||||
if not line_strip:
|
||||
continue
|
||||
if not line_strip.startswith('{'):
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line_strip)
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not records:
|
||||
raise ValueError('No JSONL records')
|
||||
except Exception:
|
||||
# Fallback: heuristic extraction
|
||||
records = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
raw = f.read()
|
||||
for obj_str in _smart_json_object_stream(raw):
|
||||
if 'embedding' not in obj_str:
|
||||
continue
|
||||
# Clean possible trailing ',"
|
||||
try:
|
||||
# Attempt to fix malformed numbers like '1.2\n421875' (broken newline) by removing stray newlines inside arrays
|
||||
fixed = re.sub(r"(\d)\n(\d)", r"\1\2", obj_str)
|
||||
obj = json.loads(fixed)
|
||||
except Exception:
|
||||
continue
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
if not records:
|
||||
raise RuntimeError("Không load được embedding nào từ file.")
|
||||
|
||||
# Random sample if needed
|
||||
if sample_size and sample_size < len(records):
|
||||
random.seed(sampling_seed)
|
||||
records = random.sample(records, sample_size)
|
||||
return records
|
||||
|
||||
|
||||
def reduce_embeddings(
|
||||
X: np.ndarray,
|
||||
method: str,
|
||||
random_state: int = 42,
|
||||
umap_neighbors: int = 15,
|
||||
umap_min_dist: float = 0.1,
|
||||
tsne_perplexity: int = 30,
|
||||
tsne_learning_rate: str | float = 'auto',
|
||||
) -> Tuple[np.ndarray, dict]:
|
||||
"""Project high-dim embeddings to 3D.
|
||||
|
||||
Returns (coords (n,3), meta_info)
|
||||
"""
|
||||
meta = {"method": method}
|
||||
if method == 'PCA':
|
||||
pca = PCA(n_components=3, random_state=random_state)
|
||||
coords = pca.fit_transform(X)
|
||||
meta['explained_variance_ratio'] = pca.explained_variance_ratio_.tolist()
|
||||
return coords, meta
|
||||
if method == 'UMAP':
|
||||
if umap is None:
|
||||
raise RuntimeError("Chưa cài umap-learn: pip install umap-learn")
|
||||
reducer = umap.UMAP(
|
||||
n_components=3,
|
||||
n_neighbors=umap_neighbors,
|
||||
min_dist=umap_min_dist,
|
||||
metric='cosine',
|
||||
random_state=random_state,
|
||||
)
|
||||
coords = reducer.fit_transform(X)
|
||||
meta['umap_graph_connectivity'] = float(reducer.graph_.getnnz())
|
||||
return coords, meta
|
||||
if method == 't-SNE':
|
||||
perplexity = min(tsne_perplexity, max(5, (X.shape[0] - 1) // 3))
|
||||
tsne = TSNE(
|
||||
n_components=3,
|
||||
perplexity=perplexity,
|
||||
learning_rate=tsne_learning_rate,
|
||||
init='pca',
|
||||
random_state=random_state,
|
||||
n_iter=1000,
|
||||
verbose=0,
|
||||
)
|
||||
coords = tsne.fit_transform(X)
|
||||
meta['effective_perplexity'] = perplexity
|
||||
return coords, meta
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
|
||||
|
||||
def kmeans_cluster(coords: np.ndarray, n_clusters: int, seed: int = 42) -> Tuple[np.ndarray, float]:
|
||||
if n_clusters <= 1:
|
||||
return np.zeros(coords.shape[0], dtype=int), float('nan')
|
||||
km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=seed)
|
||||
labels = km.fit_predict(coords)
|
||||
score = float('nan')
|
||||
if len(set(labels)) > 1 and coords.shape[0] >= n_clusters * 5:
|
||||
try:
|
||||
score = silhouette_score(coords, labels)
|
||||
except Exception:
|
||||
pass
|
||||
return labels, score
|
||||
|
||||
|
||||
def build_dataframe(recs: List[EmbeddingRecord]) -> pd.DataFrame:
|
||||
return pd.DataFrame({
|
||||
'filepath': [r.filepath for r in recs],
|
||||
'embedding': [r.embedding for r in recs],
|
||||
})
|
||||
|
||||
|
||||
def load_cluster_file(
|
||||
path: str,
|
||||
expected_n: int,
|
||||
noise_label: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Load cluster labels from a JSON result file.
|
||||
|
||||
Supports formats:
|
||||
- {"results": [ ... ]}
|
||||
- [ ... ]
|
||||
Each item may contain one of: cluster, cluster_id, label, is_noise, filepath.
|
||||
If only is_noise exists: non-noise -> 0, noise -> noise_label.
|
||||
If filepath present, mapping is done by filepath, otherwise by index order.
|
||||
"""
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Không đọc được file cluster: {e}')
|
||||
|
||||
if isinstance(content, dict) and 'results' in content:
|
||||
items = content['results']
|
||||
elif isinstance(content, list):
|
||||
items = content
|
||||
else:
|
||||
raise RuntimeError('Định dạng file cluster không hợp lệ (cần list hoặc có key "results").')
|
||||
|
||||
# Detect if filepath-based mapping
|
||||
use_filepath = any(isinstance(it, dict) and 'filepath' in it for it in items)
|
||||
|
||||
labels = np.full(expected_n, noise_label, dtype=int)
|
||||
if use_filepath:
|
||||
# Build path->label
|
||||
mapping = {}
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
fp = it.get('filepath')
|
||||
if not fp:
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
mapping[str(fp)] = int(val)
|
||||
except Exception:
|
||||
continue
|
||||
return labels, mapping # second value used later to map onto df
|
||||
|
||||
# Index-based mapping
|
||||
collected = []
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
# accept raw int labels
|
||||
if isinstance(it, int):
|
||||
collected.append(int(it))
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
collected.append(int(val))
|
||||
except Exception:
|
||||
collected.append(noise_label)
|
||||
|
||||
for i in range(min(expected_n, len(collected))):
|
||||
labels[i] = collected[i]
|
||||
return labels, None
|
||||
|
||||
|
||||
def main(): # pragma: no cover - Streamlit entry
|
||||
st.set_page_config(page_title="Embedding 3D Viewer", layout="wide")
|
||||
st.title("🔍 Embedding 3D Viewer")
|
||||
st.caption("Trực quan hóa tương quan embedding hóa đơn (Qwen2-VL).")
|
||||
|
||||
default_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
path = st.text_input('Đường dẫn file embedding', value=default_path)
|
||||
col_top = st.columns(4)
|
||||
sample_size = col_top[0].number_input('Sample (0 = all)', min_value=0, value=1000, step=100)
|
||||
max_objects = col_top[1].number_input('Max objects đọc (0 = no limit)', min_value=0, value=0, step=500)
|
||||
seed = col_top[2].number_input('Seed', min_value=0, value=42, step=1)
|
||||
show_raw = col_top[3].checkbox('Hiện bảng raw', value=False)
|
||||
|
||||
algo = st.selectbox('Thuật toán giảm chiều', ['UMAP', 'PCA', 't-SNE'], index=0)
|
||||
with st.expander('Tham số giảm chiều'):
|
||||
if algo == 'UMAP':
|
||||
umap_neighbors = st.slider('n_neighbors', 5, 100, 15, 1)
|
||||
umap_min_dist = st.slider('min_dist', 0.0, 1.0, 0.1, 0.01)
|
||||
tsne_perplexity = 30
|
||||
elif algo == 't-SNE':
|
||||
tsne_perplexity = st.slider('perplexity', 5, 100, 30, 1)
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
else:
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
tsne_perplexity = 30
|
||||
|
||||
with st.expander('Phân cụm & màu sắc'):
|
||||
cluster_source = st.radio('Nguồn cluster', ['KMeans','Load file','None'], index=0, horizontal=True)
|
||||
n_clusters = st.slider('Số cluster KMeans', 2, 100, 10, 1, disabled=(cluster_source != 'KMeans'))
|
||||
cluster_file_path = st.text_input('Đường dẫn file cluster (JSON)', value='cluster/dbscan_results.json', disabled=(cluster_source != 'Load file'))
|
||||
noise_label = st.number_input('Giá trị label cho noise', value=-1, step=1, disabled=(cluster_source != 'Load file'))
|
||||
path_filter = st.text_input('Filter substring (lọc filepath, comma = OR)', value='')
|
||||
color_by_substring = st.text_input('Color theo substring (vd: osteo, facture)', value='')
|
||||
palette_name = st.selectbox('Bảng màu', ['auto','okabe-ito','tol','plotly','bold','dark24','set3','d3','hcl'], index=0, help='Chọn bảng màu dễ phân biệt hơn (okabe-ito, tol = thân thiện cho người mù màu)')
|
||||
marker_size = st.slider('Kích thước điểm', 2, 15, 5, 1)
|
||||
st.caption('Nếu file cluster chỉ chứa is_noise: noise -> noise_label, còn lại -> 0.')
|
||||
|
||||
load_btn = st.button('🚀 Load & Giảm chiều', type='primary')
|
||||
|
||||
if load_btn:
|
||||
if not os.path.isfile(path):
|
||||
st.error(f'File không tồn tại: {path}')
|
||||
return
|
||||
with st.spinner('Đang load embeddings...'):
|
||||
recs = load_embeddings(
|
||||
path,
|
||||
sample_size=sample_size or None,
|
||||
sampling_seed=int(seed),
|
||||
max_objects=max_objects or None,
|
||||
)
|
||||
st.success(f'Loaded {len(recs)} embeddings')
|
||||
|
||||
df = build_dataframe(recs)
|
||||
dim = len(df['embedding'].iloc[0])
|
||||
st.write(f'Chiều gốc: {dim}')
|
||||
|
||||
# Filter
|
||||
if path_filter.strip():
|
||||
tokens = [t.strip() for t in path_filter.split(',') if t.strip()]
|
||||
if tokens:
|
||||
mask = df['filepath'].apply(lambda p: any(tok.lower() in p.lower() for tok in tokens))
|
||||
df = df[mask].reset_index(drop=True)
|
||||
st.info(f'Filter còn {len(df)} bản ghi.')
|
||||
if df.empty:
|
||||
return
|
||||
|
||||
X = np.vstack(df['embedding'].values).astype(np.float32)
|
||||
with st.spinner('Đang giảm chiều...'):
|
||||
coords, meta = reduce_embeddings(
|
||||
X,
|
||||
algo,
|
||||
random_state=int(seed),
|
||||
umap_neighbors=umap_neighbors,
|
||||
umap_min_dist=umap_min_dist,
|
||||
tsne_perplexity=tsne_perplexity,
|
||||
)
|
||||
df[['x', 'y', 'z']] = coords
|
||||
st.write('Meta:', meta)
|
||||
|
||||
# Clustering (three modes)
|
||||
if cluster_source == 'KMeans':
|
||||
with st.spinner('KMeans clustering...'):
|
||||
labels, sil = kmeans_cluster(coords, n_clusters, seed)
|
||||
df['cluster'] = labels
|
||||
if not math.isnan(sil):
|
||||
st.write(f'Silhouette: {sil:.4f}')
|
||||
else:
|
||||
st.write('Silhouette: N/A')
|
||||
elif cluster_source == 'Load file':
|
||||
if not os.path.isfile(cluster_file_path):
|
||||
st.error(f'Không tìm thấy file cluster: {cluster_file_path}')
|
||||
return
|
||||
with st.spinner('Đang nạp label từ file cluster...'):
|
||||
loaded_labels, mapping = load_cluster_file(cluster_file_path, len(df), noise_label=int(noise_label))
|
||||
if mapping is not None:
|
||||
# map by filepath
|
||||
labels = []
|
||||
miss = 0
|
||||
for fp in df['filepath']:
|
||||
lab = mapping.get(fp)
|
||||
if lab is None:
|
||||
lab = int(noise_label)
|
||||
miss += 1
|
||||
labels.append(lab)
|
||||
labels = np.array(labels, dtype=int)
|
||||
if miss:
|
||||
st.warning(f'{miss} filepath không tìm thấy trong file cluster – gán noise.')
|
||||
else:
|
||||
labels = loaded_labels
|
||||
df['cluster'] = labels
|
||||
# Try silhouette when >1 cluster and not only noise
|
||||
uniq = set(labels)
|
||||
if len([u for u in uniq if u != int(noise_label)]) > 1:
|
||||
try:
|
||||
mask = labels != int(noise_label)
|
||||
sil = silhouette_score(coords[mask], labels[mask])
|
||||
st.write(f'Silhouette (exclude noise): {sil:.4f}')
|
||||
except Exception:
|
||||
st.write('Silhouette: N/A')
|
||||
else:
|
||||
st.write('Silhouette: N/A')
|
||||
else: # None
|
||||
df['cluster'] = -1
|
||||
st.write('Không áp dụng clustering.')
|
||||
|
||||
# Color scheme
|
||||
if color_by_substring.strip():
|
||||
subs = [s.strip() for s in color_by_substring.split(',') if s.strip()]
|
||||
def color_from_sub(p: str) -> str:
|
||||
for i, ssub in enumerate(subs):
|
||||
if ssub.lower() in p.lower():
|
||||
return ssub
|
||||
return 'other'
|
||||
df['color_group'] = df['filepath'].apply(color_from_sub)
|
||||
color_col = 'color_group'
|
||||
else:
|
||||
color_col = 'cluster'
|
||||
|
||||
# --- Palette handling -------------------------------------------------
|
||||
def get_base_palette(name: str) -> List[str]:
|
||||
name = name.lower()
|
||||
if name == 'okabe-ito': # 8 colors, colorblind-safe
|
||||
return ["#000000","#E69F00","#56B4E9","#009E73","#F0E442","#0072B2","#D55E00","#CC79A7"]
|
||||
if name == 'tol': # Paul Tol (12)
|
||||
return ["#4477AA","#66CCEE","#228833","#CCBB44","#EE6677","#AA3377","#BBBBBB","#000000","#EEDD88","#FFAABB","#99DDFF","#44BB99"]
|
||||
if name == 'plotly':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Plotly
|
||||
if name == 'bold':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Bold
|
||||
if name == 'dark24':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Dark24
|
||||
if name == 'set3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Set3
|
||||
if name == 'd3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.D3
|
||||
return []
|
||||
|
||||
def generate_hcl_like(n: int) -> List[str]:
|
||||
# Simple evenly spaced hues in HSV then adjust to look more balanced
|
||||
cols = []
|
||||
for i in range(n):
|
||||
h = (i / n) % 1.0
|
||||
s = 0.55 + 0.35 * ((i * 37) % 2) # alternate saturation
|
||||
v = 0.85 if (i % 3) else 0.98
|
||||
r, g, b = colorsys.hsv_to_rgb(h, s, v)
|
||||
cols.append('#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255)))
|
||||
return cols
|
||||
|
||||
def build_palette(name: str, k: int) -> List[str]:
|
||||
if name == 'auto':
|
||||
return [] # let plotly decide
|
||||
if name == 'hcl':
|
||||
return generate_hcl_like(k)
|
||||
base = get_base_palette(name)
|
||||
if k <= len(base):
|
||||
return base[:k]
|
||||
# extend by generating extra colors using HSV golden ratio
|
||||
cols = list(base)
|
||||
gold = 0.61803398875
|
||||
h = 0.1
|
||||
while len(cols) < k:
|
||||
h = (h + gold) % 1.0
|
||||
r, g, b = colorsys.hsv_to_rgb(h, 0.6, 0.95)
|
||||
newc = '#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255))
|
||||
if newc not in cols:
|
||||
cols.append(newc)
|
||||
return cols
|
||||
|
||||
color_display_col = color_col + '_display'
|
||||
# ensure discrete (string) for numeric clusters
|
||||
if np.issubdtype(df[color_col].dtype, np.number):
|
||||
df[color_display_col] = df[color_col].astype(int).astype(str)
|
||||
else:
|
||||
df[color_display_col] = df[color_col].astype(str)
|
||||
groups = df[color_display_col].unique()
|
||||
palette_seq = build_palette(palette_name, len(groups)) if palette_name else []
|
||||
|
||||
import plotly.express as px
|
||||
fig = px.scatter_3d(
|
||||
df,
|
||||
x='x', y='y', z='z',
|
||||
color=color_display_col,
|
||||
color_discrete_sequence=palette_seq if palette_seq else None,
|
||||
hover_data={'filepath': True, 'cluster': True, 'x': ':.2f', 'y': ':.2f', 'z': ':.2f'},
|
||||
title=f'Embedding 3D ({algo})',
|
||||
opacity=0.9,
|
||||
height=800,
|
||||
)
|
||||
fig.update_traces(marker={'size': int(marker_size)})
|
||||
fig.update_layout(legend=dict(title='Nhóm', itemsizing='constant'))
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
if show_raw:
|
||||
st.dataframe(df.head(100))
|
||||
|
||||
# Download
|
||||
out_csv = df[['filepath', 'x', 'y', 'z', 'cluster']].to_csv(index=False).encode('utf-8')
|
||||
st.download_button('⬇️ Tải toạ độ 3D (CSV)', out_csv, file_name='embedding_3d.csv', mime='text/csv')
|
||||
|
||||
st.caption('Hoàn tất.')
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: no cover
|
||||
# Khi chạy bằng 'streamlit run', sys.argv chỉ chứa tên file => ta luôn gọi main()
|
||||
# Nếu muốn test nhanh CLI, có thể thêm arg '--cli-test'
|
||||
if '--cli-test' in os.sys.argv:
|
||||
test_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
if os.path.exists(test_path):
|
||||
recs = load_embeddings(test_path, sample_size=5)
|
||||
print(f'[CLI TEST] Loaded {len(recs)} embeddings dim={len(recs[0].embedding)}')
|
||||
else:
|
||||
print('[CLI TEST] Không tìm thấy file test.')
|
||||
else:
|
||||
main()
|
||||
|
600
check_filter/visual_data copy.py
Normal file
600
check_filter/visual_data copy.py
Normal file
@@ -0,0 +1,600 @@
|
||||
"""
|
||||
Streamlit app để trực quan hóa embedding 3 chiều (PCA / UMAP / t-SNE) + phân cụm.
|
||||
|
||||
Chạy:
|
||||
streamlit run visual_data.py --server.port 8501
|
||||
|
||||
Yêu cầu cài đặt (một lần):
|
||||
pip install streamlit plotly scikit-learn umap-learn numpy pandas
|
||||
|
||||
Tính năng:
|
||||
- Load file JSON lớn chứa các object {"filepath": ..., "embedding": [...]} hoặc định dạng JSON lines.
|
||||
- Tùy chọn sample n phần tử (random) để tăng tốc.
|
||||
- Chọn thuật toán giảm chiều: PCA, UMAP, t-SNE.
|
||||
- Tham số điều chỉnh: n_neighbors, min_dist (UMAP); perplexity (t-SNE); n_components=3.
|
||||
- KMeans clustering (tuỳ chọn) để tô màu điểm; hoặc tô màu theo regex/substring trong tên file.
|
||||
- Lọc theo từ khóa trong đường dẫn.
|
||||
- Tải xuống toạ độ 3D + nhãn cluster.
|
||||
|
||||
File embedding quan sát được có thể không phải JSON array chuẩn; script sẽ thử:
|
||||
1. Parse như JSON array.
|
||||
2. Parse như JSON lines (mỗi dòng 1 object).
|
||||
3. Parse thủ công bằng cách tìm pattern {"filepath": ... , "embedding": [ ... ]}.
|
||||
|
||||
Nếu kích thước > ~1e6 bytes, dùng đọc streaming để giảm RAM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import colorsys
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
try:
|
||||
import umap # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
umap = None # handled later
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRecord:
|
||||
filepath: str
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
def _smart_json_object_stream(raw_text: str) -> Iterable[str]:
|
||||
"""Yield JSON object strings from a large raw buffer.
|
||||
|
||||
Heuristic: find balanced braces starting with {"filepath": ...}.
|
||||
This is a fallback when content is not standard array / jsonlines.
|
||||
"""
|
||||
brace = 0
|
||||
buf = []
|
||||
in_obj = False
|
||||
for ch in raw_text:
|
||||
if ch == '{':
|
||||
if not in_obj:
|
||||
in_obj = True
|
||||
buf = ['{']
|
||||
brace = 1
|
||||
else:
|
||||
brace += 1
|
||||
buf.append(ch)
|
||||
elif ch == '}':
|
||||
if in_obj:
|
||||
brace -= 1
|
||||
buf.append('}')
|
||||
if brace == 0:
|
||||
yield ''.join(buf)
|
||||
in_obj = False
|
||||
else:
|
||||
# stray closing
|
||||
continue
|
||||
else:
|
||||
if in_obj:
|
||||
buf.append(ch)
|
||||
|
||||
|
||||
def load_embeddings(
|
||||
path: str,
|
||||
sample_size: Optional[int] = None,
|
||||
sampling_seed: int = 42,
|
||||
max_objects: Optional[int] = None,
|
||||
) -> List[EmbeddingRecord]:
|
||||
"""Load embeddings from a possibly large JSON / JSONL / raw file.
|
||||
|
||||
Args:
|
||||
path: file path
|
||||
sample_size: random sample (after load) if provided
|
||||
sampling_seed: RNG seed
|
||||
max_objects: hard cap to stop early (for speed)
|
||||
"""
|
||||
# size = os.path.getsize(path) # kích thước có thể dùng sau nếu muốn tối ưu đọc streaming
|
||||
# First attempt: JSON array
|
||||
records: List[EmbeddingRecord] = []
|
||||
def to_rec(obj) -> Optional[EmbeddingRecord]:
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
if 'embedding' in obj:
|
||||
fp = str(obj.get('filepath') or obj.get('file_path') or obj.get('path') or '')
|
||||
emb = obj['embedding']
|
||||
if isinstance(emb, list) and fp:
|
||||
return EmbeddingRecord(fp, emb)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
text_stripped = text.strip()
|
||||
if text_stripped.startswith('[') and text_stripped.endswith(']'):
|
||||
arr = json.loads(text_stripped)
|
||||
for obj in arr:
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
else:
|
||||
raise ValueError('Not a JSON array')
|
||||
except Exception:
|
||||
# Retry as JSON lines
|
||||
records = []
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line_strip = line.strip().rstrip(',')
|
||||
if not line_strip:
|
||||
continue
|
||||
if not line_strip.startswith('{'):
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line_strip)
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not records:
|
||||
raise ValueError('No JSONL records')
|
||||
except Exception:
|
||||
# Fallback: heuristic extraction
|
||||
records = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
raw = f.read()
|
||||
for obj_str in _smart_json_object_stream(raw):
|
||||
if 'embedding' not in obj_str:
|
||||
continue
|
||||
# Clean possible trailing ',"
|
||||
try:
|
||||
# Attempt to fix malformed numbers like '1.2\n421875' (broken newline) by removing stray newlines inside arrays
|
||||
fixed = re.sub(r"(\d)\n(\d)", r"\1\2", obj_str)
|
||||
obj = json.loads(fixed)
|
||||
except Exception:
|
||||
continue
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
if not records:
|
||||
raise RuntimeError("Không load được embedding nào từ file.")
|
||||
|
||||
# Random sample if needed
|
||||
if sample_size and sample_size < len(records):
|
||||
random.seed(sampling_seed)
|
||||
records = random.sample(records, sample_size)
|
||||
return records
|
||||
|
||||
|
||||
def reduce_embeddings(
|
||||
X: np.ndarray,
|
||||
method: str,
|
||||
random_state: int = 42,
|
||||
umap_neighbors: int = 15,
|
||||
umap_min_dist: float = 0.1,
|
||||
tsne_perplexity: int = 30,
|
||||
tsne_learning_rate: str | float = 'auto',
|
||||
) -> Tuple[np.ndarray, dict]:
|
||||
"""Project high-dim embeddings to 3D.
|
||||
|
||||
Returns (coords (n,3), meta_info)
|
||||
"""
|
||||
meta = {"method": method}
|
||||
if method == 'PCA':
|
||||
pca = PCA(n_components=3, random_state=random_state)
|
||||
coords = pca.fit_transform(X)
|
||||
meta['explained_variance_ratio'] = pca.explained_variance_ratio_.tolist()
|
||||
return coords, meta
|
||||
if method == 'UMAP':
|
||||
if umap is None:
|
||||
raise RuntimeError("Chưa cài umap-learn: pip install umap-learn")
|
||||
reducer = umap.UMAP(
|
||||
n_components=3,
|
||||
n_neighbors=umap_neighbors,
|
||||
min_dist=umap_min_dist,
|
||||
metric='cosine',
|
||||
random_state=random_state,
|
||||
)
|
||||
coords = reducer.fit_transform(X)
|
||||
meta['umap_graph_connectivity'] = float(reducer.graph_.getnnz())
|
||||
return coords, meta
|
||||
if method == 't-SNE':
|
||||
perplexity = min(tsne_perplexity, max(5, (X.shape[0] - 1) // 3))
|
||||
tsne = TSNE(
|
||||
n_components=3,
|
||||
perplexity=perplexity,
|
||||
learning_rate=tsne_learning_rate,
|
||||
init='pca',
|
||||
random_state=random_state,
|
||||
n_iter=1000,
|
||||
verbose=0,
|
||||
)
|
||||
coords = tsne.fit_transform(X)
|
||||
meta['effective_perplexity'] = perplexity
|
||||
return coords, meta
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
|
||||
|
||||
def kmeans_cluster(coords: np.ndarray, n_clusters: int, seed: int = 42) -> Tuple[np.ndarray, float]:
|
||||
if n_clusters <= 1:
|
||||
return np.zeros(coords.shape[0], dtype=int), float('nan')
|
||||
km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=seed)
|
||||
labels = km.fit_predict(coords)
|
||||
score = float('nan')
|
||||
if len(set(labels)) > 1 and coords.shape[0] >= n_clusters * 5:
|
||||
try:
|
||||
score = silhouette_score(coords, labels)
|
||||
except Exception:
|
||||
pass
|
||||
return labels, score
|
||||
|
||||
|
||||
def build_dataframe(recs: List[EmbeddingRecord]) -> pd.DataFrame:
|
||||
return pd.DataFrame({
|
||||
'filepath': [r.filepath for r in recs],
|
||||
'embedding': [r.embedding for r in recs],
|
||||
})
|
||||
|
||||
|
||||
def load_cluster_file(
|
||||
path: str,
|
||||
expected_n: int,
|
||||
noise_label: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Load cluster labels from a JSON result file.
|
||||
|
||||
Supports formats:
|
||||
- {"results": [ ... ]}
|
||||
- [ ... ]
|
||||
Each item may contain one of: cluster, cluster_id, label, is_noise, filepath.
|
||||
If only is_noise exists: non-noise -> 0, noise -> noise_label.
|
||||
If filepath present, mapping is done by filepath, otherwise by index order.
|
||||
"""
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Không đọc được file cluster: {e}')
|
||||
|
||||
if isinstance(content, dict) and 'results' in content:
|
||||
items = content['results']
|
||||
elif isinstance(content, list):
|
||||
items = content
|
||||
else:
|
||||
raise RuntimeError('Định dạng file cluster không hợp lệ (cần list hoặc có key "results").')
|
||||
|
||||
# Detect if filepath-based mapping
|
||||
use_filepath = any(isinstance(it, dict) and 'filepath' in it for it in items)
|
||||
|
||||
labels = np.full(expected_n, noise_label, dtype=int)
|
||||
if use_filepath:
|
||||
# Build path->label
|
||||
mapping = {}
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
fp = it.get('filepath')
|
||||
if not fp:
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
mapping[str(fp)] = int(val)
|
||||
except Exception:
|
||||
continue
|
||||
return labels, mapping # second value used later to map onto df
|
||||
|
||||
# Index-based mapping
|
||||
collected = []
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
# accept raw int labels
|
||||
if isinstance(it, int):
|
||||
collected.append(int(it))
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
collected.append(int(val))
|
||||
except Exception:
|
||||
collected.append(noise_label)
|
||||
|
||||
for i in range(min(expected_n, len(collected))):
|
||||
labels[i] = collected[i]
|
||||
return labels, None
|
||||
|
||||
|
||||
def main(): # pragma: no cover - Streamlit entry
|
||||
st.set_page_config(page_title="Embedding 3D Viewer", layout="wide")
|
||||
st.title("🔍 Embedding 3D Viewer (Multi)")
|
||||
st.caption("Mỗi lần nạp file sẽ thêm một đồ thị mới ở bên dưới.")
|
||||
|
||||
if 'plots' not in st.session_state:
|
||||
st.session_state['plots']: List[Dict[str, Any]] = [] # type: ignore
|
||||
|
||||
with st.sidebar:
|
||||
st.markdown('### ⚙️ Cấu hình chung')
|
||||
default_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
path = st.text_input('Đường dẫn file embedding', value=default_path, key='path_input')
|
||||
sample_size = st.number_input('Sample (0 = all)', min_value=0, value=1000, step=100, key='sample_size')
|
||||
max_objects = st.number_input('Max objects đọc (0 = no limit)', min_value=0, value=0, step=500, key='max_objects')
|
||||
seed = st.number_input('Seed', min_value=0, value=42, step=1, key='seed')
|
||||
show_raw = st.checkbox('Hiện bảng raw (chỉ 100 dòng đầu mỗi plot)', value=False, key='show_raw')
|
||||
|
||||
algo = st.selectbox('Thuật toán giảm chiều', ['UMAP', 'PCA', 't-SNE'], index=0, key='algo')
|
||||
if algo == 'UMAP':
|
||||
umap_neighbors = st.slider('UMAP n_neighbors', 5, 100, 15, 1, key='umap_neighbors')
|
||||
umap_min_dist = st.slider('UMAP min_dist', 0.0, 1.0, 0.1, 0.01, key='umap_min_dist')
|
||||
tsne_perplexity = 30
|
||||
elif algo == 't-SNE':
|
||||
tsne_perplexity = st.slider('t-SNE perplexity', 5, 100, 30, 1, key='tsne_perplexity')
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
else:
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
tsne_perplexity = 30
|
||||
|
||||
st.markdown('### 🎨 Màu sắc & Cluster')
|
||||
cluster_source = st.radio('Nguồn cluster', ['KMeans','Load file','None'], index=0, horizontal=False, key='cluster_source')
|
||||
n_clusters = st.slider('Số cluster KMeans', 2, 100, 10, 1, disabled=(cluster_source != 'KMeans'), key='n_clusters')
|
||||
cluster_file_path = st.text_input('File cluster JSON', value='cluster/dbscan_results.json', disabled=(cluster_source != 'Load file'), key='cluster_file')
|
||||
noise_label = st.number_input('Noise label', value=-1, step=1, disabled=(cluster_source != 'Load file'), key='noise_label')
|
||||
path_filter = st.text_input('Filter filepath (comma OR)', value='', key='path_filter')
|
||||
color_by_substring = st.text_input('Color theo substring', value='', key='color_by_substring')
|
||||
palette_name = st.selectbox('Bảng màu', ['auto','okabe-ito','tol','plotly','bold','dark24','set3','d3','hcl'], index=0, help='okabe-ito, tol = thân thiện mù màu', key='palette')
|
||||
marker_size = st.slider('Kích thước điểm', 2, 15, 5, 1, key='marker_size')
|
||||
st.caption('File cluster chỉ có is_noise: non-noise -> 0, noise -> noise_label.')
|
||||
|
||||
add_btn = st.button('➕ Thêm plot', type='primary')
|
||||
clear_btn = st.button('🧹 Xoá tất cả plot')
|
||||
|
||||
if clear_btn:
|
||||
st.session_state['plots'].clear()
|
||||
st.success('Đã xoá toàn bộ plot.')
|
||||
|
||||
def process_and_add_plot():
|
||||
path_local = path
|
||||
if not os.path.isfile(path_local):
|
||||
st.error(f'File không tồn tại: {path_local}')
|
||||
return
|
||||
with st.spinner(f'Loading embeddings: {os.path.basename(path_local)}'):
|
||||
recs = load_embeddings(
|
||||
path_local,
|
||||
sample_size=sample_size or None,
|
||||
sampling_seed=int(seed),
|
||||
max_objects=max_objects or None,
|
||||
)
|
||||
df = build_dataframe(recs)
|
||||
if df.empty:
|
||||
st.warning('Không có embedding nào.')
|
||||
return
|
||||
# Filter
|
||||
if path_filter.strip():
|
||||
tokens = [t.strip() for t in path_filter.split(',') if t.strip()]
|
||||
if tokens:
|
||||
mask = df['filepath'].apply(lambda p: any(tok.lower() in p.lower() for tok in tokens))
|
||||
df = df[mask].reset_index(drop=True)
|
||||
if df.empty:
|
||||
st.warning('Sau filter không còn bản ghi.')
|
||||
return
|
||||
# Reduce
|
||||
X = np.vstack(df['embedding'].values).astype(np.float32)
|
||||
with st.spinner('Giảm chiều...'):
|
||||
coords, meta = reduce_embeddings(
|
||||
X,
|
||||
algo,
|
||||
random_state=int(seed),
|
||||
umap_neighbors=umap_neighbors,
|
||||
umap_min_dist=umap_min_dist,
|
||||
tsne_perplexity=tsne_perplexity,
|
||||
)
|
||||
df[['x','y','z']] = coords
|
||||
# Clustering
|
||||
if cluster_source == 'KMeans':
|
||||
with st.spinner('KMeans...'):
|
||||
labels, sil = kmeans_cluster(coords, n_clusters, int(seed))
|
||||
df['cluster'] = labels
|
||||
sil_msg = f'Silhouette: {sil:.4f}' if not math.isnan(sil) else 'Silhouette: N/A'
|
||||
elif cluster_source == 'Load file':
|
||||
if not os.path.isfile(cluster_file_path):
|
||||
st.warning(f'Không thấy file cluster: {cluster_file_path}; gán -1')
|
||||
df['cluster'] = -1
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
loaded_labels, mapping = load_cluster_file(cluster_file_path, len(df), noise_label=int(noise_label))
|
||||
if mapping is not None:
|
||||
labs = []
|
||||
miss=0
|
||||
for fp in df['filepath']:
|
||||
val = mapping.get(fp, int(noise_label))
|
||||
if fp not in mapping:
|
||||
miss += 1
|
||||
labs.append(val)
|
||||
if miss:
|
||||
st.info(f'{miss} filepath gán noise.')
|
||||
df['cluster'] = np.array(labs, dtype=int)
|
||||
else:
|
||||
df['cluster'] = loaded_labels
|
||||
uniq = set(df['cluster'])
|
||||
if len([u for u in uniq if u != int(noise_label)]) > 1:
|
||||
try:
|
||||
mask_valid = df['cluster'].to_numpy() != int(noise_label)
|
||||
sil = silhouette_score(coords[mask_valid], df.loc[mask_valid,'cluster'])
|
||||
sil_msg = f'Silhouette(ex noise): {sil:.4f}'
|
||||
except Exception:
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
df['cluster'] = -1
|
||||
sil_msg = 'No clustering'
|
||||
|
||||
# Color grouping
|
||||
if color_by_substring.strip():
|
||||
subs = [s.strip() for s in color_by_substring.split(',') if s.strip()]
|
||||
def color_from_sub(p: str) -> str:
|
||||
for ssub in subs:
|
||||
if ssub.lower() in p.lower():
|
||||
return ssub
|
||||
return 'other'
|
||||
df['color_group'] = df['filepath'].apply(color_from_sub)
|
||||
color_col = 'color_group'
|
||||
else:
|
||||
color_col = 'cluster'
|
||||
|
||||
# Palette utilities (copy of earlier helpers)
|
||||
def get_base_palette(name: str) -> List[str]:
|
||||
name = name.lower()
|
||||
if name == 'okabe-ito':
|
||||
return ["#000000","#E69F00","#56B4E9","#009E73","#F0E442","#0072B2","#D55E00","#CC79A7"]
|
||||
if name == 'tol':
|
||||
return ["#4477AA","#66CCEE","#228833","#CCBB44","#EE6677","#AA3377","#BBBBBB","#000000","#EEDD88","#FFAABB","#99DDFF","#44BB99"]
|
||||
if name == 'plotly':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Plotly
|
||||
if name == 'bold':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Bold
|
||||
if name == 'dark24':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Dark24
|
||||
if name == 'set3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Set3
|
||||
if name == 'd3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.D3
|
||||
return []
|
||||
def generate_hcl_like(n: int) -> List[str]:
|
||||
cols=[]
|
||||
for i in range(n):
|
||||
h=(i/n)%1.0
|
||||
s=0.55+0.35*((i*37)%2)
|
||||
v=0.85 if (i%3) else 0.98
|
||||
r,g,b = colorsys.hsv_to_rgb(h,s,v)
|
||||
cols.append('#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255)))
|
||||
return cols
|
||||
def build_palette(name: str, k: int) -> List[str]:
|
||||
if name == 'auto':
|
||||
return []
|
||||
if name == 'hcl':
|
||||
return generate_hcl_like(k)
|
||||
base = get_base_palette(name)
|
||||
if k <= len(base):
|
||||
return base[:k]
|
||||
cols=list(base)
|
||||
gold=0.61803398875
|
||||
h=0.1
|
||||
while len(cols) < k:
|
||||
h=(h+gold)%1.0
|
||||
r,g,b = colorsys.hsv_to_rgb(h,0.6,0.95)
|
||||
newc='#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255))
|
||||
if newc not in cols:
|
||||
cols.append(newc)
|
||||
return cols
|
||||
color_display_col = color_col + '_display'
|
||||
if np.issubdtype(df[color_col].dtype, np.number):
|
||||
df[color_display_col] = df[color_col].astype(int).astype(str)
|
||||
else:
|
||||
df[color_display_col] = df[color_col].astype(str)
|
||||
groups = df[color_display_col].unique()
|
||||
palette_seq = build_palette(palette_name, len(groups)) if palette_name else []
|
||||
import plotly.express as px
|
||||
fig = px.scatter_3d(
|
||||
df,
|
||||
x='x', y='y', z='z',
|
||||
color=color_display_col,
|
||||
color_discrete_sequence=palette_seq if palette_seq else None,
|
||||
hover_data={'filepath': True, 'cluster': True, 'x': ':.2f', 'y': ':.2f', 'z': ':.2f'},
|
||||
title=f'{os.path.basename(path_local)} ({algo})',
|
||||
opacity=0.9,
|
||||
height=700,
|
||||
)
|
||||
fig.update_traces(marker={'size': int(marker_size)})
|
||||
fig.update_layout(margin=dict(l=0,r=0,t=40,b=0))
|
||||
out_csv = df[['filepath','x','y','z','cluster']].to_csv(index=False).encode('utf-8')
|
||||
st.session_state['plots'].append({
|
||||
'path': path_local,
|
||||
'df_head': df.head(100) if show_raw else None,
|
||||
'fig': fig,
|
||||
'meta': meta,
|
||||
'algo': algo,
|
||||
'sil_msg': sil_msg,
|
||||
'csv': out_csv,
|
||||
'n': len(df),
|
||||
'dim': len(df['embedding'].iloc[0]),
|
||||
})
|
||||
st.success(f'Đã thêm plot: {os.path.basename(path_local)}')
|
||||
|
||||
if add_btn:
|
||||
process_and_add_plot()
|
||||
|
||||
# Render existing plots
|
||||
if st.session_state['plots']:
|
||||
st.markdown('---')
|
||||
for idx, plot_data in enumerate(st.session_state['plots']):
|
||||
container = st.container()
|
||||
with container:
|
||||
cols = st.columns([0.8,0.2])
|
||||
with cols[0]:
|
||||
st.subheader(f'#{idx+1} {os.path.basename(plot_data["path"])}')
|
||||
with cols[1]:
|
||||
if st.button('❌ Xoá', key=f'remove_{idx}'):
|
||||
st.session_state['plots'].pop(idx)
|
||||
st.experimental_rerun()
|
||||
st.caption(f"Embeddings: {plot_data['n']} | Dim gốc: {plot_data['dim']} | {plot_data['sil_msg']}")
|
||||
st.plotly_chart(plot_data['fig'], use_container_width=True)
|
||||
with st.expander('Meta / Thông tin thêm'):
|
||||
st.json(plot_data['meta'])
|
||||
st.download_button('⬇️ CSV', plot_data['csv'], file_name=f"embedding_3d_{idx+1}.csv", mime='text/csv', key=f'dl_{idx}')
|
||||
if plot_data['df_head'] is not None:
|
||||
st.dataframe(plot_data['df_head'])
|
||||
if not st.session_state['plots']:
|
||||
st.info('Chưa có plot nào. Chọn file ở sidebar và nhấn "➕ Thêm plot".')
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: no cover
|
||||
# Khi chạy bằng 'streamlit run', sys.argv chỉ chứa tên file => ta luôn gọi main()
|
||||
# Nếu muốn test nhanh CLI, có thể thêm arg '--cli-test'
|
||||
if '--cli-test' in os.sys.argv:
|
||||
test_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
if os.path.exists(test_path):
|
||||
recs = load_embeddings(test_path, sample_size=5)
|
||||
print(f'[CLI TEST] Loaded {len(recs)} embeddings dim={len(recs[0].embedding)}')
|
||||
else:
|
||||
print('[CLI TEST] Không tìm thấy file test.')
|
||||
else:
|
||||
main()
|
||||
|
783
check_filter/visual_data.py
Normal file
783
check_filter/visual_data.py
Normal file
@@ -0,0 +1,783 @@
|
||||
"""
|
||||
Streamlit app để trực quan hóa embedding 3 chiều (PCA / UMAP / t-SNE) + phân cụm.
|
||||
|
||||
Chạy:
|
||||
streamlit run visual_data.py --server.port 8501
|
||||
|
||||
Yêu cầu cài đặt (một lần):
|
||||
pip install streamlit plotly scikit-learn umap-learn numpy pandas
|
||||
|
||||
Tính năng:
|
||||
- Load file JSON lớn chứa các object {"filepath": ..., "embedding": [...]} hoặc định dạng JSON lines.
|
||||
- Tùy chọn sample n phần tử (random) để tăng tốc.
|
||||
- Chọn thuật toán giảm chiều: PCA, UMAP, t-SNE.
|
||||
- Tham số điều chỉnh: n_neighbors, min_dist (UMAP); perplexity (t-SNE); n_components=3.
|
||||
- KMeans clustering (tuỳ chọn) để tô màu điểm; hoặc tô màu theo regex/substring trong tên file.
|
||||
- Lọc theo từ khóa trong đường dẫn.
|
||||
- Tải xuống toạ độ 3D + nhãn cluster.
|
||||
|
||||
File embedding quan sát được có thể không phải JSON array chuẩn; script sẽ thử:
|
||||
1. Parse như JSON array.
|
||||
2. Parse như JSON lines (mỗi dòng 1 object).
|
||||
3. Parse thủ công bằng cách tìm pattern {"filepath": ... , "embedding": [ ... ]}.
|
||||
|
||||
Nếu kích thước > ~1e6 bytes, dùng đọc streaming để giảm RAM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import colorsys
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
try:
|
||||
import umap # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
umap = None # handled later
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRecord:
|
||||
filepath: str
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
def _smart_json_object_stream(raw_text: str) -> Iterable[str]:
|
||||
"""Yield JSON object strings from a large raw buffer.
|
||||
|
||||
Heuristic: find balanced braces starting with {"filepath": ...}.
|
||||
This is a fallback when content is not standard array / jsonlines.
|
||||
"""
|
||||
brace = 0
|
||||
buf = []
|
||||
in_obj = False
|
||||
for ch in raw_text:
|
||||
if ch == '{':
|
||||
if not in_obj:
|
||||
in_obj = True
|
||||
buf = ['{']
|
||||
brace = 1
|
||||
else:
|
||||
brace += 1
|
||||
buf.append(ch)
|
||||
elif ch == '}':
|
||||
if in_obj:
|
||||
brace -= 1
|
||||
buf.append('}')
|
||||
if brace == 0:
|
||||
yield ''.join(buf)
|
||||
in_obj = False
|
||||
else:
|
||||
# stray closing
|
||||
continue
|
||||
else:
|
||||
if in_obj:
|
||||
buf.append(ch)
|
||||
|
||||
|
||||
def load_embeddings(
|
||||
path: str,
|
||||
sample_size: Optional[int] = None,
|
||||
sampling_seed: int = 42,
|
||||
max_objects: Optional[int] = None,
|
||||
) -> List[EmbeddingRecord]:
|
||||
"""Load embeddings from a possibly large JSON / JSONL / raw file.
|
||||
|
||||
Args:
|
||||
path: file path
|
||||
sample_size: random sample (after load) if provided
|
||||
sampling_seed: RNG seed
|
||||
max_objects: hard cap to stop early (for speed)
|
||||
"""
|
||||
# size = os.path.getsize(path) # kích thước có thể dùng sau nếu muốn tối ưu đọc streaming
|
||||
# First attempt: JSON array
|
||||
records: List[EmbeddingRecord] = []
|
||||
def to_rec(obj) -> Optional[EmbeddingRecord]:
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
if 'embedding' in obj:
|
||||
fp = str(obj.get('filepath') or obj.get('file_path') or obj.get('path') or '')
|
||||
emb = obj['embedding']
|
||||
if isinstance(emb, list) and fp:
|
||||
return EmbeddingRecord(fp, emb)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
text_stripped = text.strip()
|
||||
if text_stripped.startswith('[') and text_stripped.endswith(']'):
|
||||
arr = json.loads(text_stripped)
|
||||
for obj in arr:
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
else:
|
||||
raise ValueError('Not a JSON array')
|
||||
except Exception:
|
||||
# Retry as JSON lines
|
||||
records = []
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line_strip = line.strip().rstrip(',')
|
||||
if not line_strip:
|
||||
continue
|
||||
if not line_strip.startswith('{'):
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line_strip)
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not records:
|
||||
raise ValueError('No JSONL records')
|
||||
except Exception:
|
||||
# Fallback: heuristic extraction
|
||||
records = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
raw = f.read()
|
||||
for obj_str in _smart_json_object_stream(raw):
|
||||
if 'embedding' not in obj_str:
|
||||
continue
|
||||
# Clean possible trailing ',"
|
||||
try:
|
||||
# Attempt to fix malformed numbers like '1.2\n421875' (broken newline) by removing stray newlines inside arrays
|
||||
fixed = re.sub(r"(\d)\n(\d)", r"\1\2", obj_str)
|
||||
obj = json.loads(fixed)
|
||||
except Exception:
|
||||
continue
|
||||
rec = to_rec(obj)
|
||||
if rec:
|
||||
records.append(rec)
|
||||
if max_objects and len(records) >= max_objects:
|
||||
break
|
||||
if not records:
|
||||
raise RuntimeError("Không load được embedding nào từ file.")
|
||||
|
||||
# Random sample if needed
|
||||
if sample_size and sample_size < len(records):
|
||||
random.seed(sampling_seed)
|
||||
records = random.sample(records, sample_size)
|
||||
return records
|
||||
|
||||
|
||||
def reduce_embeddings(
|
||||
X: np.ndarray,
|
||||
method: str,
|
||||
random_state: int = 42,
|
||||
umap_neighbors: int = 15,
|
||||
umap_min_dist: float = 0.1,
|
||||
tsne_perplexity: int = 30,
|
||||
tsne_learning_rate: str | float = 'auto',
|
||||
) -> Tuple[np.ndarray, dict]:
|
||||
"""Project high-dim embeddings to 3D.
|
||||
|
||||
Returns (coords (n,3), meta_info)
|
||||
"""
|
||||
meta = {"method": method}
|
||||
if method == 'PCA':
|
||||
pca = PCA(n_components=3, random_state=random_state)
|
||||
coords = pca.fit_transform(X)
|
||||
meta['explained_variance_ratio'] = pca.explained_variance_ratio_.tolist()
|
||||
return coords, meta
|
||||
if method == 'UMAP':
|
||||
if umap is None:
|
||||
raise RuntimeError("Chưa cài umap-learn: pip install umap-learn")
|
||||
reducer = umap.UMAP(
|
||||
n_components=3,
|
||||
n_neighbors=umap_neighbors,
|
||||
min_dist=umap_min_dist,
|
||||
metric='cosine',
|
||||
random_state=random_state,
|
||||
)
|
||||
coords = reducer.fit_transform(X)
|
||||
meta['umap_graph_connectivity'] = float(reducer.graph_.getnnz())
|
||||
return coords, meta
|
||||
if method == 't-SNE':
|
||||
perplexity = min(tsne_perplexity, max(5, (X.shape[0] - 1) // 3))
|
||||
tsne = TSNE(
|
||||
n_components=3,
|
||||
perplexity=perplexity,
|
||||
learning_rate=tsne_learning_rate,
|
||||
init='pca',
|
||||
random_state=random_state,
|
||||
n_iter=1000,
|
||||
verbose=0,
|
||||
)
|
||||
coords = tsne.fit_transform(X)
|
||||
meta['effective_perplexity'] = perplexity
|
||||
return coords, meta
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
|
||||
|
||||
def kmeans_cluster(coords: np.ndarray, n_clusters: int, seed: int = 42) -> Tuple[np.ndarray, float]:
|
||||
if n_clusters <= 1:
|
||||
return np.zeros(coords.shape[0], dtype=int), float('nan')
|
||||
km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=seed)
|
||||
labels = km.fit_predict(coords)
|
||||
score = float('nan')
|
||||
if len(set(labels)) > 1 and coords.shape[0] >= n_clusters * 5:
|
||||
try:
|
||||
score = silhouette_score(coords, labels)
|
||||
except Exception:
|
||||
pass
|
||||
return labels, score
|
||||
|
||||
|
||||
def build_dataframe(recs: List[EmbeddingRecord]) -> pd.DataFrame:
|
||||
return pd.DataFrame({
|
||||
'filepath': [r.filepath for r in recs],
|
||||
'embedding': [r.embedding for r in recs],
|
||||
})
|
||||
|
||||
|
||||
def load_cluster_file(
|
||||
path: str,
|
||||
expected_n: int,
|
||||
noise_label: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Load cluster labels from a JSON result file.
|
||||
|
||||
Supports formats:
|
||||
- {"results": [ ... ]}
|
||||
- [ ... ]
|
||||
Each item may contain one of: cluster, cluster_id, label, is_noise, filepath.
|
||||
If only is_noise exists: non-noise -> 0, noise -> noise_label.
|
||||
If filepath present, mapping is done by filepath, otherwise by index order.
|
||||
"""
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Không đọc được file cluster: {e}')
|
||||
|
||||
if isinstance(content, dict) and 'results' in content:
|
||||
items = content['results']
|
||||
elif isinstance(content, list):
|
||||
items = content
|
||||
else:
|
||||
raise RuntimeError('Định dạng file cluster không hợp lệ (cần list hoặc có key "results").')
|
||||
|
||||
# Detect if filepath-based mapping
|
||||
use_filepath = any(isinstance(it, dict) and 'filepath' in it for it in items)
|
||||
|
||||
labels = np.full(expected_n, noise_label, dtype=int)
|
||||
if use_filepath:
|
||||
# Build path->label
|
||||
mapping = {}
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
continue
|
||||
fp = it.get('filepath')
|
||||
if not fp:
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
mapping[str(fp)] = int(val)
|
||||
except Exception:
|
||||
continue
|
||||
return labels, mapping # second value used later to map onto df
|
||||
|
||||
# Index-based mapping
|
||||
collected = []
|
||||
for it in items:
|
||||
if not isinstance(it, dict):
|
||||
# accept raw int labels
|
||||
if isinstance(it, int):
|
||||
collected.append(int(it))
|
||||
continue
|
||||
if 'cluster' in it:
|
||||
val = it['cluster']
|
||||
elif 'cluster_id' in it:
|
||||
val = it['cluster_id']
|
||||
elif 'label' in it:
|
||||
val = it['label']
|
||||
elif 'is_noise' in it:
|
||||
val = (0 if not it.get('is_noise') else noise_label)
|
||||
else:
|
||||
val = 0
|
||||
try:
|
||||
collected.append(int(val))
|
||||
except Exception:
|
||||
collected.append(noise_label)
|
||||
|
||||
for i in range(min(expected_n, len(collected))):
|
||||
labels[i] = collected[i]
|
||||
return labels, None
|
||||
|
||||
|
||||
def main(): # pragma: no cover - Streamlit entry
|
||||
st.set_page_config(page_title="Embedding 3D Viewer", layout="wide")
|
||||
st.title("🔍 Embedding 3D Viewer (Multi)")
|
||||
st.caption("Mỗi lần nạp file sẽ thêm một đồ thị mới ở bên dưới.")
|
||||
|
||||
if 'plots' not in st.session_state:
|
||||
st.session_state['plots']: List[Dict[str, Any]] = [] # type: ignore
|
||||
if 'cache_loaded' not in st.session_state:
|
||||
st.session_state['cache_loaded'] = False
|
||||
|
||||
CACHE_DIR = '.visual_cache'
|
||||
INDEX_FILE = os.path.join(CACHE_DIR, 'index.json')
|
||||
|
||||
def ensure_cache_dir():
|
||||
if not os.path.isdir(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
|
||||
def save_cache():
|
||||
# Persist plots metadata + data (CSV) for future sessions
|
||||
ensure_cache_dir()
|
||||
index: List[Dict[str, Any]] = []
|
||||
for i, p in enumerate(st.session_state['plots']):
|
||||
# Data handling
|
||||
data_csv_path = p.get('data_csv_path')
|
||||
if not data_csv_path:
|
||||
# Build a dataframe back from figure data or stored head (we stored csv bytes earlier)
|
||||
# We stored only partial head in 'df_head', so we keep an internal full CSV bytes previously saved in 'csv'
|
||||
# Instead, we require full dataset at creation; we saved it in 'csv'. We'll re-use it.
|
||||
data_csv_path = os.path.join(CACHE_DIR, f'plot_{i+1}.csv')
|
||||
try:
|
||||
with open(data_csv_path, 'wb') as fcsv:
|
||||
fcsv.write(p['csv'])
|
||||
p['data_csv_path'] = data_csv_path
|
||||
except Exception:
|
||||
continue
|
||||
index.append({
|
||||
'path': p.get('path'),
|
||||
'name': p.get('name'),
|
||||
'algo': p.get('algo'),
|
||||
'sil_msg': p.get('sil_msg'),
|
||||
'meta': p.get('meta'),
|
||||
'marker_size': p.get('marker_size', 5),
|
||||
'palette_name': p.get('palette_name'),
|
||||
'data_csv': os.path.basename(data_csv_path),
|
||||
'timestamp': p.get('timestamp'),
|
||||
})
|
||||
try:
|
||||
with open(INDEX_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump({'plots': index, 'saved_at': datetime.utcnow().isoformat()}, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
st.warning(f'Lưu cache lỗi: {e}')
|
||||
|
||||
def load_cache():
|
||||
if st.session_state['cache_loaded']:
|
||||
return
|
||||
if not os.path.isfile(INDEX_FILE):
|
||||
st.session_state['cache_loaded'] = True
|
||||
return
|
||||
try:
|
||||
with open(INDEX_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
plots_meta = data.get('plots', [])
|
||||
for meta_entry in plots_meta:
|
||||
csv_file = os.path.join(CACHE_DIR, meta_entry.get('data_csv',''))
|
||||
if not os.path.isfile(csv_file):
|
||||
continue
|
||||
try:
|
||||
df = pd.read_csv(csv_file)
|
||||
except Exception:
|
||||
continue
|
||||
# Reconstruct figure
|
||||
import plotly.express as px
|
||||
color_display_col = 'color_display' if 'color_display' in df.columns else ('cluster_display' if 'cluster_display' in df.columns else 'cluster')
|
||||
# Build palette again based on unique groups and stored palette_name
|
||||
palette_name_load = meta_entry.get('palette_name','auto')
|
||||
def get_base_palette(name: str) -> List[str]:
|
||||
name = name.lower()
|
||||
if name == 'okabe-ito':
|
||||
return ["#000000","#E69F00","#56B4E9","#009E73","#F0E442","#0072B2","#D55E00","#CC79A7"]
|
||||
if name == 'tol':
|
||||
return ["#4477AA","#66CCEE","#228833","#CCBB44","#EE6677","#AA3377","#BBBBBB","#000000","#EEDD88","#FFAABB","#99DDFF","#44BB99"]
|
||||
if name == 'plotly':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Plotly
|
||||
if name == 'bold':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Bold
|
||||
if name == 'dark24':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Dark24
|
||||
if name == 'set3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Set3
|
||||
if name == 'd3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.D3
|
||||
return []
|
||||
def generate_hcl_like(n: int) -> List[str]:
|
||||
cols = []
|
||||
for i in range(n):
|
||||
h = (i / n) % 1.0
|
||||
s = 0.55 + 0.35 * ((i * 37) % 2)
|
||||
v = 0.85 if (i % 3) else 0.98
|
||||
r, g, b = colorsys.hsv_to_rgb(h, s, v)
|
||||
cols.append('#%02X%02X%02X' % (int(r * 255), int(g * 255), int(b * 255)))
|
||||
return cols
|
||||
def build_palette(name: str, k: int) -> List[str]:
|
||||
if name == 'auto':
|
||||
return []
|
||||
if name == 'hcl':
|
||||
return generate_hcl_like(k)
|
||||
base = get_base_palette(name)
|
||||
if k <= len(base):
|
||||
return base[:k]
|
||||
cols = list(base)
|
||||
gold = 0.61803398875
|
||||
h = 0.1
|
||||
while len(cols) < k:
|
||||
h = (h + gold) % 1.0
|
||||
r, g, b = colorsys.hsv_to_rgb(h, 0.6, 0.95)
|
||||
newc = '#%02X%02X%02X' % (int(r * 255), int(g * 255), int(b * 255))
|
||||
if newc not in cols:
|
||||
cols.append(newc)
|
||||
return cols
|
||||
groups = df[color_display_col].astype(str).unique() if color_display_col in df.columns else []
|
||||
palette_seq = build_palette(palette_name_load, len(groups)) if palette_name_load else []
|
||||
fig = px.scatter_3d(
|
||||
df,
|
||||
x='x', y='y', z='z',
|
||||
color=color_display_col if color_display_col in df.columns else None,
|
||||
color_discrete_sequence=palette_seq if palette_seq else None,
|
||||
hover_data={'filepath': True, 'cluster': True},
|
||||
title=f'{os.path.basename(meta_entry.get("path",""))} ({meta_entry.get("algo")})',
|
||||
opacity=0.9,
|
||||
height=700,
|
||||
)
|
||||
fig.update_traces(marker={'size': int(meta_entry.get('marker_size',5))})
|
||||
fig.update_layout(margin=dict(l=0,r=0,t=40,b=0))
|
||||
# Reconstruct csv bytes
|
||||
with open(csv_file, 'rb') as fcsv:
|
||||
csv_bytes = fcsv.read()
|
||||
st.session_state['plots'].append({
|
||||
'path': meta_entry.get('path'),
|
||||
'name': meta_entry.get('name'),
|
||||
'fig': fig,
|
||||
'algo': meta_entry.get('algo'),
|
||||
'meta': meta_entry.get('meta',{}),
|
||||
'sil_msg': meta_entry.get('sil_msg'),
|
||||
'csv': csv_bytes,
|
||||
'n': int(df.shape[0]),
|
||||
'dim': int(len([c for c in df.columns if c.startswith('x') or c.startswith('y') or c.startswith('z')]) or 3),
|
||||
'palette_name': meta_entry.get('palette_name'),
|
||||
'marker_size': meta_entry.get('marker_size',5),
|
||||
'data_csv_path': csv_file,
|
||||
'timestamp': meta_entry.get('timestamp'),
|
||||
'df_head': None, # không lưu raw để nhẹ; None cho an toàn
|
||||
})
|
||||
st.session_state['cache_loaded'] = True
|
||||
except Exception as e:
|
||||
st.warning(f'Lỗi load cache: {e}')
|
||||
st.session_state['cache_loaded'] = True
|
||||
|
||||
# Load cache only once per session (initial render)
|
||||
load_cache()
|
||||
|
||||
with st.sidebar:
|
||||
st.markdown('### ⚙️ Cấu hình chung')
|
||||
default_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
path = st.text_input('Đường dẫn file embedding', value=default_path, key='path_input')
|
||||
plot_name_input = st.text_input('Tên plot (tuỳ chọn)', value='', key='plot_name')
|
||||
sample_size = st.number_input('Sample (0 = all)', min_value=0, value=0, step=100, key='sample_size')
|
||||
max_objects = st.number_input('Max objects đọc (0 = no limit)', min_value=0, value=0, step=500, key='max_objects')
|
||||
seed = st.number_input('Seed', min_value=0, value=42, step=1, key='seed')
|
||||
show_raw = st.checkbox('Hiện bảng raw (chỉ 100 dòng đầu mỗi plot)', value=False, key='show_raw')
|
||||
|
||||
algo = st.selectbox('Thuật toán giảm chiều', ['UMAP', 'PCA', 't-SNE'], index=0, key='algo')
|
||||
if algo == 'UMAP':
|
||||
umap_neighbors = st.slider('UMAP n_neighbors', 5, 100, 15, 1, key='umap_neighbors')
|
||||
umap_min_dist = st.slider('UMAP min_dist', 0.0, 1.0, 0.1, 0.01, key='umap_min_dist')
|
||||
tsne_perplexity = 30
|
||||
elif algo == 't-SNE':
|
||||
tsne_perplexity = st.slider('t-SNE perplexity', 5, 100, 30, 1, key='tsne_perplexity')
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
else:
|
||||
umap_neighbors = 15
|
||||
umap_min_dist = 0.1
|
||||
tsne_perplexity = 30
|
||||
|
||||
st.markdown('### 🎨 Màu sắc & Cluster')
|
||||
cluster_source = st.radio('Nguồn cluster', ['KMeans','Load file','None'], index=0, horizontal=False, key='cluster_source')
|
||||
n_clusters = st.slider('Số cluster KMeans', 2, 100, 10, 1, disabled=(cluster_source != 'KMeans'), key='n_clusters')
|
||||
cluster_file_path = st.text_input('File cluster JSON', value='cluster/dbscan_results.json', disabled=(cluster_source != 'Load file'), key='cluster_file')
|
||||
noise_label = st.number_input('Noise label', value=-1, step=1, disabled=(cluster_source != 'Load file'), key='noise_label')
|
||||
path_filter = st.text_input('Filter filepath (comma OR)', value='', key='path_filter')
|
||||
color_by_substring = st.text_input('Color theo substring', value='', key='color_by_substring')
|
||||
palette_name = st.selectbox('Bảng màu', ['auto','okabe-ito','tol','plotly','bold','dark24','set3','d3','hcl'], index=0, help='okabe-ito, tol = thân thiện mù màu', key='palette')
|
||||
marker_size = st.slider('Kích thước điểm', 2, 15, 3, 1, key='marker_size')
|
||||
st.caption('File cluster chỉ có is_noise: non-noise -> 0, noise -> noise_label.')
|
||||
|
||||
add_btn = st.button('➕ Thêm plot', type='primary')
|
||||
clear_btn = st.button('🧹 Xoá tất cả plot')
|
||||
|
||||
if clear_btn:
|
||||
st.session_state['plots'].clear()
|
||||
# Clear cache directory
|
||||
try:
|
||||
if os.path.isdir(CACHE_DIR):
|
||||
for f in os.listdir(CACHE_DIR):
|
||||
try:
|
||||
os.remove(os.path.join(CACHE_DIR,f))
|
||||
except Exception:
|
||||
pass
|
||||
os.rmdir(CACHE_DIR)
|
||||
except Exception:
|
||||
pass
|
||||
st.success('Đã xoá toàn bộ plot (và cache).')
|
||||
|
||||
def process_and_add_plot():
|
||||
path_local = path
|
||||
if not os.path.isfile(path_local):
|
||||
st.error(f'File không tồn tại: {path_local}')
|
||||
return
|
||||
with st.spinner(f'Loading embeddings: {os.path.basename(path_local)}'):
|
||||
recs = load_embeddings(
|
||||
path_local,
|
||||
sample_size=sample_size or None,
|
||||
sampling_seed=int(seed),
|
||||
max_objects=max_objects or None,
|
||||
)
|
||||
df = build_dataframe(recs)
|
||||
if df.empty:
|
||||
st.warning('Không có embedding nào.')
|
||||
return
|
||||
# Filter
|
||||
if path_filter.strip():
|
||||
tokens = [t.strip() for t in path_filter.split(',') if t.strip()]
|
||||
if tokens:
|
||||
mask = df['filepath'].apply(lambda p: any(tok.lower() in p.lower() for tok in tokens))
|
||||
df = df[mask].reset_index(drop=True)
|
||||
if df.empty:
|
||||
st.warning('Sau filter không còn bản ghi.')
|
||||
return
|
||||
# Reduce
|
||||
X = np.vstack(df['embedding'].values).astype(np.float32)
|
||||
with st.spinner('Giảm chiều...'):
|
||||
coords, meta = reduce_embeddings(
|
||||
X,
|
||||
algo,
|
||||
random_state=int(seed),
|
||||
umap_neighbors=umap_neighbors,
|
||||
umap_min_dist=umap_min_dist,
|
||||
tsne_perplexity=tsne_perplexity,
|
||||
)
|
||||
df[['x','y','z']] = coords
|
||||
# Clustering
|
||||
if cluster_source == 'KMeans':
|
||||
with st.spinner('KMeans...'):
|
||||
labels, sil = kmeans_cluster(coords, n_clusters, int(seed))
|
||||
df['cluster'] = labels
|
||||
sil_msg = f'Silhouette: {sil:.4f}' if not math.isnan(sil) else 'Silhouette: N/A'
|
||||
elif cluster_source == 'Load file':
|
||||
if not os.path.isfile(cluster_file_path):
|
||||
st.warning(f'Không thấy file cluster: {cluster_file_path}; gán -1')
|
||||
df['cluster'] = -1
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
loaded_labels, mapping = load_cluster_file(cluster_file_path, len(df), noise_label=int(noise_label))
|
||||
if mapping is not None:
|
||||
labs = []
|
||||
miss=0
|
||||
for fp in df['filepath']:
|
||||
val = mapping.get(fp, int(noise_label))
|
||||
if fp not in mapping:
|
||||
miss += 1
|
||||
labs.append(val)
|
||||
if miss:
|
||||
st.info(f'{miss} filepath gán noise.')
|
||||
df['cluster'] = np.array(labs, dtype=int)
|
||||
else:
|
||||
df['cluster'] = loaded_labels
|
||||
uniq = set(df['cluster'])
|
||||
if len([u for u in uniq if u != int(noise_label)]) > 1:
|
||||
try:
|
||||
mask_valid = df['cluster'].to_numpy() != int(noise_label)
|
||||
sil = silhouette_score(coords[mask_valid], df.loc[mask_valid,'cluster'])
|
||||
sil_msg = f'Silhouette(ex noise): {sil:.4f}'
|
||||
except Exception:
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
sil_msg = 'Silhouette: N/A'
|
||||
else:
|
||||
df['cluster'] = -1
|
||||
sil_msg = 'No clustering'
|
||||
|
||||
# Color grouping
|
||||
if color_by_substring.strip():
|
||||
subs = [s.strip() for s in color_by_substring.split(',') if s.strip()]
|
||||
def color_from_sub(p: str) -> str:
|
||||
for ssub in subs:
|
||||
if ssub.lower() in p.lower():
|
||||
return ssub
|
||||
return 'other'
|
||||
df['color_group'] = df['filepath'].apply(color_from_sub)
|
||||
color_col = 'color_group'
|
||||
else:
|
||||
color_col = 'cluster'
|
||||
|
||||
# Palette utilities (copy of earlier helpers)
|
||||
def get_base_palette(name: str) -> List[str]:
|
||||
name = name.lower()
|
||||
if name == 'okabe-ito':
|
||||
return ["#000000","#E69F00","#56B4E9","#009E73","#F0E442","#0072B2","#D55E00","#CC79A7"]
|
||||
if name == 'tol':
|
||||
return ["#4477AA","#66CCEE","#228833","#CCBB44","#EE6677","#AA3377","#BBBBBB","#000000","#EEDD88","#FFAABB","#99DDFF","#44BB99"]
|
||||
if name == 'plotly':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Plotly
|
||||
if name == 'bold':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Bold
|
||||
if name == 'dark24':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Dark24
|
||||
if name == 'set3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.Set3
|
||||
if name == 'd3':
|
||||
from plotly.colors import qualitative as q
|
||||
return q.D3
|
||||
return []
|
||||
def generate_hcl_like(n: int) -> List[str]:
|
||||
cols=[]
|
||||
for i in range(n):
|
||||
h=(i/n)%1.0
|
||||
s=0.55+0.35*((i*37)%2)
|
||||
v=0.85 if (i%3) else 0.98
|
||||
r,g,b = colorsys.hsv_to_rgb(h,s,v)
|
||||
cols.append('#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255)))
|
||||
return cols
|
||||
def build_palette(name: str, k: int) -> List[str]:
|
||||
if name == 'auto':
|
||||
return []
|
||||
if name == 'hcl':
|
||||
return generate_hcl_like(k)
|
||||
base = get_base_palette(name)
|
||||
if k <= len(base):
|
||||
return base[:k]
|
||||
cols=list(base)
|
||||
gold=0.61803398875
|
||||
h=0.1
|
||||
while len(cols) < k:
|
||||
h=(h+gold)%1.0
|
||||
r,g,b = colorsys.hsv_to_rgb(h,0.6,0.95)
|
||||
newc='#%02X%02X%02X' % (int(r*255), int(g*255), int(b*255))
|
||||
if newc not in cols:
|
||||
cols.append(newc)
|
||||
return cols
|
||||
color_display_col = color_col + '_display'
|
||||
if np.issubdtype(df[color_col].dtype, np.number):
|
||||
df[color_display_col] = df[color_col].astype(int).astype(str)
|
||||
else:
|
||||
df[color_display_col] = df[color_col].astype(str)
|
||||
groups = df[color_display_col].unique()
|
||||
palette_seq = build_palette(palette_name, len(groups)) if palette_name else []
|
||||
import plotly.express as px
|
||||
fig = px.scatter_3d(
|
||||
df,
|
||||
x='x', y='y', z='z',
|
||||
color=color_display_col,
|
||||
color_discrete_sequence=palette_seq if palette_seq else None,
|
||||
hover_data={'filepath': True, 'cluster': True, 'x': ':.2f', 'y': ':.2f', 'z': ':.2f'},
|
||||
title=f'{os.path.basename(path_local)} ({algo})',
|
||||
opacity=0.9,
|
||||
height=700,
|
||||
)
|
||||
fig.update_traces(marker={'size': int(marker_size)})
|
||||
fig.update_layout(margin=dict(l=0,r=0,t=40,b=0))
|
||||
out_csv = df[['filepath','x','y','z','cluster']].to_csv(index=False).encode('utf-8')
|
||||
st.session_state['plots'].append({
|
||||
'path': path_local,
|
||||
'name': plot_name_input.strip() or os.path.basename(path_local),
|
||||
'df_head': df.head(100) if show_raw else None,
|
||||
'fig': fig,
|
||||
'meta': meta,
|
||||
'algo': algo,
|
||||
'sil_msg': sil_msg,
|
||||
'csv': out_csv,
|
||||
'n': len(df),
|
||||
'dim': len(df['embedding'].iloc[0]),
|
||||
'palette_name': palette_name,
|
||||
'marker_size': marker_size,
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
})
|
||||
# Persist new state
|
||||
save_cache()
|
||||
st.success(f'Đã thêm plot: {os.path.basename(path_local)}')
|
||||
|
||||
if add_btn:
|
||||
process_and_add_plot()
|
||||
|
||||
# Render existing plots
|
||||
if st.session_state['plots']:
|
||||
st.markdown('---')
|
||||
for idx, plot_data in enumerate(st.session_state['plots']):
|
||||
container = st.container()
|
||||
with container:
|
||||
cols = st.columns([0.8,0.2])
|
||||
with cols[0]:
|
||||
st.subheader(f'#{idx+1} {plot_data.get("name", os.path.basename(plot_data["path"]))}')
|
||||
with cols[1]:
|
||||
if st.button('❌ Xoá', key=f'remove_{idx}'):
|
||||
st.session_state['plots'].pop(idx)
|
||||
save_cache()
|
||||
st.experimental_rerun()
|
||||
# Rename inline
|
||||
new_name = st.text_input('Đổi tên plot', value=plot_data.get('name', ''), key=f'rename_{idx}')
|
||||
if new_name.strip() and new_name.strip() != plot_data.get('name'):
|
||||
plot_data['name'] = new_name.strip()
|
||||
save_cache()
|
||||
st.caption(f"Embeddings: {plot_data['n']} | Dim gốc: {plot_data['dim']} | {plot_data['sil_msg']}")
|
||||
st.plotly_chart(plot_data['fig'], use_container_width=True)
|
||||
with st.expander('Meta / Thông tin thêm'):
|
||||
st.json(plot_data['meta'])
|
||||
st.download_button('⬇️ CSV', plot_data['csv'], file_name=f"embedding_3d_{idx+1}.csv", mime='text/csv', key=f'dl_{idx}')
|
||||
df_head_cached = plot_data.get('df_head')
|
||||
if df_head_cached is not None:
|
||||
st.dataframe(df_head_cached)
|
||||
if not st.session_state['plots']:
|
||||
st.info('Chưa có plot nào. Chọn file ở sidebar và nhấn "➕ Thêm plot".')
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: no cover
|
||||
# Khi chạy bằng 'streamlit run', sys.argv chỉ chứa tên file => ta luôn gọi main()
|
||||
# Nếu muốn test nhanh CLI, có thể thêm arg '--cli-test'
|
||||
if '--cli-test' in os.sys.argv:
|
||||
test_path = 'embeddings_factures_osteopathie_1k_qwen.json'
|
||||
if os.path.exists(test_path):
|
||||
recs = load_embeddings(test_path, sample_size=5)
|
||||
print(f'[CLI TEST] Loaded {len(recs)} embeddings dim={len(recs[0].embedding)}')
|
||||
else:
|
||||
print('[CLI TEST] Không tìm thấy file test.')
|
||||
else:
|
||||
main()
|
||||
|
Reference in New Issue
Block a user