update PaddleOCR result & docTR result

This commit is contained in:
Nguyễn Phước Thành
2025-08-09 22:29:33 +07:00
parent f63589a10a
commit 028e3237bb
15 changed files with 12838 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
# Re-include all image files and JSON within this folder (and subfolders)
# PNG
!*.png
!**/*.png
# JPG/JPEG
!*.jpg
!**/*.jpg
!*.jpeg
!**/*.jpeg
# BMP/GIF/TIFF/WEBP
!*.bmp
!**/*.bmp
!*.gif
!**/*.gif
!*.tif
!**/*.tif
!*.tiff
!**/*.tiff
!*.webp
!**/*.webp
# JSON
!*.json
!**/*.json
# Ensure this file itself is tracked
!.gitignore

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 MiB

View File

@@ -0,0 +1,643 @@
import argparse
import json
import os
import sys
from typing import List, Tuple, Any, Optional, Dict
from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageOps
import numpy as np
import cv2
from paddleocr import PaddleOCR
try: # Some paddleocr builds do not expose draw_ocr at top-level
from paddleocr import draw_ocr as paddle_draw_ocr # type: ignore
except Exception: # pragma: no cover - environment dependent
paddle_draw_ocr = None
def resolve_font_path() -> str:
"""Return a system font path that supports Latin accents (Windows-friendly)."""
candidate_fonts: List[str] = []
# Prefer Arial Unicode or Arial on Windows
windir = os.environ.get("WINDIR", r"C:\\Windows")
candidate_fonts.extend(
[
os.path.join(windir, "Fonts", "arialuni.ttf"),
os.path.join(windir, "Fonts", "arial.ttf"),
os.path.join(windir, "Fonts", "tahoma.ttf"),
os.path.join(windir, "Fonts", "seguiemj.ttf"),
]
)
# Common fallbacks on other platforms
candidate_fonts.extend(
[
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf", # macOS
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux
]
)
for font in candidate_fonts:
if os.path.exists(font):
return font
return "" # paddleocr will still draw, but non-ASCII may not render correctly
def build_argparser() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Simple PaddleOCR runner for single image"
)
parser.add_argument(
"--image",
default="im1.png",
help="Path to input image (default: im1.png)",
)
parser.add_argument(
"--lang",
default="fr",
help="Language code for PaddleOCR (try: fr, en, vi, etc.). Default: fr",
)
parser.add_argument(
"--out-json",
default="ocr_result.json",
help="Path to save raw OCR results as JSON",
)
parser.add_argument(
"--out-image",
default="ocr_vis.png",
help="Path to save visualization image with boxes and text",
)
parser.add_argument(
"--poly-source",
choices=["dt", "rec", "auto"],
default="auto",
help="Which polygons to visualize: detection(dt), recognition(rec) or auto (prefer dt)",
)
parser.add_argument(
"--box-color",
default="#FF0000",
help="Outline color for polygons (hex like #RRGGBB or 'R,G,B').",
)
parser.add_argument(
"--fill-color",
default="#FF000033",
help="Fill color with alpha for polygons (e.g., #FF000033).",
)
parser.add_argument(
"--box-width",
type=int,
default=3,
help="Outline width for polygons.",
)
parser.add_argument(
"--scale-ratio",
type=float,
default=None,
help="Scale polygons around centroid (e.g., 0.96 to shrink, 1.05 to expand).",
)
parser.add_argument(
"--shrink-ratio",
type=float,
default=None,
help="Deprecated: use --scale-ratio. If given, will be used when --scale-ratio is not set.",
)
parser.add_argument(
"--offset-x",
type=int,
default=0,
help="Shift polygons horizontally (pixels). Positive moves right.",
)
parser.add_argument(
"--offset-y",
type=int,
default=0,
help="Shift polygons vertically (pixels). Positive moves down.",
)
parser.add_argument(
"--scale-x",
type=float,
default=None,
help="Non-uniform scale along X around centroid (e.g., 1.1). Overrides --scale-ratio on X if set.",
)
parser.add_argument(
"--scale-y",
type=float,
default=None,
help="Non-uniform scale along Y around centroid (e.g., 1.05). Overrides --scale-ratio on Y if set.",
)
parser.add_argument(
"--pad-x",
type=int,
default=0,
help="Pixels to pad left/right after scaling.",
)
parser.add_argument(
"--pad-y",
type=int,
default=0,
help="Pixels to pad top/bottom after scaling.",
)
parser.add_argument(
"--rect",
choices=["poly", "axis", "rotated"],
default="poly",
help="Visualization mode: polygon, axis-aligned rectangle, or rotated min-area rectangle.",
)
parser.add_argument(
"--enhance",
action="store_true",
help="Apply light image enhancement (contrast/sharpen) before OCR.",
)
parser.add_argument(
"--rec-model",
default="",
help=(
"Recognition model name to force (e.g. latin_PP-OCRv5_server_rec, "
"latin_PP-OCRv5_mobile_rec, en_PP-OCRv4_rec, en_PP-OCRv3_rec)."
),
)
parser.add_argument(
"--det-model",
default="",
help=(
"Detection model name to force (e.g. PP-OCRv5_server_det, "
"PP-OCRv5_mobile_det)."
),
)
parser.add_argument(
"--ocr-version",
default="",
help="Optional OCR version hint (e.g., PP-OCRv5).",
)
parser.add_argument(
"--rec-score-thresh",
type=float,
default=0.0,
help="Minimum recognition score to keep a text (default 0.0 to keep all)",
)
parser.add_argument(
"--det-db-thresh",
type=float,
default=None,
help="DB detector binary threshold (e.g., 0.2).",
)
parser.add_argument(
"--det-db-box-thresh",
type=float,
default=None,
help="DB detector box threshold (e.g., 0.3-0.6). Lower to get more boxes.",
)
parser.add_argument(
"--det-db-unclip",
type=float,
default=None,
help="DB detector unclip ratio (e.g., 1.5).",
)
parser.add_argument(
"--det-limit-side",
type=int,
default=None,
help="text_det_limit_side_len. Smaller may speed up and merge boxes.",
)
return parser.parse_args()
def initialize_ocr(preferred_lang: str, args: argparse.Namespace) -> Tuple[PaddleOCR, str]:
"""Initialize PaddleOCR, trying the preferred language first then fallbacks."""
tried_errors: List[str] = []
candidate_langs: List[str] = []
# Keep order but deduplicate
for code in [preferred_lang, "fr", "en"]:
if code not in candidate_langs:
candidate_langs.append(code)
for code in candidate_langs:
# First, try with requested model overrides (if any)
try:
init_kwargs: Dict[str, Any] = {
"lang": code,
"use_textline_orientation": True,
}
if args.rec_model:
init_kwargs["text_recognition_model_name"] = args.rec_model
if args.det_model:
init_kwargs["text_detection_model_name"] = args.det_model
if args.ocr_version:
init_kwargs["ocr_version"] = args.ocr_version
# thresholds & limits (kwargs accepted by PaddleOCR >=3.x)
if args.rec_score_thresh is not None:
init_kwargs["text_rec_score_thresh"] = float(args.rec_score_thresh)
if args.det_db_thresh is not None:
init_kwargs["text_det_db_thresh"] = float(args.det_db_thresh)
if args.det_db_box_thresh is not None:
init_kwargs["text_det_db_box_thresh"] = float(args.det_db_box_thresh)
# Some builds may not support unclip ratio as kwarg; skip if unsupported
if args.det_limit_side is not None:
init_kwargs["text_det_limit_side_len"] = int(args.det_limit_side)
ocr = PaddleOCR(**init_kwargs)
return ocr, code
except Exception as exc_with_models: # pragma: no cover
tried_errors.append(f"{code} (with models): {exc_with_models}")
# Fallback: try default models for this language
try:
base_kwargs: Dict[str, Any] = {
"lang": code,
"use_textline_orientation": True,
}
if args.rec_score_thresh is not None:
base_kwargs["text_rec_score_thresh"] = float(args.rec_score_thresh)
if args.det_limit_side is not None:
base_kwargs["text_det_limit_side_len"] = int(args.det_limit_side)
if args.det_db_thresh is not None:
base_kwargs["text_det_db_thresh"] = float(args.det_db_thresh)
if args.det_db_box_thresh is not None:
base_kwargs["text_det_db_box_thresh"] = float(args.det_db_box_thresh)
# Skip setting unclip if unsupported in this build
ocr = PaddleOCR(**base_kwargs)
return ocr, code
except Exception as exc_default:
tried_errors.append(f"{code} (default): {exc_default}")
raise RuntimeError(
"Failed to initialize PaddleOCR. Tried languages: "
+ ", ".join(candidate_langs)
+ "\nErrors: "
+ " | ".join(tried_errors)
)
def main() -> None:
args = build_argparser()
if not os.path.exists(args.image):
print(f"[ERROR] Image not found: {args.image}")
sys.exit(1)
try:
ocr, used_lang = initialize_ocr(args.lang, args)
except Exception as init_exc:
print(f"[ERROR] {init_exc}")
sys.exit(2)
chosen_rec = args.rec_model if args.rec_model else "auto"
chosen_det = args.det_model if args.det_model else "auto"
print(
f"[INFO] Running OCR on '{args.image}' with lang='{used_lang}', "
f"rec='{chosen_rec}', det='{chosen_det}' ..."
)
# Prefer new API for PaddleOCR >=3.x, fallback to legacy .ocr
try:
result: Any = ocr.predict(args.image) # type: ignore[assignment]
except Exception:
result = ocr.ocr(args.image) # type: ignore[assignment]
# Optional light enhancement for Latin documents to make OCR tighter
def maybe_enhance_image(path: str) -> str:
if not args.enhance:
return path
try:
img = Image.open(path).convert("RGB")
# Slight auto-contrast and sharpening
img = ImageOps.autocontrast(img, cutoff=1)
img = ImageEnhance.Sharpness(img).enhance(1.4)
img = ImageEnhance.Contrast(img).enhance(1.1)
tmp_path = os.path.splitext(path)[0] + "_enh.png"
img.save(tmp_path)
return tmp_path
except Exception:
return path
# Prefer to draw detection polygons only (no text) for clarity
def extract_polygons(res: Any) -> List[List[Tuple[int, int]]]:
def to_tuple_list(poly_any: Any) -> List[Tuple[int, int]]:
try:
return [(int(p[0]), int(p[1])) for p in list(poly_any)]
except Exception:
return []
# Common outputs in new pipelines
if isinstance(res, dict):
if args.poly_source in ("auto", "dt"):
polys = res.get("dt_polys") or res.get("det_polygons")
else:
polys = None
if polys is None and args.poly_source in ("auto", "rec"):
polys = res.get("rec_polys")
if polys is None:
polys = res.get("polygons") or res.get("boxes")
if polys is None:
return []
return [to_tuple_list(poly) for poly in list(polys)]
if isinstance(res, list) and len(res) > 0:
# Often result is a list with a single dict
if isinstance(res[0], dict):
return extract_polygons(res[0])
return []
# Normalize result across PaddleOCR versions
def poly_to_list(poly_any: Any) -> List[List[int]]:
try:
# numpy array path
if hasattr(poly_any, "tolist"):
lst = poly_any.tolist()
return [[int(p[0]), int(p[1])] for p in lst]
except Exception:
pass
try:
return [[int(p[0]), int(p[1])] for p in list(poly_any)]
except Exception:
return []
def parse_lines(res: Any) -> List[Dict[str, Any]]:
items: List[Dict[str, Any]] = []
# Case 0: top-level dict output (some 3.x pipelines)
if isinstance(res, dict):
polys = res.get("det_polygons") or res.get("boxes") or res.get("polygons")
texts = res.get("rec_texts") or res.get("rec_text") or res.get("texts") or []
scores = res.get("rec_scores") or res.get("scores") or []
if isinstance(texts, str):
texts = [texts]
if polys is None:
polys = [None] * len(texts)
if not isinstance(scores, list):
try:
scores = list(scores)
except Exception:
scores = [None] * len(texts)
if len(scores) < len(texts):
scores = list(scores) + [None] * (len(texts) - len(scores))
for poly, text, score in zip(polys, texts, scores):
try:
score_val = float(score) if score is not None else None
except Exception:
score_val = None
items.append({"text": str(text), "score": score_val, "box": poly_to_list(poly) if poly is not None else None})
return items
if isinstance(res, list) and len(res) > 0:
# Special: list with a single dict that holds batched arrays (rec_texts, rec_scores, dt_polys, ...)
if len(res) == 1 and isinstance(res[0], dict) and (
"rec_texts" in res[0] or "texts" in res[0]
):
obj = res[0]
texts = obj.get("rec_texts") or obj.get("texts") or []
scores = obj.get("rec_scores") or obj.get("scores") or []
boxes = obj.get("rec_polys") or obj.get("dt_polys") or []
# Normalize lengths
n = min(len(texts), len(scores) if hasattr(scores, "__len__") else len(texts), len(boxes) if hasattr(boxes, "__len__") else len(texts))
out: List[Dict[str, Any]] = []
for i in range(n):
txt = texts[i]
try:
sc = float(scores[i])
except Exception:
sc = None
bx = boxes[i] if i < len(boxes) else None
out.append({"text": str(txt), "score": sc, "box": poly_to_list(bx) if bx is not None else None})
return out
# Case A: legacy format [[ [poly], (text, score) ], ...] wrapped by [ ... ]
if isinstance(res[0], list) and len(res[0]) > 0 and isinstance(res[0][0], list):
lines_local = res[0]
for line in lines_local:
if not isinstance(line, (list, tuple)) or len(line) < 2:
continue
box = line[0]
text: str = ""
score: Optional[float] = None
payload = line[1]
if isinstance(payload, (list, tuple)) and len(payload) >= 1:
text = str(payload[0])
if len(payload) >= 2:
try:
score = float(payload[1])
except Exception:
score = None
elif isinstance(payload, str):
text = payload
if len(line) >= 3:
try:
score = float(line[2])
except Exception:
score = None
items.append({"text": text, "score": score, "box": box})
return items
# Case B: new format already a flat list of dicts or lists per detection
# Try dict format first
if isinstance(res[0], dict):
for obj in res:
box = obj.get("box") or obj.get("poly") or obj.get("bbox") or obj.get("det_polygons")
text = obj.get("text") or obj.get("rec_text") or ""
score = obj.get("score") or obj.get("rec_score")
try:
score = float(score) if score is not None else None
except Exception:
score = None
items.append({"text": str(text), "score": score, "box": poly_to_list(box) if box is not None else None})
return items
# Case C: flat list of [poly, text, (maybe score)]
if isinstance(res[0], (list, tuple)):
for line in res:
if not isinstance(line, (list, tuple)) or len(line) < 2:
continue
box = line[0]
text = str(line[1])
score: Optional[float] = None
if len(line) >= 3:
try:
score = float(line[2])
except Exception:
score = None
items.append({"text": text, "score": score, "box": poly_to_list(box) if box is not None else None})
return items
return items
parsed = parse_lines(result)
det_polys = extract_polygons(result)
# Additionally collect both dt and rec polygons for JSON output
def extract_both(res: Any) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
def to_tuple_list(poly_any: Any) -> List[Tuple[int, int]]:
try:
if hasattr(poly_any, "tolist"):
poly_any = poly_any.tolist()
return [(int(p[0]), int(p[1])) for p in list(poly_any)]
except Exception:
return []
if isinstance(res, dict):
dt = res.get("dt_polys") or res.get("det_polygons") or []
rc = res.get("rec_polys") or []
return [to_tuple_list(p) for p in list(dt)], [to_tuple_list(p) for p in list(rc)]
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], dict):
return extract_both(res[0])
return [], []
all_dt_polys, all_rec_polys = extract_both(result)
# Print quick summary to console
print("\n[TEXT]\n" + "\n".join([p["text"] for p in parsed]))
# Save JSON
with open(args.out_json, "w", encoding="utf-8") as f:
json.dump(
{
"image": os.path.abspath(args.image),
"language": used_lang,
"num_items": len(parsed),
"items": parsed,
"poly_source": args.poly_source,
"det_polygons": [[list(pt) for pt in poly] for poly in all_dt_polys],
"rec_polygons": [[list(pt) for pt in poly] for poly in all_rec_polys],
"enhance": bool(args.enhance),
"recognition_model": chosen_rec,
"detection_model": chosen_det,
"ocr_version": args.ocr_version or "auto",
"box_color": args.box_color,
"fill_color": args.fill_color,
"box_width": int(args.box_width),
"scale_ratio": float(args.scale_ratio) if args.scale_ratio is not None else None,
"offset_x": int(args.offset_x),
"offset_y": int(args.offset_y),
},
f,
ensure_ascii=False,
indent=2,
)
print(f"[INFO] Saved JSON: {args.out_json}")
# Also store raw result for debugging purposes
try:
with open("raw_result.txt", "w", encoding="utf-8") as rf:
rf.write(repr(result))
except Exception:
pass
# Draw and save visualization (only polygons)
image = Image.open(args.image).convert("RGBA")
canvas = image.copy()
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Prefer detection polygons from pipeline; fallback to parsed boxes
polygons: List[List[Tuple[int, int]]] = det_polys
if not polygons:
polygons = [
[tuple(p) for p in (box or [])] # type: ignore[misc]
for box in [p.get("box") for p in parsed]
if box
]
def parse_color(color_str: str, default=(255, 0, 0, 255)) -> Tuple[int, int, int, int]:
try:
if color_str.startswith("#"):
hexv = color_str.lstrip("#")
if len(hexv) == 6:
r = int(hexv[0:2], 16)
g = int(hexv[2:4], 16)
b = int(hexv[4:6], 16)
return (r, g, b, 255)
if len(hexv) == 8:
r = int(hexv[0:2], 16)
g = int(hexv[2:4], 16)
b = int(hexv[4:6], 16)
a = int(hexv[6:8], 16)
return (r, g, b, a)
else:
parts = [int(x) for x in color_str.split(",")]
if len(parts) == 3:
return (parts[0], parts[1], parts[2], 255)
if len(parts) == 4:
return (parts[0], parts[1], parts[2], parts[3])
except Exception:
pass
return default
outline_rgba = parse_color(args.box_color)
fill_rgba = parse_color(args.fill_color, default=(255, 0, 0, 51))
def transform_polygon(
poly: List[Tuple[int, int]],
scale: Optional[float],
dx: int,
dy: int,
scale_x: Optional[float] = None,
scale_y: Optional[float] = None,
pad_x: int = 0,
pad_y: int = 0,
) -> List[Tuple[int, int]]:
if not poly:
return poly
cx = sum(p[0] for p in poly) / len(poly)
cy = sum(p[1] for p in poly) / len(poly)
out: List[Tuple[int, int]] = []
for (x, y) in poly:
sx = scale_x if scale_x is not None else scale
sy = scale_y if scale_y is not None else scale
if sx is not None:
x = cx + sx * (x - cx)
if sy is not None:
y = cy + sy * (y - cy)
out.append((int(round(x + dx)), int(round(y + dy))))
# pad expands rect-like by moving points outwards along axes
if pad_x or pad_y:
out = [(x - pad_x if x < cx else x + pad_x, y - pad_y if y < cy else y + pad_y) for (x, y) in out]
return out
vis_polys: List[List[Tuple[int, int]]] = []
def draw_axis_aligned(draw_obj, pts: List[Tuple[int, int]]):
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
box = [(min(xs), min(ys)), (max(xs), min(ys)), (max(xs), max(ys)), (min(xs), max(ys))]
draw_obj.polygon(box, outline=outline_rgba, fill=fill_rgba)
draw_obj.line(box + [box[0]], fill=outline_rgba, width=args.box_width)
def draw_rotated(draw_obj, pts: List[Tuple[int, int]]):
cnt = np.array(pts, dtype=np.int32).reshape(-1, 1, 2)
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
box = np.int0(box)
poly = [(int(x), int(y)) for x, y in box.tolist()]
draw_obj.polygon(poly, outline=outline_rgba, fill=fill_rgba)
draw_obj.line(poly + [poly[0]], fill=outline_rgba, width=args.box_width)
for poly in polygons:
if len(poly) >= 3:
# Backward-compat: if scale-ratio not provided, use shrink-ratio (<1.0)
scale = args.scale_ratio if args.scale_ratio is not None else args.shrink_ratio
sp = transform_polygon(
poly,
scale,
args.offset_x,
args.offset_y,
args.scale_x,
args.scale_y,
args.pad_x,
args.pad_y,
)
vis_polys.append(sp)
if args.rect == "axis":
draw_axis_aligned(draw, sp)
elif args.rect == "rotated":
draw_rotated(draw, sp)
else:
draw.polygon(sp, outline=outline_rgba, fill=fill_rgba)
draw.line(sp + [sp[0]], fill=outline_rgba, width=args.box_width)
out = Image.alpha_composite(canvas, overlay).convert("RGB")
out.save(args.out_image)
print(f"[INFO] Saved visualization: {args.out_image}")
# Append the visualization polygons to JSON file for exact reproducibility
try:
with open(args.out_json, "r", encoding="utf-8") as fjson:
data = json.load(fjson)
data["vis_polygons"] = [[list(pt) for pt in poly] for poly in vis_polys]
with open(args.out_json, "w", encoding="utf-8") as fjson:
json.dump(data, fjson, ensure_ascii=False, indent=2)
except Exception:
pass
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 MiB

View File

@@ -0,0 +1,32 @@
# Re-include all image files and JSON within this folder (and subfolders)
# PNG
!*.png
!**/*.png
# JPG/JPEG
!*.jpg
!**/*.jpg
!*.jpeg
!**/*.jpeg
# BMP/GIF/TIFF/WEBP
!*.bmp
!**/*.bmp
!*.gif
!**/*.gif
!*.tif
!**/*.tif
!*.tiff
!**/*.tiff
!*.webp
!**/*.webp
# JSON
!*.json
!**/*.json
# Ensure this file itself is tracked
!.gitignore

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@@ -0,0 +1,95 @@
import argparse
import json
import os
from pathlib import Path
import matplotlib
# Use non-interactive backend for headless execution/environments
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from doctr.utils.visualization import visualize_page
def run_doctr(
image_path: str,
output_dir: str = "docTR_outputs",
det_arch: str = "db_resnet50",
reco_arch: str = "crnn_vgg16_bn",
) -> Path:
image_path = Path(image_path)
if not image_path.is_file():
raise FileNotFoundError(f"Image not found: {image_path}")
output_dir_path = Path(output_dir)
output_dir_path.mkdir(parents=True, exist_ok=True)
# Load image for visualization
bgr = cv2.imread(str(image_path))
if bgr is None:
raise RuntimeError(f"Failed to read image with OpenCV: {image_path}")
image = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
# Build predictor
predictor = ocr_predictor(det_arch=det_arch, reco_arch=reco_arch, pretrained=True)
# Inference
doc = DocumentFile.from_images(str(image_path))
result = predictor(doc)
# Export structured result to JSON
export = result.export()
json_path = output_dir_path / f"{image_path.stem}_doctr.json"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(export, f, ensure_ascii=False, indent=2)
# Visualization
page = result.pages[0]
page_dict = export["pages"][0]
fig = visualize_page(page_dict, image=image)
vis_path = output_dir_path / f"{image_path.stem}_doctr_vis.png"
fig.savefig(vis_path, dpi=200, bbox_inches="tight")
plt.close(fig)
# Write a simple text file with detected lines (if any)
lines = []
for block in page.blocks:
for line in block.lines:
text = " ".join([word.value for word in line.words])
if text:
lines.append(text)
if lines:
txt_path = output_dir_path / f"{image_path.stem}_doctr.txt"
txt_path.write_text("\n".join(lines), encoding="utf-8")
return vis_path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run docTR OCR on an image")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument(
"--output-dir",
default="docTR_outputs",
help="Directory to store outputs (JSON/visualization)",
)
parser.add_argument("--det-arch", default="db_resnet50", help="Detection architecture")
parser.add_argument("--reco-arch", default="crnn_vgg16_bn", help="Recognition architecture")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
vis_path = run_doctr(
image_path=args.image,
output_dir=args.output_dir,
det_arch=args.det_arch,
reco_arch=args.reco_arch,
)
print(f"Saved visualization to: {vis_path}")