Files
IDcardsGenerator/src/model/ID_cards_detector/inference.py
Nguyễn Phước Thành 4ee14f17d3 update ID_cards_detector
2025-08-06 19:03:17 +07:00

130 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""
YOLOv8 Inference Script for French ID Card Detection
"""
import os
import sys
import argparse
from pathlib import Path
import logging
# Import config
from config import (
INFERENCE_RESULTS_DIR, EVALUATION_RESULTS_DIR,
VISUALIZATION_RESULTS_DIR, create_directories, get_best_model_path
)
# Create necessary directories first
create_directories()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Import modules
from modules.inference import YOLOv8Inference
def main():
"""Main function"""
parser = argparse.ArgumentParser(description='YOLOv8 Inference for French ID Card Detection')
parser.add_argument('--model', type=str, default=None,
help='Path to trained model (if None, uses best model from runs/train)')
parser.add_argument('--model-size', type=str, default='n',
help='Model size (n, s, m, l, x) - used when --model is not specified')
parser.add_argument('--input', type=str, required=True,
help='Input image or directory')
parser.add_argument('--output', type=str, default=None,
help='Output directory (uses default if not specified)')
parser.add_argument('--conf', type=float, default=0.25,
help='Confidence threshold')
parser.add_argument('--iou', type=float, default=0.45,
help='IoU threshold')
parser.add_argument('--batch', action='store_true',
help='Process as batch (input is directory)')
parser.add_argument('--evaluate', action='store_true',
help='Evaluate on test set')
parser.add_argument('--export', type=str, default=None,
help='Export results to JSON file')
parser.add_argument('--visualize', action='store_true',
help='Create visualizations')
args = parser.parse_args()
logger.info("=" * 60)
logger.info("YOLOv8 French ID Card Detection Inference")
logger.info("=" * 60)
try:
# Get model path
if args.model:
model_path = args.model
else:
model_path = get_best_model_path(args.model_size)
if not model_path:
logger.error("[ERROR] No trained model found. Please train a model first.")
sys.exit(1)
# Initialize inference
logger.info(f"Loading model: {model_path}")
inference = YOLOv8Inference(model_path, args.conf, args.iou)
# Set output directory
output_dir = args.output if args.output else INFERENCE_RESULTS_DIR
if args.batch or Path(args.input).is_dir():
# Batch processing
logger.info(f"Processing batch from: {args.input}")
results = inference.predict_batch(args.input, output_dir)
else:
# Single image processing
logger.info(f"Processing single image: {args.input}")
result = inference.predict_single_image(args.input, True, output_dir)
results = {'results': [result]}
# Evaluate if requested
if args.evaluate:
logger.info("Evaluating on test set...")
evaluation_results = inference.evaluate_on_test_set(args.input)
results.update(evaluation_results)
# Export results
if args.export:
logger.info(f"Exporting results to {args.export}")
inference.export_results(results, args.export)
# Create visualizations
if args.visualize:
logger.info("Creating visualizations...")
for result in results['results']:
if result['detections']:
save_path = VISUALIZATION_RESULTS_DIR / f"viz_{Path(result['image_path']).stem}.png"
inference.visualize_detections(
result['image_path'],
result['detections'],
str(save_path)
)
logger.info("\n" + "=" * 60)
logger.info("[SUCCESS] Inference completed successfully!")
logger.info("=" * 60)
# Summary
total_images = results.get('total_images', len(results['results']))
processed_images = results.get('processed_images', len(results['results']))
total_detections = sum(len(r['detections']) for r in results['results'])
logger.info(f"\n[INFO] Results summary:")
logger.info(f" - Total images: {total_images}")
logger.info(f" - Processed: {processed_images}")
logger.info(f" - Total detections: {total_detections}")
logger.info(f" - Output directory: {output_dir}")
except Exception as e:
logger.error(f"[ERROR] Error: {e}")
sys.exit(1)
if __name__ == '__main__':
main()