Files
IQA-Metric-Benchmark/scripts/pipeline_compare.py

255 lines
9.1 KiB
Python
Raw Normal View History

2025-09-11 09:39:02 +00:00
#!/usr/bin/env python3
"""
Compare hallucination across three pipelines over five preprocessing methods:
1) Raw: all images
2) DeQA-filtered: keep images with DeQA score >= threshold (default 2.6)
3) Human-filtered: keep images labeled High in CSV labels
Inputs:
- One or more per_sample_eval.json files (or per_image_anls.csv already generated)
- DeQA score file (txt): lines like "3.9 - image (9)_0.png"
- Human labels CSV with columns: filename,label where label in {High,Low}
Outputs:
- Combined means CSV: method vs mean hallucination for each pipeline
- Line chart (3 lines): hallucination mean per method across the three pipelines
"""
from __future__ import annotations
import argparse
import csv
import json
import re
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import pandas as pd
import matplotlib.pyplot as plt
def canonical_key(name: str) -> str:
"""Map various filenames to a canonical key used by per_sample_eval 'image' field.
Examples:
- "image (9)_0.png" -> "image (9)"
- "image (22)" -> "image (22)"
- "foo/bar/image (15)_3.jpg" -> "image (15)"
- other names -> stem without extension
"""
if not name:
return name
# Keep only basename
base = Path(name).name
# Try pattern image (N)
m = re.search(r"(image \(\d+\))", base, flags=re.IGNORECASE)
if m:
return m.group(1)
# Fallback: remove extension
return Path(base).stem
def read_deqa_scores(txt_path: Path) -> Dict[str, float]:
scores: Dict[str, float] = {}
with txt_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# Accept formats: "3.9 - filename" or "filename,3.9" etc.
m = re.match(r"\s*([0-9]+(?:\.[0-9]+)?)\s*[-,:]?\s*(.+)$", line)
if m:
score = float(m.group(1))
filename = m.group(2)
else:
parts = re.split(r"[,\t]", line)
if len(parts) >= 2:
try:
score = float(parts[1])
filename = parts[0]
except Exception:
continue
else:
continue
key = canonical_key(filename)
scores[key] = score
return scores
def read_human_labels(csv_path: Path) -> Dict[str, str]:
labels: Dict[str, str] = {}
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
filename = (row.get("filename") or row.get("file") or "").strip()
label = (row.get("label") or row.get("Label") or "").strip()
if not filename:
continue
key = canonical_key(filename)
if label:
labels[key] = label
return labels
def levenshtein_distance(a: str, b: str) -> int:
if a == b:
return 0
if len(a) == 0:
return len(b)
if len(b) == 0:
return len(a)
previous_row = list(range(len(b) + 1))
for i, ca in enumerate(a, start=1):
current_row = [i]
for j, cb in enumerate(b, start=1):
insertions = previous_row[j] + 1
deletions = current_row[j - 1] + 1
substitutions = previous_row[j - 1] + (0 if ca == cb else 1)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def normalized_similarity(pred: str, gt: str) -> float:
pred = pred or ""
gt = gt or ""
max_len = max(len(pred), len(gt))
if max_len == 0:
return 1.0
dist = levenshtein_distance(pred, gt)
sim = 1.0 - (dist / max_len)
if sim < 0.0:
return 0.0
if sim > 1.0:
return 1.0
return sim
def compute_anls_for_record(record: Dict) -> tuple[float, int]:
fields = record.get("fields") or []
if not isinstance(fields, list) or len(fields) == 0:
return 0.0, 0
sims: List[float] = []
for f in fields:
pred = str(f.get("pred", ""))
gt = str(f.get("gt", ""))
sims.append(normalized_similarity(pred, gt))
anls = float(sum(sims) / len(sims)) if sims else 0.0
return anls, len(sims)
def load_per_image_anls(input_json: Path) -> pd.DataFrame:
# Prefer existing per_image_anls.csv, otherwise compute quickly
per_image_csv = input_json.parent / "per_image_anls.csv"
if per_image_csv.exists():
df = pd.read_csv(per_image_csv)
return df
# Fallback: compute minimal ANLS like in the other script
with input_json.open("r", encoding="utf-8") as f:
data = json.load(f)
rows = []
for rec in data:
anls, num_fields = compute_anls_for_record(rec)
rows.append({
"image": rec.get("image"),
"anls": anls,
"hallucination_score": 1.0 - anls,
"num_fields": int(num_fields),
})
return pd.DataFrame(rows)
def main() -> None:
p = argparse.ArgumentParser(description="Compare hallucination across raw/DeQA/Human pipelines over methods")
p.add_argument("inputs", nargs="+", help="per_sample_eval.json files for each method")
p.add_argument("--deqa_txt", required=True, help="Path to DeQA scores txt (e.g., cni.txt)")
p.add_argument("--human_csv", required=True, help="Path to human labels CSV")
p.add_argument("--deqa_threshold", type=float, default=2.6, help="DeQA threshold (>=)")
args = p.parse_args()
# Load filters
deqa_scores = read_deqa_scores(Path(args.deqa_txt))
human_labels = read_human_labels(Path(args.human_csv))
# Aggregate per method
method_to_df: Dict[str, pd.DataFrame] = {}
for ip in args.inputs:
path = Path(ip)
df = load_per_image_anls(path)
df["method"] = path.parent.name
df["image_key"] = df["image"].apply(canonical_key)
method_to_df[path.parent.name] = df
# Compute means per pipeline (fair comparison: set excluded images to hallucination=0)
records = []
for method, df in method_to_df.items():
raw_mean = float(df["hallucination_score"].mean()) if len(df) else float("nan")
# DeQA filter: mark DeQA < threshold as hallucination=0, keep all images
df_deqa = df.copy()
mask_deqa = df_deqa["image_key"].map(lambda k: deqa_scores.get(k, None))
# Set hallucination=0 for images with DeQA < threshold (or missing DeQA)
df_deqa.loc[mask_deqa.isna() | (mask_deqa < args.deqa_threshold), "hallucination_score"] = 0.0
deqa_mean = float(df_deqa["hallucination_score"].mean()) if len(df_deqa) else float("nan")
# Human filter: mark Low labels as hallucination=0, keep all images
df_human = df.copy()
mask_human = df_human["image_key"].map(lambda k: human_labels.get(k, "").lower())
# Set hallucination=0 for images labeled Low (or missing label)
df_human.loc[mask_human != "high", "hallucination_score"] = 0.0
human_mean = float(df_human["hallucination_score"].mean()) if len(df_human) else float("nan")
records.append({
"method": method,
"raw_mean": raw_mean,
"deqa_mean": deqa_mean,
"human_mean": human_mean,
"raw_count": int(len(df)),
"deqa_count": int(len(df_deqa)), # Now equal to raw_count
"human_count": int(len(df_human)), # Now equal to raw_count
})
outdir = Path(args.inputs[0]).parent.parent / "combined_anls" / "pipeline"
outdir.mkdir(parents=True, exist_ok=True)
out_csv = outdir / "pipeline_means.csv"
means_df = pd.DataFrame(records).sort_values("method")
means_df.to_csv(out_csv, index=False)
# 3-line comparison plot over methods (narrower with score annotations)
x = range(len(means_df))
plt.figure(figsize=(7, 5))
# Plot lines and add score annotations
raw_vals = means_df["raw_mean"].values
deqa_vals = means_df["deqa_mean"].values
human_vals = means_df["human_mean"].values
plt.plot(x, raw_vals, marker="o", label="Raw", linewidth=2, markersize=6)
plt.plot(x, deqa_vals, marker="s", label=f"DeQA >= {args.deqa_threshold}", linewidth=2, markersize=6)
plt.plot(x, human_vals, marker="^", label="Human High", linewidth=2, markersize=6)
# Annotate each point with its score
for i, (r, d, h) in enumerate(zip(raw_vals, deqa_vals, human_vals)):
plt.annotate(f"{r:.3f}", (i, r), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
plt.annotate(f"{d:.3f}", (i, d), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
plt.annotate(f"{h:.3f}", (i, h), textcoords="offset points", xytext=(0,8), ha='center', fontsize=8)
plt.xticks(list(x), means_df["method"].tolist(), rotation=25, ha="right")
plt.ylabel("Mean hallucination (1 - ANLS)")
plt.title("Pipeline comparison over preprocessing methods")
plt.grid(axis="y", linestyle="--", alpha=0.3)
plt.legend()
plt.tight_layout()
out_png = outdir / "pipeline_comparison.png"
plt.savefig(out_png, dpi=160)
plt.close()
print(f"Saved: {out_csv}")
print(f"Saved: {out_png}")
if __name__ == "__main__":
main()