update ID_cards_detector
This commit is contained in:
130
src/model/ID_cards_detector/inference.py
Normal file
130
src/model/ID_cards_detector/inference.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
YOLOv8 Inference Script for French ID Card Detection
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Import config
|
||||
from config import (
|
||||
INFERENCE_RESULTS_DIR, EVALUATION_RESULTS_DIR,
|
||||
VISUALIZATION_RESULTS_DIR, create_directories, get_best_model_path
|
||||
)
|
||||
|
||||
# Create necessary directories first
|
||||
create_directories()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import modules
|
||||
from modules.inference import YOLOv8Inference
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description='YOLOv8 Inference for French ID Card Detection')
|
||||
parser.add_argument('--model', type=str, default=None,
|
||||
help='Path to trained model (if None, uses best model from runs/train)')
|
||||
parser.add_argument('--model-size', type=str, default='n',
|
||||
help='Model size (n, s, m, l, x) - used when --model is not specified')
|
||||
parser.add_argument('--input', type=str, required=True,
|
||||
help='Input image or directory')
|
||||
parser.add_argument('--output', type=str, default=None,
|
||||
help='Output directory (uses default if not specified)')
|
||||
parser.add_argument('--conf', type=float, default=0.25,
|
||||
help='Confidence threshold')
|
||||
parser.add_argument('--iou', type=float, default=0.45,
|
||||
help='IoU threshold')
|
||||
parser.add_argument('--batch', action='store_true',
|
||||
help='Process as batch (input is directory)')
|
||||
parser.add_argument('--evaluate', action='store_true',
|
||||
help='Evaluate on test set')
|
||||
parser.add_argument('--export', type=str, default=None,
|
||||
help='Export results to JSON file')
|
||||
parser.add_argument('--visualize', action='store_true',
|
||||
help='Create visualizations')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("YOLOv8 French ID Card Detection Inference")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
# Get model path
|
||||
if args.model:
|
||||
model_path = args.model
|
||||
else:
|
||||
model_path = get_best_model_path(args.model_size)
|
||||
if not model_path:
|
||||
logger.error("[ERROR] No trained model found. Please train a model first.")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize inference
|
||||
logger.info(f"Loading model: {model_path}")
|
||||
inference = YOLOv8Inference(model_path, args.conf, args.iou)
|
||||
|
||||
# Set output directory
|
||||
output_dir = args.output if args.output else INFERENCE_RESULTS_DIR
|
||||
|
||||
if args.batch or Path(args.input).is_dir():
|
||||
# Batch processing
|
||||
logger.info(f"Processing batch from: {args.input}")
|
||||
results = inference.predict_batch(args.input, output_dir)
|
||||
else:
|
||||
# Single image processing
|
||||
logger.info(f"Processing single image: {args.input}")
|
||||
result = inference.predict_single_image(args.input, True, output_dir)
|
||||
results = {'results': [result]}
|
||||
|
||||
# Evaluate if requested
|
||||
if args.evaluate:
|
||||
logger.info("Evaluating on test set...")
|
||||
evaluation_results = inference.evaluate_on_test_set(args.input)
|
||||
results.update(evaluation_results)
|
||||
|
||||
# Export results
|
||||
if args.export:
|
||||
logger.info(f"Exporting results to {args.export}")
|
||||
inference.export_results(results, args.export)
|
||||
|
||||
# Create visualizations
|
||||
if args.visualize:
|
||||
logger.info("Creating visualizations...")
|
||||
for result in results['results']:
|
||||
if result['detections']:
|
||||
save_path = VISUALIZATION_RESULTS_DIR / f"viz_{Path(result['image_path']).stem}.png"
|
||||
inference.visualize_detections(
|
||||
result['image_path'],
|
||||
result['detections'],
|
||||
str(save_path)
|
||||
)
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("[SUCCESS] Inference completed successfully!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Summary
|
||||
total_images = results.get('total_images', len(results['results']))
|
||||
processed_images = results.get('processed_images', len(results['results']))
|
||||
total_detections = sum(len(r['detections']) for r in results['results'])
|
||||
|
||||
logger.info(f"\n[INFO] Results summary:")
|
||||
logger.info(f" - Total images: {total_images}")
|
||||
logger.info(f" - Processed: {processed_images}")
|
||||
logger.info(f" - Total detections: {total_detections}")
|
||||
logger.info(f" - Output directory: {output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user