YOLO crop model
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
"""
|
||||
Model module for YOLO-based ID card detection and cropping
|
||||
Model module for Roboflow-based ID card detection and cropping
|
||||
"""
|
||||
|
||||
from .yolo_detector import YOLODetector
|
||||
from .id_card_processor import IDCardProcessor
|
||||
from .roboflow_id_detector import RoboflowIDDetector
|
||||
|
||||
__all__ = ['YOLODetector', 'IDCardProcessor']
|
||||
__all__ = ['RoboflowIDDetector']
|
BIN
src/model/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
src/model/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
src/model/__pycache__/roboflow_id_detector.cpython-313.pyc
Normal file
BIN
src/model/__pycache__/roboflow_id_detector.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/model/__pycache__/roboflow_id_detector.cpython-39.pyc
Normal file
BIN
src/model/__pycache__/roboflow_id_detector.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
@@ -1,46 +1,104 @@
|
||||
"""
|
||||
YOLO Detector for ID Card Detection and Cropping
|
||||
Roboflow ID Card Detector using French Card ID Detection Model
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import logging
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
class YOLODetector:
|
||||
class RoboflowIDDetector:
|
||||
"""
|
||||
YOLO-based detector for ID card detection and cropping
|
||||
Roboflow-based detector for French ID card detection using the french-card-id-detect model
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: Optional[str] = None, confidence: float = 0.5):
|
||||
def __init__(self, api_key: str, model_id: str = "french-card-id-detect",
|
||||
version: int = 3, confidence: float = 0.5):
|
||||
"""
|
||||
Initialize YOLO detector
|
||||
Initialize Roboflow ID detector
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model file (.pt)
|
||||
api_key: Roboflow API key
|
||||
model_id: Model identifier (default: french-card-id-detect)
|
||||
version: Model version (default: 3)
|
||||
confidence: Confidence threshold for detection
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model_id = model_id
|
||||
self.version = version
|
||||
self.confidence = confidence
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize model
|
||||
if model_path and Path(model_path).exists():
|
||||
self.model = YOLO(model_path)
|
||||
self.logger.info(f"Loaded custom YOLO model from {model_path}")
|
||||
else:
|
||||
# Use pre-trained YOLO model for general object detection
|
||||
self.model = YOLO('yolov8n.pt')
|
||||
self.logger.info("Using pre-trained YOLOv8n model")
|
||||
# API endpoint
|
||||
self.api_url = f"https://serverless.roboflow.com/{model_id}/{version}"
|
||||
|
||||
# Set device
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.logger.info(f"Using device: {self.device}")
|
||||
self.logger.info(f"Initialized Roboflow ID detector with model: {model_id}/{version}")
|
||||
|
||||
def _encode_image(self, image_path: Path) -> str:
|
||||
"""
|
||||
Encode image to base64
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_string
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error encoding image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def _make_api_request(self, image_data: str, image_name: str = "image.jpg") -> Optional[Dict]:
|
||||
"""
|
||||
Make API request to Roboflow
|
||||
|
||||
Args:
|
||||
image_data: Base64 encoded image data
|
||||
image_name: Name of the image file
|
||||
|
||||
Returns:
|
||||
API response as dictionary
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
}
|
||||
|
||||
params = {
|
||||
'api_key': self.api_key,
|
||||
'name': image_name
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
params=params,
|
||||
data=image_data,
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
self.logger.error(f"API request failed with status {response.status_code}: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error making API request: {e}")
|
||||
return None
|
||||
|
||||
def detect_id_cards(self, image_path: Path) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect ID cards in an image
|
||||
Detect ID cards in an image using Roboflow API
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
@@ -49,39 +107,50 @@ class YOLODetector:
|
||||
List of detection results with bounding boxes
|
||||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
self.logger.error(f"Could not load image: {image_path}")
|
||||
# Encode image
|
||||
image_data = self._encode_image(image_path)
|
||||
if not image_data:
|
||||
return []
|
||||
|
||||
# Run detection
|
||||
results = self.model(image, conf=self.confidence)
|
||||
# Make API request
|
||||
response = self._make_api_request(image_data, image_path.name)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Get coordinates
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
confidence = float(box.conf[0])
|
||||
class_id = int(box.cls[0])
|
||||
class_name = self.model.names[class_id]
|
||||
|
||||
detection = {
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'area': (x2 - x1) * (y2 - y1)
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
# Sort by confidence and area (prefer larger, more confident detections)
|
||||
# Parse predictions from response
|
||||
if 'predictions' in response:
|
||||
for prediction in response['predictions']:
|
||||
# Check confidence threshold
|
||||
if prediction.get('confidence', 0) < self.confidence:
|
||||
continue
|
||||
|
||||
# Extract bounding box coordinates
|
||||
x = prediction.get('x', 0)
|
||||
y = prediction.get('y', 0)
|
||||
width = prediction.get('width', 0)
|
||||
height = prediction.get('height', 0)
|
||||
|
||||
# Convert to [x1, y1, x2, y2] format
|
||||
x1 = int(x - width / 2)
|
||||
y1 = int(y - height / 2)
|
||||
x2 = int(x + width / 2)
|
||||
y2 = int(y + height / 2)
|
||||
|
||||
detection = {
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'confidence': prediction.get('confidence', 0),
|
||||
'class_id': prediction.get('class_id', 0),
|
||||
'class_name': prediction.get('class', 'id_card'),
|
||||
'area': width * height
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
# Sort by confidence and area
|
||||
detections.sort(key=lambda x: (x['confidence'], x['area']), reverse=True)
|
||||
|
||||
self.logger.info(f"Found {len(detections)} detections in {image_path.name}")
|
||||
self.logger.info(f"Found {len(detections)} ID card detections in {image_path.name}")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
@@ -201,7 +270,7 @@ class YOLODetector:
|
||||
return result
|
||||
|
||||
def batch_process(self, input_dir: Path, output_dir: Path,
|
||||
save_annotated: bool = False) -> Dict[str, Any]:
|
||||
save_annotated: bool = False, delay: float = 1.0) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all images in a directory and subdirectories
|
||||
|
||||
@@ -209,6 +278,7 @@ class YOLODetector:
|
||||
input_dir: Input directory containing images
|
||||
output_dir: Output directory for cropped images
|
||||
save_annotated: Whether to save annotated images
|
||||
delay: Delay between API requests (seconds)
|
||||
|
||||
Returns:
|
||||
Batch processing results
|
||||
@@ -216,11 +286,10 @@ class YOLODetector:
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get all image files recursively from input directory and subdirectories
|
||||
# Get all image files recursively
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
||||
image_files = []
|
||||
|
||||
# Recursively find all image files
|
||||
for file_path in input_dir.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
|
||||
image_files.append(file_path)
|
||||
@@ -255,6 +324,10 @@ class YOLODetector:
|
||||
results['processed_images'] += 1
|
||||
results['total_detections'] += len(result['detections'])
|
||||
results['total_cropped'] += len(result['cropped_paths'])
|
||||
|
||||
# Add delay between requests to avoid rate limiting
|
||||
if i < len(image_files) - 1: # Don't delay after the last image
|
||||
time.sleep(delay)
|
||||
|
||||
# Summary
|
||||
self.logger.info(f"Batch processing completed:")
|
Reference in New Issue
Block a user