Files
embedding-clustering/filter/analyze_labels.py

561 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Analyze 'label' fields in a JSON dataset and produce summaries.
- Handles entries where 'label' is either an object or a list of objects.
- Computes distributions (is_bill, profession, currency, IDs presence, handwriting/rotation).
- Computes numeric stats (total_billed, amount_paid, remaining_payment, coverages).
- Parses dates and shows temporal distribution.
- Analyzes items: count, sum of amounts and coverages, and mismatches vs total_billed.
- Emits a concise stdout summary and writes CSVs and a Markdown report.
Usage:
python analyze_labels.py --input 008_label_data_sample_seed_1997.json --out-dir .
"""
from __future__ import annotations
import argparse
import csv
import json
import math
import re
from collections import Counter
from datetime import datetime
from pathlib import Path
from statistics import mean, median
from typing import Any, Dict, Iterable, List, Optional, Tuple
NUMERIC_FIELDS = [
"total_billed",
"amount_paid",
"remaining_payment",
"client_part",
"mandatory_coverage",
"complementary_coverage",
]
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Analyze 'label' fields in JSON dataset")
p.add_argument("--input", required=True, help="Path to JSON file (list of records)")
p.add_argument(
"--out-dir", default=None, help="Output directory for reports (default: alongside input)"
)
p.add_argument(
"--max-professions", type=int, default=50, help="Max professions to list in report"
)
p.add_argument(
"--no-plots",
action="store_true",
help="Disable generating plots (PNG) and embedding into report",
)
p.add_argument(
"--plot-top-k",
type=int,
default=20,
help="Top-K categories to visualize for profession/currency",
)
p.add_argument(
"--plot-format",
type=str,
default="png",
choices=["png", "jpg", "jpeg"],
help="Image format for plots",
)
return p.parse_args()
def load_json(path: Path) -> List[Dict[str, Any]]:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Top-level JSON must be a list of records")
return data
def to_bool(value: Any) -> Optional[bool]:
if value is None:
return None
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
v = value.strip().lower()
if v in {"true", "t", "1", "yes", "y"}:
return True
if v in {"false", "f", "0", "no", "n"}:
return False
return None
def to_float(value: Any) -> Optional[float]:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def parse_date(value: Any) -> Optional[datetime]:
if not value or not isinstance(value, str):
return None
s = value.strip()
if not s:
return None
# Common patterns (day-first)
fmts = [
"%d-%m-%Y",
"%d/%m/%Y",
"%Y-%m-%d",
"%d-%m-%y",
"%d/%m/%y",
]
for fmt in fmts:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
# Try to extract a date-like token using regex (e.g., 2025-02-07 or 07-02-2025)
m = re.search(r"(\d{2}[/-]\d{2}[/-]\d{4}|\d{4}-\d{2}-\d{2})", s)
if m:
token = m.group(1)
for fmt in fmts:
try:
return datetime.strptime(token, fmt)
except ValueError:
continue
return None
def safe_get(d: Dict[str, Any], key: str, default=None):
return d.get(key, default)
def flatten_labels(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for rec in records:
src_image = rec.get("image") or ",".join(rec.get("image_files", []) or [])
label = rec.get("label")
if label is None:
continue
if isinstance(label, list):
for idx, lab in enumerate(label):
if not isinstance(lab, dict):
continue
o = dict(lab)
o["__source_image__"] = src_image
o["__multi_index__"] = idx
out.append(o)
elif isinstance(label, dict):
o = dict(label)
o["__source_image__"] = src_image
out.append(o)
return out
def presence_counts(labels: List[Dict[str, Any]], fields: Iterable[str]) -> Dict[str, int]:
counts: Dict[str, int] = {}
for field in fields:
present = 0
for lbl in labels:
if safe_get(lbl, field) not in (None, ""):
present += 1
counts[field] = present
return counts
def numeric_summary(values: List[Optional[float]]) -> Dict[str, Any]:
clean = [v for v in values if isinstance(v, (int, float)) and not math.isnan(v)]
if not clean:
return {"count": 0}
return {
"count": len(clean),
"min": min(clean),
"p25": percentile(clean, 25),
"median": median(clean),
"p75": percentile(clean, 75),
"max": max(clean),
"mean": mean(clean),
"sum": sum(clean),
"missing": len(values) - len(clean),
}
def percentile(arr: List[float], p: float) -> float:
if not arr:
return float("nan")
a = sorted(arr)
k = (len(a) - 1) * (p / 100.0)
f = math.floor(k)
c = math.ceil(k)
if f == c:
return a[int(k)]
d0 = a[f] * (c - k)
d1 = a[c] * (k - f)
return d0 + d1
def write_csv(path: Path, headers: List[str], rows: Iterable[Iterable[Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8", newline="") as f:
w = csv.writer(f)
w.writerow(headers)
for row in rows:
w.writerow(row)
def try_import_matplotlib():
try:
import matplotlib # type: ignore[import-not-found]
matplotlib.use("Agg") # headless backend
import matplotlib.pyplot as plt # type: ignore[import-not-found]
return plt
except Exception:
return None
def save_bar_plot(plt, x_labels: List[str], values: List[float], title: str, out_path: Path, rotation: int = 0):
out_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(max(6, min(14, 0.4 * len(x_labels) + 3)), 4))
ax.bar(range(len(values)), values, color="#4C78A8")
ax.set_title(title)
ax.set_ylabel("count")
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=rotation, ha="right" if rotation else "center")
fig.tight_layout()
fig.savefig(out_path, dpi=150)
plt.close(fig)
def save_hist_plot(plt, values: List[float], title: str, out_path: Path, bins: int = 30):
out_path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(7, 4))
ax.hist(values, bins=bins, color="#72B7B2", edgecolor="white")
ax.set_title(title)
ax.set_ylabel("count")
ax.set_xlabel("value")
fig.tight_layout()
fig.savefig(out_path, dpi=150)
plt.close(fig)
def produce_plots(
out_dir: Path,
args: argparse.Namespace,
is_bill_counter: Counter,
bill_paid_counter: Counter,
handwriting_counter: Counter,
rotation_counter: Counter,
profession_counter: Counter,
currency_counter: Counter,
year_month_counter: Counter,
numeric_data: Dict[str, List[Optional[float]]],
items_per_label: List[int],
) -> List[Path]:
"""Generate plots and return list of created file paths."""
if args.no_plots:
return []
plt = try_import_matplotlib()
if plt is None:
# matplotlib not available; skip plotting gracefully
return []
created: List[Path] = []
plots_dir = out_dir / "plots"
ext = args.plot_format
# is_bill
if is_bill_counter:
labels = [str(k) for k, _ in is_bill_counter.items()]
vals = [v for _, v in is_bill_counter.items()]
p = plots_dir / f"is_bill.{ext}"
save_bar_plot(plt, labels, vals, "is_bill distribution", p)
created.append(p)
# bill_paid
if bill_paid_counter:
labels = [str(k) for k, _ in bill_paid_counter.items()]
vals = [v for _, v in bill_paid_counter.items()]
p = plots_dir / f"bill_paid.{ext}"
save_bar_plot(plt, labels, vals, "bill_paid distribution", p)
created.append(p)
# Flags
if handwriting_counter:
labels = [str(k) for k, _ in handwriting_counter.items()]
vals = [v for _, v in handwriting_counter.items()]
p = plots_dir / f"is_handwriting.{ext}"
save_bar_plot(plt, labels, vals, "is_handwriting", p)
created.append(p)
if rotation_counter:
labels = [str(k) for k, _ in rotation_counter.items()]
vals = [v for _, v in rotation_counter.items()]
p = plots_dir / f"is_rotated.{ext}"
save_bar_plot(plt, labels, vals, "is_rotated", p)
created.append(p)
# Professions (top-K)
if profession_counter:
top = profession_counter.most_common(max(1, min(args.plot_top_k, len(profession_counter))))
labels = [k if len(str(k)) <= 20 else str(k)[:17] + "" for k, _ in top]
vals = [v for _, v in top]
p = plots_dir / f"professions_top{len(labels)}.{ext}"
save_bar_plot(plt, labels, vals, f"Top {len(labels)} professions", p, rotation=45)
created.append(p)
# Currency
if currency_counter:
top = currency_counter.most_common(max(1, min(args.plot_top_k, len(currency_counter))))
labels = [str(k) for k, _ in top]
vals = [v for _, v in top]
p = plots_dir / f"currency.{ext}"
save_bar_plot(plt, labels, vals, "Currency distribution", p)
created.append(p)
# Year-month
if year_month_counter:
items = sorted(year_month_counter.items(), key=lambda x: (x[0][0], x[0][1]))
labels = [f"{y:04d}-{m:02d}" for (y, m), _ in items]
vals = [v for _, v in items]
p = plots_dir / f"invoice_year_month.{ext}"
save_bar_plot(plt, labels, vals, "Invoices by year-month", p, rotation=45)
created.append(p)
# Items per label
if items_per_label:
p = plots_dir / f"items_per_label.{ext}"
save_hist_plot(plt, items_per_label, "Items per label (histogram)", p, bins=min(30, max(5, int(len(items_per_label) ** 0.5))))
created.append(p)
# Numeric fields histograms
for k, vals_all in numeric_data.items():
vals = [float(v) for v in vals_all if isinstance(v, (int, float)) and not math.isnan(v)]
if not vals:
continue
p = plots_dir / f"hist_{k}.{ext}"
save_hist_plot(plt, vals, f"{k} (histogram)", p)
created.append(p)
return created
def main() -> None:
args = parse_args()
in_path = Path(args.input).resolve()
out_dir = Path(args.out_dir).resolve() if args.out_dir else in_path.parent
out_dir.mkdir(parents=True, exist_ok=True)
records = load_json(in_path)
labels = flatten_labels(records)
n_total_rec = len(records)
n_labels = len(labels)
# Normalize some fields
for lbl in labels:
lbl["is_bill"] = to_bool(lbl.get("is_bill"))
lbl["bill_paid"] = to_bool(lbl.get("bill_paid"))
# Normalize numeric fields in-place for ease of stats
for k in NUMERIC_FIELDS:
lbl[k] = to_float(lbl.get(k))
# Basic distributions
is_bill_counter = Counter(lbl.get("is_bill") for lbl in labels)
bill_paid_counter = Counter(lbl.get("bill_paid") for lbl in labels)
currency_counter = Counter(lbl.get("currency") for lbl in labels if lbl.get("currency"))
profession_counter = Counter((lbl.get("profession") or "").strip() or "(missing)" for lbl in labels)
# Presence of identifiers and key fields
id_presence = presence_counts(labels, [
"adeli_number",
"rpps_number",
"finess_number",
"prescripteur_finess_number",
"doctor_name",
"invoice_issuer",
"insured_name",
"beneficiary_name",
"security_number",
"currency",
])
# Handwriting/rotation flags
handwriting_counter = Counter(to_bool(lbl.get("is_handwriting")) for lbl in labels)
rotation_counter = Counter(to_bool(lbl.get("is_rotated")) for lbl in labels)
# Numeric stats
numeric_stats: Dict[str, Dict[str, Any]] = {}
for k in NUMERIC_FIELDS:
numeric_stats[k] = numeric_summary([lbl.get(k) for lbl in labels])
# Keep raw numeric data for histograms
numeric_raw: Dict[str, List[Optional[float]]] = {k: [lbl.get(k) for lbl in labels] for k in NUMERIC_FIELDS}
# Dates
invoice_dates = [parse_date(lbl.get("invoice_date")) for lbl in labels]
invoice_dates_clean = [d for d in invoice_dates if d is not None]
year_month_counter = Counter((d.year, d.month) for d in invoice_dates_clean)
# Items analysis
items_per_label: List[int] = []
sum_item_amount: List[Optional[float]] = []
sum_item_mandatory: List[Optional[float]] = []
mismatch_records: List[Tuple[str, Optional[float], Optional[float], Optional[float]]] = []
for lbl in labels:
items = lbl.get("items") or []
if not isinstance(items, list):
items = []
items_per_label.append(len(items))
s_amount = None
s_mand = None
for it in items:
if not isinstance(it, dict):
continue
a = to_float(it.get("amount"))
m = to_float(it.get("mandatory_coverage"))
s_amount = (s_amount or 0.0) + (a or 0.0)
s_mand = (s_mand or 0.0) + (m or 0.0)
sum_item_amount.append(s_amount)
sum_item_mandatory.append(s_mand)
total_billed = to_float(lbl.get("total_billed"))
if total_billed is not None and s_amount is not None:
diff = total_billed - s_amount
if abs(diff) > 1e-6:
mismatch_records.append((
str(lbl.get("__source_image__")), total_billed, s_amount, diff
))
# Data quality issues
issues: List[Dict[str, Any]] = []
for lbl in labels:
src = str(lbl.get("__source_image__"))
# is_bill must be True/False or None (unknown); flag strings that could not be parsed
if "is_bill" in lbl and not isinstance(lbl.get("is_bill"), (bool, type(None))):
issues.append({"source": src, "issue": "is_bill not boolean"})
# bill_paid True but amount_paid missing
if lbl.get("bill_paid") is True and to_float(lbl.get("amount_paid")) is None:
issues.append({"source": src, "issue": "bill_paid True but amount_paid missing"})
# remaining_payment > 0 but bill_paid True
rp = to_float(lbl.get("remaining_payment"))
if lbl.get("bill_paid") is True and (rp or 0) > 0:
issues.append({"source": src, "issue": "bill_paid True but remaining_payment > 0"})
# Negative or zero amounts on items
items = lbl.get("items") or []
if isinstance(items, list):
for idx, it in enumerate(items):
if not isinstance(it, dict):
continue
a = to_float(it.get("amount"))
if a is not None and a < 0:
issues.append({"source": src, "issue": f"item[{idx}].amount negative: {a}"})
q = to_float(it.get("quantity"))
if q is None:
# Not strictly an issue, but mark for completeness
issues.append({"source": src, "issue": f"item[{idx}].quantity missing"})
# Missing currency on bill
if lbl.get("is_bill") is True and not lbl.get("currency"):
issues.append({"source": src, "issue": "currency missing for bill"})
# Outputs
# 1) CSVs
write_csv(out_dir / "professions_counts.csv", ["profession", "count"], profession_counter.most_common())
write_csv(out_dir / "currency_counts.csv", ["currency", "count"], currency_counter.most_common())
write_csv(out_dir / "is_bill_counts.csv", ["is_bill", "count"], is_bill_counter.items())
write_csv(out_dir / "bill_paid_counts.csv", ["bill_paid", "count"], bill_paid_counter.items())
write_csv(out_dir / "id_presence.csv", ["field", "present_count"], id_presence.items())
write_csv(out_dir / "item_total_billed_mismatches.csv", ["source_image", "total_billed", "sum_item_amount", "diff"], mismatch_records)
write_csv(out_dir / "issues.csv", ["source", "issue"], ((i["source"], i["issue"]) for i in issues))
# 2) Markdown report
md = []
md.append("# Label Analysis Report\n")
md.append(f"Input: `{in_path.name}`\n")
md.append("")
md.append("## Overview\n")
md.append(f"- Total records: {n_total_rec}")
md.append(f"- Total labels (flattened): {n_labels}")
md.append(f"- is_bill distribution: {dict(is_bill_counter)}")
md.append(f"- bill_paid distribution: {dict(bill_paid_counter)}")
if invoice_dates_clean:
md.append(
f"- Invoice dates span: {min(invoice_dates_clean).date()} .. {max(invoice_dates_clean).date()}"
)
md.append(f"- Unique year-month pairs: {len(year_month_counter)}")
else:
md.append("- Invoice dates: none parseable")
md.append("\n## Professions (top)\n")
for prof, cnt in profession_counter.most_common(args.max_professions):
md.append(f"- {prof}: {cnt}")
md.append("\n## Currency distribution\n")
for cur, cnt in currency_counter.most_common():
md.append(f"- {cur}: {cnt}")
md.append("\n## Identifier and key field presence\n")
for k, v in id_presence.items():
md.append(f"- {k}: {v} present")
md.append("\n## Flags\n")
md.append(f"- is_handwriting: {dict(handwriting_counter)}")
md.append(f"- is_rotated: {dict(rotation_counter)}")
md.append("\n## Numeric summaries\n")
for k, stats in numeric_stats.items():
md.append(f"- {k}: {stats}")
if items_per_label:
md.append("\n## Items analysis\n")
md.append(f"- Items per label: count={len(items_per_label)}, min={min(items_per_label)}, max={max(items_per_label)}, mean={mean(items_per_label):.2f}")
n_mismatch = len(mismatch_records)
md.append(f"- total_billed vs sum(items.amount) mismatches: {n_mismatch}")
if issues:
md.append("\n## Data quality issues (sample)\n")
for row in issues[:50]:
md.append(f"- {row['source']}: {row['issue']}")
# 3) Plots (if enabled)
created_plots = produce_plots(
out_dir=out_dir,
args=args,
is_bill_counter=is_bill_counter,
bill_paid_counter=bill_paid_counter,
handwriting_counter=handwriting_counter,
rotation_counter=rotation_counter,
profession_counter=profession_counter,
currency_counter=currency_counter,
year_month_counter=year_month_counter,
numeric_data=numeric_raw,
items_per_label=items_per_label,
)
if created_plots:
md.append("\n## Plots\n")
for p in created_plots:
rel = p.relative_to(out_dir)
md.append(f"- {p.stem}")
md.append(f"![]({rel.as_posix()})\n")
elif not args.no_plots:
md.append("\n## Plots\n")
md.append("- matplotlib not available or no data to plot.")
report_path = out_dir / "label_analysis_report.md"
report_path.write_text("\n".join(md), encoding="utf-8")
# Console summary
print("Label analysis complete.")
print(f"- Records: {n_total_rec}, Labels: {n_labels}")
print(f"- is_bill: {dict(is_bill_counter)} | bill_paid: {dict(bill_paid_counter)}")
print(f"- Professions (top 10): {profession_counter.most_common(10)}")
print(f"- Currency: {dict(currency_counter)}")
print(f"Report written to: {report_path}")
if created_plots:
print(f"- Plots saved under: {(out_dir / 'plots').as_posix()} ({len(created_plots)} files)")
if __name__ == "__main__":
main()