update ressult
This commit is contained in:
303
scripts/compute_hallucination_from_anls.py
Normal file
303
scripts/compute_hallucination_from_anls.py
Normal file
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compute ANLS and hallucination scores from per-sample evaluation JSONs and plot results.
|
||||
|
||||
Inputs: one or more JSON files with schema like:
|
||||
[
|
||||
{
|
||||
"image": "image (22)",
|
||||
"num_pred_fields": 10,
|
||||
"num_gt_fields": 12,
|
||||
"num_correct": 9,
|
||||
"all_correct": false,
|
||||
"fields": [
|
||||
{"field": "address", "pred": "...", "gt": "...", "correct": true},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
ANLS definition: average normalized Levenshtein similarity over fields present in ground truth.
|
||||
Here we approximate per-field similarity as:
|
||||
sim = 1 - (levenshtein_distance(pred, gt) / max(len(pred), len(gt)))
|
||||
clipped into [0, 1], and treat empty max length as exact match (1.0).
|
||||
|
||||
Per-image ANLS is the mean of field similarities for that image. Hallucination is 1 - ANLS.
|
||||
|
||||
Outputs:
|
||||
- CSV per input JSON placed next to it: per_image_anls.csv with columns [image, anls, hallucination_score, num_fields]
|
||||
- PNG bar chart per input JSON: hallucination_per_image.png with mean line and title.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def levenshtein_distance(a: str, b: str) -> int:
|
||||
"""Compute Levenshtein distance between two strings (iterative DP)."""
|
||||
if a == b:
|
||||
return 0
|
||||
if len(a) == 0:
|
||||
return len(b)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
previous_row = list(range(len(b) + 1))
|
||||
for i, ca in enumerate(a, start=1):
|
||||
current_row = [i]
|
||||
for j, cb in enumerate(b, start=1):
|
||||
insertions = previous_row[j] + 1
|
||||
deletions = current_row[j - 1] + 1
|
||||
substitutions = previous_row[j - 1] + (0 if ca == cb else 1)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
def normalized_similarity(pred: str, gt: str) -> float:
|
||||
"""Return 1 - normalized edit distance in [0, 1]."""
|
||||
pred = pred or ""
|
||||
gt = gt or ""
|
||||
max_len = max(len(pred), len(gt))
|
||||
if max_len == 0:
|
||||
return 1.0
|
||||
dist = levenshtein_distance(pred, gt)
|
||||
sim = 1.0 - (dist / max_len)
|
||||
if sim < 0.0:
|
||||
return 0.0
|
||||
if sim > 1.0:
|
||||
return 1.0
|
||||
return sim
|
||||
|
||||
|
||||
def compute_anls_for_record(record: Dict) -> Tuple[float, int]:
|
||||
"""Compute ANLS and number of fields for a single record object."""
|
||||
fields = record.get("fields") or []
|
||||
if not isinstance(fields, list) or len(fields) == 0:
|
||||
return 0.0, 0
|
||||
sims: List[float] = []
|
||||
for f in fields:
|
||||
pred = str(f.get("pred", ""))
|
||||
gt = str(f.get("gt", ""))
|
||||
sims.append(normalized_similarity(pred, gt))
|
||||
anls = float(sum(sims) / len(sims)) if sims else 0.0
|
||||
return anls, len(sims)
|
||||
|
||||
|
||||
def process_json(json_path: Path) -> Path:
|
||||
with json_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
rows = []
|
||||
for rec in data:
|
||||
image_name = rec.get("image")
|
||||
anls, num_fields = compute_anls_for_record(rec)
|
||||
hallucination = 1.0 - anls
|
||||
rows.append({
|
||||
"image": image_name,
|
||||
"anls": anls,
|
||||
"hallucination_score": hallucination,
|
||||
"num_fields": int(num_fields),
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
out_csv = json_path.parent / "per_image_anls.csv"
|
||||
df.to_csv(out_csv, index=False)
|
||||
|
||||
# Plot hallucination bar chart with mean line
|
||||
if len(df) > 0:
|
||||
sorted_df = df.sort_values("hallucination_score", ascending=False).reset_index(drop=True)
|
||||
plt.figure(figsize=(max(8, len(sorted_df) * 0.12), 5))
|
||||
plt.bar(range(len(sorted_df)), sorted_df["hallucination_score"].values, color="#1f77b4")
|
||||
mean_val = float(sorted_df["hallucination_score"].mean())
|
||||
plt.axhline(mean_val, color="red", linestyle="--", label=f"Mean={mean_val:.3f}")
|
||||
plt.xlabel("Image (sorted by hallucination)")
|
||||
plt.ylabel("Hallucination = 1 - ANLS")
|
||||
plt.title(f"Hallucination per image: {json_path.parent.name}")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
out_png = json_path.parent / "hallucination_per_image.png"
|
||||
plt.savefig(out_png, dpi=150)
|
||||
plt.close()
|
||||
|
||||
return out_csv
|
||||
|
||||
|
||||
def common_parent(paths: List[Path]) -> Path:
|
||||
if not paths:
|
||||
return Path.cwd()
|
||||
common = Path(Path(paths[0]).anchor)
|
||||
parts = list(Path(paths[0]).resolve().parts)
|
||||
for i in range(1, len(paths)):
|
||||
other_parts = list(Path(paths[i]).resolve().parts)
|
||||
# shrink parts to common prefix
|
||||
new_parts: List[str] = []
|
||||
for a, b in zip(parts, other_parts):
|
||||
if a == b:
|
||||
new_parts.append(a)
|
||||
else:
|
||||
break
|
||||
parts = new_parts
|
||||
if not parts:
|
||||
return Path.cwd()
|
||||
return Path(*parts)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Compute ANLS and hallucination from per-sample JSONs and plot results.")
|
||||
parser.add_argument("inputs", nargs="+", help="Paths to per_sample_eval.json files")
|
||||
args = parser.parse_args()
|
||||
|
||||
any_error = False
|
||||
combined_rows: List[Dict] = []
|
||||
input_paths: List[Path] = []
|
||||
for in_path_str in args.inputs:
|
||||
path = Path(in_path_str)
|
||||
if not path.exists():
|
||||
print(f"[WARN] File does not exist: {path}", file=sys.stderr)
|
||||
any_error = True
|
||||
continue
|
||||
try:
|
||||
out_csv = process_json(path)
|
||||
print(f"Processed: {path} -> {out_csv}")
|
||||
# Load just-written CSV to aggregate and tag method
|
||||
df = pd.read_csv(out_csv)
|
||||
method_name = path.parent.name
|
||||
df["method"] = method_name
|
||||
combined_rows.extend(df.to_dict(orient="records"))
|
||||
input_paths.append(path)
|
||||
except Exception as exc:
|
||||
print(f"[ERROR] Failed to process {path}: {exc}", file=sys.stderr)
|
||||
any_error = True
|
||||
# Create combined outputs if we have multiple inputs
|
||||
if combined_rows:
|
||||
combo_df = pd.DataFrame(combined_rows)
|
||||
# Reorder columns
|
||||
cols = ["image", "method", "anls", "hallucination_score", "num_fields"]
|
||||
combo_df = combo_df[cols]
|
||||
base_outdir = common_parent(input_paths)
|
||||
combined_dir = base_outdir / "combined_anls"
|
||||
combined_dir.mkdir(parents=True, exist_ok=True)
|
||||
combined_csv = combined_dir / "combined_per_image_anls.csv"
|
||||
combo_df.to_csv(combined_csv, index=False)
|
||||
|
||||
# Mean hallucination per method (bar chart)
|
||||
means = combo_df.groupby("method")["hallucination_score"].mean().sort_values(ascending=False)
|
||||
stds = combo_df.groupby("method")["hallucination_score"].std().reindex(means.index)
|
||||
plt.figure(figsize=(max(6, len(means) * 1.2), 5))
|
||||
plt.bar(means.index, means.values, yerr=stds.values, capsize=4, color="#2ca02c")
|
||||
overall_mean = float(combo_df["hallucination_score"].mean())
|
||||
plt.axhline(overall_mean, color="red", linestyle="--", label=f"Overall mean={overall_mean:.3f}")
|
||||
plt.ylabel("Mean hallucination (1 - ANLS)")
|
||||
plt.title("Mean hallucination by method")
|
||||
plt.xticks(rotation=20, ha="right")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
bar_png = combined_dir / "mean_hallucination_by_method.png"
|
||||
plt.savefig(bar_png, dpi=160)
|
||||
plt.close()
|
||||
|
||||
# Heatmap: images x methods (hallucination)
|
||||
pivot = combo_df.pivot_table(index="image", columns="method", values="hallucination_score", aggfunc="mean")
|
||||
# Sort images by average hallucination descending for readability
|
||||
pivot = pivot.reindex(pivot.mean(axis=1).sort_values(ascending=False).index)
|
||||
plt.figure(figsize=(max(8, len(pivot.columns) * 1.0), max(6, len(pivot.index) * 0.25)))
|
||||
im = plt.imshow(pivot.values, aspect="auto", cmap="viridis")
|
||||
plt.colorbar(im, label="Hallucination (1 - ANLS)")
|
||||
plt.xticks(range(len(pivot.columns)), pivot.columns, rotation=30, ha="right")
|
||||
plt.yticks(range(len(pivot.index)), pivot.index)
|
||||
plt.title("Hallucination per image across methods")
|
||||
plt.tight_layout()
|
||||
heatmap_png = combined_dir / "hallucination_heatmap.png"
|
||||
plt.savefig(heatmap_png, dpi=160)
|
||||
plt.close()
|
||||
|
||||
print(f"Combined CSV: {combined_csv}")
|
||||
print(f"Saved: {bar_png}")
|
||||
print(f"Saved: {heatmap_png}")
|
||||
|
||||
# Line chart: 1 line per method over images, hide image names
|
||||
# Use same image order as pivot
|
||||
methods = list(pivot.columns)
|
||||
x = list(range(len(pivot.index)))
|
||||
plt.figure(figsize=(max(10, len(x) * 0.12), 5))
|
||||
colors = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd'])
|
||||
for idx, method in enumerate(methods):
|
||||
y = pivot[method].to_numpy()
|
||||
plt.plot(x, y, label=method, linewidth=1.8, color=colors[idx % len(colors)])
|
||||
plt.ylim(0.0, 1.0)
|
||||
plt.xlabel("Images (sorted by overall hallucination)")
|
||||
plt.ylabel("Hallucination (1 - ANLS)")
|
||||
plt.title("Hallucination across images by method")
|
||||
plt.xticks([], []) # hide image names
|
||||
# Mean note box
|
||||
mean_lines = []
|
||||
for method in methods:
|
||||
m = float(combo_df[combo_df["method"] == method]["hallucination_score"].mean())
|
||||
mean_lines.append(f"{method}: {m:.3f}")
|
||||
text = "\n".join(mean_lines)
|
||||
plt.gca().text(0.99, 0.01, text, transform=plt.gca().transAxes,
|
||||
fontsize=9, va='bottom', ha='right',
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='gray'))
|
||||
plt.legend(loc="upper right", ncol=min(3, len(methods)))
|
||||
plt.tight_layout()
|
||||
line_png = combined_dir / "hallucination_lines_by_method.png"
|
||||
plt.savefig(line_png, dpi=160)
|
||||
plt.close()
|
||||
|
||||
# Grouped-by-image interlocking line chart with image labels
|
||||
# Build a consistent x position per image, with small offsets per method
|
||||
base_x = list(range(len(pivot.index)))
|
||||
offsets = {
|
||||
m: ((i - (len(methods) - 1) / 2) * 0.12) for i, m in enumerate(methods)
|
||||
}
|
||||
# Cap width to avoid extremely long images; dynamic but limited
|
||||
width = min(16, max(10, len(base_x) * 0.12))
|
||||
plt.figure(figsize=(width, 6))
|
||||
for idx, method in enumerate(methods):
|
||||
# Fill missing values with 0 to connect lines seamlessly
|
||||
y = pivot[method].fillna(0.0).to_numpy()
|
||||
x_shifted = [bx + offsets[method] for bx in base_x]
|
||||
plt.plot(x_shifted, y, label=method, linewidth=1.8, marker='o', markersize=3,
|
||||
color=colors[idx % len(colors)])
|
||||
plt.ylim(0.0, 1.0)
|
||||
plt.xlim(-0.5, len(base_x) - 0.5)
|
||||
# Hide image names; keep index ticks sparse for readability
|
||||
plt.xticks([], [])
|
||||
plt.xlabel("Images (index)")
|
||||
plt.ylabel("Hallucination (1 - ANLS)")
|
||||
plt.title("Hallucination by image (interlocked methods)")
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.3)
|
||||
# Add box with per-method mean
|
||||
text2 = "\n".join([f"{m}: {float(combo_df[combo_df['method']==m]['hallucination_score'].mean()):.3f}" for m in methods])
|
||||
plt.gca().text(0.99, 0.01, text2, transform=plt.gca().transAxes,
|
||||
fontsize=9, va='bottom', ha='right',
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor='gray'))
|
||||
plt.legend(loc='upper right', ncol=min(3, len(methods)))
|
||||
plt.tight_layout()
|
||||
group_line_png = combined_dir / "hallucination_interlocked_by_image.png"
|
||||
plt.savefig(group_line_png, dpi=160)
|
||||
plt.close()
|
||||
|
||||
print(f"Combined CSV: {combined_csv}")
|
||||
print(f"Saved: {bar_png}")
|
||||
print(f"Saved: {heatmap_png}")
|
||||
print(f"Saved: {line_png}")
|
||||
print(f"Saved: {group_line_png}")
|
||||
|
||||
if any_error:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user