328 lines
13 KiB
Python
328 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Threshold analysis for DeQA scores vs. human labels (High/Low).
|
|
|
|
Inputs (defaults for facture task):
|
|
- results/facture.txt # lines like: "4.2 - filename.jpg"
|
|
- data/facture/labels.csv # columns: filename,label with label in {High,Low}
|
|
|
|
Outputs:
|
|
- results/facture_thresholds_summary.json # best thresholds for accuracy/precision/recall/F1
|
|
- results/facture_metric_curves.png # metrics vs threshold
|
|
- results/facture_score_distributions.png # score histograms by label
|
|
- results/facture_decisions.csv # per-image decisions at each operating point
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
|
|
def read_deqa_results_txt(path: Path) -> pd.DataFrame:
|
|
"""Read TXT results of the form "<score> - <filename>" into a DataFrame."""
|
|
rows: List[Dict[str, str | float]] = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
# Expect pattern: "<score> - <filename>"
|
|
try:
|
|
score_part, fname = line.split(" - ", 1)
|
|
score = float(score_part)
|
|
rows.append({"filename": fname, "score": score})
|
|
except Exception:
|
|
# Skip malformed lines silently
|
|
continue
|
|
df = pd.DataFrame(rows)
|
|
if not df.empty:
|
|
df["filename"] = df["filename"].astype(str)
|
|
df["stem"] = df["filename"].apply(lambda x: Path(x).stem.lower())
|
|
return df
|
|
|
|
|
|
def read_labels_csv(path: Path) -> pd.DataFrame:
|
|
"""Read labels CSV with columns: filename,label (High/Low)."""
|
|
df = pd.read_csv(path)
|
|
# Normalize
|
|
df["filename"] = df["filename"].astype(str)
|
|
df["label"] = df["label"].astype(str).str.strip().str.capitalize()
|
|
# Map High->1, Low->0
|
|
label_map = {"High": 1, "Low": 0}
|
|
df["y_true"] = df["label"].map(label_map)
|
|
df["stem"] = df["filename"].apply(lambda x: Path(x).stem.lower())
|
|
return df[["filename", "label", "y_true", "stem"]]
|
|
|
|
|
|
def confusion_from_threshold(scores: np.ndarray, y_true: np.ndarray, thr: float) -> Tuple[int, int, int, int]:
|
|
pred = (scores >= thr).astype(int)
|
|
tp = int(np.sum((pred == 1) & (y_true == 1)))
|
|
fp = int(np.sum((pred == 1) & (y_true == 0)))
|
|
fn = int(np.sum((pred == 0) & (y_true == 1)))
|
|
tn = int(np.sum((pred == 0) & (y_true == 0)))
|
|
return tp, fp, fn, tn
|
|
|
|
|
|
def metric_from_confusion(tp: int, fp: int, fn: int, tn: int, metric: str) -> float:
|
|
if metric == "accuracy":
|
|
denom = tp + fp + fn + tn
|
|
return (tp + tn) / denom if denom > 0 else 0.0
|
|
if metric == "precision":
|
|
denom = tp + fp
|
|
return tp / denom if denom > 0 else 0.0
|
|
if metric == "recall":
|
|
denom = tp + fn
|
|
return tp / denom if denom > 0 else 0.0
|
|
if metric == "f1":
|
|
p_denom = tp + fp
|
|
r_denom = tp + fn
|
|
precision = tp / p_denom if p_denom > 0 else 0.0
|
|
recall = tp / r_denom if r_denom > 0 else 0.0
|
|
denom = precision + recall
|
|
return (2 * precision * recall / denom) if denom > 0 else 0.0
|
|
raise ValueError(f"Unsupported metric: {metric}")
|
|
|
|
|
|
def pick_threshold(scores: np.ndarray, y_true: np.ndarray, metric: str = "f1") -> Tuple[float, float, Dict[str, int]]:
|
|
thr_candidates = np.unique(scores)
|
|
best_thr: float | None = None
|
|
best_val: float = -1.0
|
|
best_conf: Tuple[int, int, int, int] | None = None
|
|
|
|
for t in thr_candidates:
|
|
tp, fp, fn, tn = confusion_from_threshold(scores, y_true, t)
|
|
val = metric_from_confusion(tp, fp, fn, tn, metric)
|
|
# Tie-breaker: prefer higher threshold if metric ties (safer for downstream)
|
|
if (val > best_val) or (np.isclose(val, best_val) and (best_thr is None or t > best_thr)):
|
|
best_val = val
|
|
best_thr = t
|
|
best_conf = (tp, fp, fn, tn)
|
|
|
|
assert best_thr is not None and best_conf is not None
|
|
tp, fp, fn, tn = best_conf
|
|
return float(best_thr), float(best_val), {"TP": tp, "FP": fp, "FN": fn, "TN": tn}
|
|
|
|
|
|
def compute_metric_curves(scores: np.ndarray, y_true: np.ndarray) -> pd.DataFrame:
|
|
data: List[Dict[str, float]] = []
|
|
for t in np.unique(scores):
|
|
tp, fp, fn, tn = confusion_from_threshold(scores, y_true, t)
|
|
row = {
|
|
"threshold": float(t),
|
|
"accuracy": metric_from_confusion(tp, fp, fn, tn, "accuracy"),
|
|
"precision": metric_from_confusion(tp, fp, fn, tn, "precision"),
|
|
"recall": metric_from_confusion(tp, fp, fn, tn, "recall"),
|
|
"f1": metric_from_confusion(tp, fp, fn, tn, "f1"),
|
|
"TP": tp,
|
|
"FP": fp,
|
|
"FN": fn,
|
|
"TN": tn,
|
|
}
|
|
data.append(row)
|
|
return pd.DataFrame(data).sort_values("threshold").reset_index(drop=True)
|
|
|
|
|
|
def plot_distributions(df: pd.DataFrame, out_path: Path) -> None:
|
|
plt.figure(figsize=(8, 5))
|
|
sns.histplot(data=df, x="score", hue="label", bins=30, kde=True, stat="density", common_norm=False)
|
|
plt.title("DeQA score distributions by label")
|
|
plt.xlabel("DeQA score")
|
|
plt.ylabel("Density")
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
|
|
def plot_metric_curves(curve_df: pd.DataFrame, out_path: Path) -> None:
|
|
plt.figure(figsize=(8, 5))
|
|
for metric in ["accuracy", "precision", "recall", "f1"]:
|
|
plt.plot(curve_df["threshold"], curve_df[metric], label=metric)
|
|
plt.xlabel("Threshold (score >= t => HIGH)")
|
|
plt.ylabel("Metric value")
|
|
plt.ylim(0.0, 1.05)
|
|
plt.title("Metrics vs threshold")
|
|
plt.legend()
|
|
plt.grid(True, alpha=0.3)
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
def plot_sorted_scores_with_threshold(df: pd.DataFrame, thr: float, out_path: Path) -> None:
|
|
tmp = df.sort_values("score").reset_index(drop=True)
|
|
x = np.arange(len(tmp))
|
|
y = tmp["score"].to_numpy()
|
|
plt.figure(figsize=(9, 4))
|
|
plt.scatter(x, y, s=6, alpha=0.6)
|
|
plt.axhline(thr, color="red", linestyle="--", label=f"threshold={thr:.3f}")
|
|
plt.xlabel("Images sorted by score")
|
|
plt.ylabel("DeQA score")
|
|
plt.title("Sorted scores with operating threshold")
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
def plot_pr_curve(curves: pd.DataFrame, out_path: Path) -> None:
|
|
plt.figure(figsize=(6, 5))
|
|
plt.plot(curves["recall"], curves["precision"], marker="o", ms=3, lw=1)
|
|
plt.xlabel("Recall")
|
|
plt.ylabel("Precision")
|
|
plt.title("Precision-Recall across thresholds")
|
|
plt.grid(True, alpha=0.3)
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
def plot_roc_like(curves: pd.DataFrame, out_path: Path) -> None:
|
|
# TPR=recall, FPR=FP/(FP+TN)
|
|
denom = (curves["FP"] + curves["TN"]).replace(0, np.nan)
|
|
fpr = curves["FP"] / denom
|
|
tpr = curves["recall"]
|
|
plt.figure(figsize=(6, 5))
|
|
plt.plot(fpr.fillna(0), tpr, marker="o", ms=3, lw=1)
|
|
plt.xlabel("False Positive Rate (FPR)")
|
|
plt.ylabel("True Positive Rate (TPR)")
|
|
plt.title("ROC-like curve across thresholds")
|
|
plt.grid(True, alpha=0.3)
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
def plot_confusion_heatmap(tp: int, fp: int, fn: int, tn: int, out_path: Path) -> None:
|
|
cm = np.array([[tp, fp],[fn, tn]])
|
|
plt.figure(figsize=(4, 4))
|
|
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
|
|
xticklabels=["Pred High","Pred Low"], yticklabels=["True High","True Low"])
|
|
plt.title("Confusion matrix at operating threshold")
|
|
plt.tight_layout()
|
|
plt.savefig(out_path, dpi=150)
|
|
plt.close()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Threshold analysis for DeQA scores vs labels")
|
|
parser.add_argument("--scores", type=str, default="results/facture.txt", help="Path to deqa scores txt")
|
|
parser.add_argument("--labels", type=str, default="data/facture/labels.csv", help="Path to labels csv")
|
|
parser.add_argument("--outdir", type=str, default="results", help="Directory to write outputs")
|
|
parser.add_argument("--sample-per-class", type=int, default=0,
|
|
help="If >0, randomly sample N High and N Low for a quick benchmark")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
|
|
args = parser.parse_args()
|
|
|
|
scores_path = Path(args.scores)
|
|
labels_path = Path(args.labels)
|
|
outdir = Path(args.outdir)
|
|
outdir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load
|
|
df_scores = read_deqa_results_txt(scores_path)
|
|
df_labels = read_labels_csv(labels_path)
|
|
|
|
# Join on lowercase stem to tolerate extension differences
|
|
df = df_scores.merge(df_labels, on="stem", how="inner", suffixes=("_score", "_label"))
|
|
# Prefer label-side filename when available
|
|
df["filename"] = df["filename_label"].where(df["filename_label"].notna(), df["filename_score"])
|
|
df.drop(columns=[c for c in ["filename_label", "filename_score"] if c in df.columns], inplace=True)
|
|
if df.empty:
|
|
raise RuntimeError("No overlap between scores and labels. Check filenames.")
|
|
|
|
# Optional sampling per class
|
|
if args.sample_per_class and args.sample_per_class > 0:
|
|
rng = np.random.default_rng(args.seed)
|
|
high_df = df[df["y_true"] == 1]
|
|
low_df = df[df["y_true"] == 0]
|
|
n_high = min(args.sample_per_class, len(high_df))
|
|
n_low = min(args.sample_per_class, len(low_df))
|
|
high_sample = high_df.sample(n=n_high, random_state=args.seed)
|
|
low_sample = low_df.sample(n=n_low, random_state=args.seed)
|
|
df = pd.concat([high_sample, low_sample], ignore_index=True)
|
|
df = df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True)
|
|
|
|
scores = df["score"].to_numpy(dtype=float)
|
|
y_true = df["y_true"].to_numpy(dtype=int)
|
|
|
|
# Compute best thresholds
|
|
thr_f1, best_f1, conf_f1 = pick_threshold(scores, y_true, metric="f1")
|
|
thr_acc, best_acc, conf_acc = pick_threshold(scores, y_true, metric="accuracy")
|
|
thr_prec, best_prec, conf_prec = pick_threshold(scores, y_true, metric="precision")
|
|
thr_rec, best_rec, conf_rec = pick_threshold(scores, y_true, metric="recall")
|
|
|
|
summary = {
|
|
"positive_definition": "HIGH when score >= threshold",
|
|
"best_thresholds": {
|
|
"f1": {"threshold": thr_f1, "value": best_f1, "confusion": conf_f1},
|
|
"accuracy": {"threshold": thr_acc, "value": best_acc, "confusion": conf_acc},
|
|
"precision": {"threshold": thr_prec, "value": best_prec, "confusion": conf_prec},
|
|
"recall": {"threshold": thr_rec, "value": best_rec, "confusion": conf_rec},
|
|
},
|
|
"counts": {
|
|
"total": int(len(df)),
|
|
"positives": int(df["y_true"].sum()),
|
|
"negatives": int(len(df) - int(df["y_true"].sum())),
|
|
},
|
|
}
|
|
|
|
# Metric curves and figures
|
|
curves = compute_metric_curves(scores, y_true)
|
|
plot_distributions(df, outdir / "facture_score_distributions.png")
|
|
plot_metric_curves(curves, outdir / "facture_metric_curves.png")
|
|
# Extra plots
|
|
plot_sorted_scores_with_threshold(df, thr_f1, outdir / "facture_sorted_scores_with_thr.png")
|
|
plot_pr_curve(curves, outdir / "facture_precision_recall_curve.png")
|
|
plot_roc_like(curves, outdir / "facture_roc_like_curve.png")
|
|
plot_confusion_heatmap(conf_f1["TP"], conf_f1["FP"], conf_f1["FN"], conf_f1["TN"], outdir / "facture_confusion_matrix.png")
|
|
|
|
# Decisions CSV (for three operating points + F1)
|
|
def decide(thr: float) -> np.ndarray:
|
|
return (scores >= thr).astype(int)
|
|
|
|
df_out = df.copy()
|
|
df_out["decision_f1"] = decide(thr_f1)
|
|
df_out["decision_acc"] = decide(thr_acc)
|
|
df_out["decision_prec"] = decide(thr_prec)
|
|
df_out["decision_rec"] = decide(thr_rec)
|
|
# Map 1/0 to textual action
|
|
to_action = {1: "implement", 0: "reject"}
|
|
for col in ["decision_f1", "decision_acc", "decision_prec", "decision_rec"]:
|
|
df_out[col] = df_out[col].map(to_action)
|
|
df_out.rename(columns={"score": "deqa_score"}, inplace=True)
|
|
df_out = df_out[["filename", "deqa_score", "label", "decision_f1", "decision_acc", "decision_prec", "decision_rec"]]
|
|
df_out.to_csv(outdir / "facture_decisions.csv", index=False)
|
|
|
|
# Save summary JSON
|
|
with open(outdir / "facture_thresholds_summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
# Save a single Excel file with one sheet containing all rows and decisions (F1 operating point)
|
|
try:
|
|
excel_path = outdir / "facture_deqa_images.xlsx"
|
|
one_sheet_df = df_out.copy()
|
|
# Keep core columns only
|
|
keep_cols = ["filename", "deqa_score", "label", "decision_f1"]
|
|
one_sheet_df = one_sheet_df[keep_cols]
|
|
one_sheet_df.rename(columns={"decision_f1": "decision"}, inplace=True)
|
|
with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
|
|
one_sheet_df.to_excel(writer, sheet_name="DeQA_Images", index=False)
|
|
except Exception as e:
|
|
print(f"Warning: Failed to write Excel file: {e}")
|
|
|
|
# Also print a concise console summary
|
|
print("Best thresholds (score >= thr => HIGH):")
|
|
for k in ["f1", "accuracy", "precision", "recall"]:
|
|
info = summary["best_thresholds"][k]
|
|
print(f"- {k}: thr={info['threshold']:.3f}, value={info['value']:.3f}, conf={info['confusion']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|