This commit is contained in:
Nguyễn Phước Thành
2025-08-05 19:09:55 +07:00
commit 24060e4ce7
25 changed files with 2268 additions and 0 deletions

23
src/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
"""
Data Augmentation Package
"""
__version__ = "1.0.0"
__author__ = "OCR Data Augmentation Tool"
from .utils import *
from .image_processor import ImageProcessor
from .data_augmentation import DataAugmentation
from .config_manager import ConfigManager
__all__ = [
"ImageProcessor",
"DataAugmentation",
"ConfigManager",
"setup_logging",
"get_image_files",
"load_image",
"save_image",
"validate_image",
"print_progress",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

40
src/config.py Normal file
View File

@@ -0,0 +1,40 @@
"""
Configuration file for data augmentation
"""
import os
from pathlib import Path
# Paths
BASE_DIR = Path(__file__).parent.parent
DATA_DIR = BASE_DIR / "data"
INPUT_IMAGES_DIR = DATA_DIR / "dataset" / "training_data" / "images"
OUTPUT_DIR = DATA_DIR / "augmented_data"
# Data augmentation parameters
AUGMENTATION_CONFIG = {
"rotation_range": 15, # degrees
"width_shift_range": 0.1, # fraction of total width
"height_shift_range": 0.1, # fraction of total height
"brightness_range": [0.8, 1.2], # brightness factor
"zoom_range": [0.9, 1.1], # zoom factor
"horizontal_flip": True,
"vertical_flip": False,
"fill_mode": "nearest",
"cval": 0,
"rescale": 1./255,
}
# Processing parameters
PROCESSING_CONFIG = {
"target_size": (224, 224), # (width, height)
"batch_size": 32,
"num_augmentations": 3, # number of augmented versions per image
"save_format": "jpg",
"quality": 95,
}
# Supported image formats
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
# Create output directory if it doesn't exist
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

175
src/config_manager.py Normal file
View File

@@ -0,0 +1,175 @@
"""
Configuration manager for data augmentation
"""
import yaml
import os
from pathlib import Path
from typing import Dict, Any, Optional, Union
class ConfigManager:
"""Manages configuration loading and validation"""
def __init__(self, config_path: Optional[Union[str, Path]] = None):
"""
Initialize ConfigManager
Args:
config_path: Path to main config file
"""
self.config_path = Path(config_path) if config_path else Path("config/config.yaml")
self.config = {}
self._load_config()
def _load_config(self):
"""Load main configuration file"""
try:
if self.config_path.exists():
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
print(f"✅ Loaded configuration from {self.config_path}")
else:
print(f"⚠️ Config file not found: {self.config_path}")
self.config = self._get_default_config()
except Exception as e:
print(f"❌ Error loading config: {e}")
self.config = self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
"paths": {
"input_dir": "data/dataset/training_data/images",
"output_dir": "data/augmented_data",
"log_file": "logs/data_augmentation.log"
},
"augmentation": {
"rotation": {"enabled": True, "angles": [30, 60, 120, 150, 180, 210, 240, 300, 330], "probability": 1.0}
},
"processing": {
"target_size": [224, 224],
"batch_size": 32,
"num_augmentations": 3,
"save_format": "jpg",
"quality": 95
},
"supported_formats": [".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
"logging": {
"level": "INFO",
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
},
"performance": {
"num_workers": 4,
"prefetch_factor": 2,
"pin_memory": True,
"use_gpu": False
}
}
def get_config(self) -> Dict[str, Any]:
"""Get current configuration"""
return self.config
def get_augmentation_config(self) -> Dict[str, Any]:
"""Get augmentation configuration"""
return self.config.get("augmentation", {})
def get_processing_config(self) -> Dict[str, Any]:
"""Get processing configuration"""
return self.config.get("processing", {})
def get_paths_config(self) -> Dict[str, Any]:
"""Get paths configuration"""
return self.config.get("paths", {})
def get_logging_config(self) -> Dict[str, Any]:
"""Get logging configuration"""
return self.config.get("logging", {})
def get_performance_config(self) -> Dict[str, Any]:
"""Get performance configuration"""
return self.config.get("performance", {})
def update_config(self, updates: Dict[str, Any]) -> bool:
"""
Update configuration with new values
Args:
updates: Dictionary with updates to apply
Returns:
True if updated successfully
"""
try:
self.config = self._merge_configs(self.config, updates)
return True
except Exception as e:
print(f"❌ Error updating config: {e}")
return False
def _merge_configs(self, base_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
"""Merge updates with base configuration"""
merged = base_config.copy()
def deep_merge(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
result = base.copy()
for key, value in update.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
return deep_merge(merged, updates)
def save_config(self, output_path: Optional[Union[str, Path]] = None) -> bool:
"""
Save current configuration to file
Args:
output_path: Path to save config file
Returns:
True if saved successfully
"""
try:
output_path = Path(output_path) if output_path else self.config_path
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config, f, default_flow_style=False, indent=2, allow_unicode=True)
print(f"✅ Configuration saved to {output_path}")
return True
except Exception as e:
print(f"❌ Error saving config: {e}")
return False
def print_config_summary(self):
"""Print configuration summary"""
print("\n" + "="*50)
print("CONFIGURATION SUMMARY")
print("="*50)
# Paths
paths = self.get_paths_config()
print(f"Input directory: {paths.get('input_dir', 'Not set')}")
print(f"Output directory: {paths.get('output_dir', 'Not set')}")
# Processing
processing = self.get_processing_config()
print(f"Target size: {processing.get('target_size', 'Not set')}")
print(f"Number of augmentations: {processing.get('num_augmentations', 'Not set')}")
# Augmentation
augmentation = self.get_augmentation_config()
enabled_augmentations = []
for name, config in augmentation.items():
if isinstance(config, dict) and config.get('enabled', False):
enabled_augmentations.append(name)
print(f"Enabled augmentations: {', '.join(enabled_augmentations) if enabled_augmentations else 'None'}")
print("="*50)

161
src/data_augmentation.py Normal file
View File

@@ -0,0 +1,161 @@
"""
Data augmentation class for image augmentation - ONLY ROTATION
"""
import cv2
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import random
import math
from image_processor import ImageProcessor
from utils import load_image, save_image, create_augmented_filename, print_progress
class DataAugmentation:
"""Class for image data augmentation - ONLY ROTATION"""
def __init__(self, config: Dict[str, Any] = None):
"""
Initialize DataAugmentation
Args:
config: Configuration dictionary for augmentation parameters
"""
self.config = config or {}
self.image_processor = ImageProcessor()
def rotate_image(self, image: np.ndarray, angle: float) -> np.ndarray:
"""
Rotate image by given angle
Args:
image: Input image
angle: Rotation angle in degrees
Returns:
Rotated image
"""
height, width = image.shape[:2]
center = (width // 2, height // 2)
# Create rotation matrix
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
# Perform rotation
rotated = cv2.warpAffine(image, rotation_matrix, (width, height),
borderMode=cv2.BORDER_REPLICATE)
return rotated
def augment_single_image(self, image: np.ndarray, num_augmentations: int = None) -> List[np.ndarray]:
"""
Apply rotation augmentation to a single image
Args:
image: Input image
num_augmentations: Number of augmented versions to create
Returns:
List of augmented images
"""
num_augmentations = num_augmentations or 3 # Default value
augmented_images = []
# Get rotation configuration
rotation_config = self.config.get("rotation", {})
angles = rotation_config.get("angles", [30, 60, 120, 150, 180, 210, 240, 300, 330])
for i in range(num_augmentations):
augmented = image.copy()
# Apply rotation with random angle from the specified list
if rotation_config.get("enabled", False):
angle = random.choice(angles)
augmented = self.rotate_image(augmented, angle)
augmented_images.append(augmented)
return augmented_images
def augment_image_file(self, image_path: Path, output_dir: Path, num_augmentations: int = None) -> List[Path]:
"""
Augment a single image file and save results
Args:
image_path: Path to input image
output_dir: Output directory for augmented images
num_augmentations: Number of augmented versions to create
Returns:
List of paths to saved augmented images
"""
# Load image
image = load_image(image_path, self.image_processor.target_size)
if image is None:
return []
# Apply augmentations
augmented_images = self.augment_single_image(image, num_augmentations)
# Save augmented images
saved_paths = []
for i, aug_image in enumerate(augmented_images):
# Create output filename
output_filename = create_augmented_filename(image_path, i + 1)
output_path = output_dir / output_filename.name
# Save image
if save_image(aug_image, output_path):
saved_paths.append(output_path)
return saved_paths
def batch_augment(self, input_dir: Path, output_dir: Path, num_augmentations: int = None) -> Dict[str, List[Path]]:
"""
Augment all images in a directory
Args:
input_dir: Input directory containing images
output_dir: Output directory for augmented images
num_augmentations: Number of augmented versions per image
Returns:
Dictionary mapping original images to their augmented versions
"""
from utils import get_image_files
image_files = get_image_files(input_dir)
results = {}
print(f"Found {len(image_files)} images to augment")
for i, image_path in enumerate(image_files):
print_progress(i + 1, len(image_files), "Augmenting images")
# Augment single image
augmented_paths = self.augment_image_file(image_path, output_dir, num_augmentations)
if augmented_paths:
results[str(image_path)] = augmented_paths
print(f"\nAugmented {len(results)} images successfully")
return results
def get_augmentation_summary(self, results: Dict[str, List[Path]]) -> Dict[str, Any]:
"""
Get summary of augmentation results
Args:
results: Results from batch_augment
Returns:
Summary dictionary
"""
total_original = len(results)
total_augmented = sum(len(paths) for paths in results.values())
return {
"total_original_images": total_original,
"total_augmented_images": total_augmented,
"augmentation_ratio": total_augmented / total_original if total_original > 0 else 0,
"successful_augmentations": len([paths for paths in results.values() if paths])
}

174
src/image_processor.py Normal file
View File

@@ -0,0 +1,174 @@
"""
Image processing class for basic image operations
"""
import cv2
import numpy as np
from pathlib import Path
from typing import Tuple, Optional, List
from utils import load_image, save_image, validate_image, get_image_files
class ImageProcessor:
"""Class for basic image processing operations"""
def __init__(self, target_size: Tuple[int, int] = None):
"""
Initialize ImageProcessor
Args:
target_size: Target size for image resizing (width, height)
"""
self.target_size = target_size or (224, 224) # Default size
def load_and_preprocess(self, image_path: Path) -> Optional[np.ndarray]:
"""
Load and preprocess image
Args:
image_path: Path to image file
Returns:
Preprocessed image as numpy array or None if failed
"""
if not validate_image(image_path):
print(f"Invalid image file: {image_path}")
return None
image = load_image(image_path, self.target_size)
if image is None:
return None
# Normalize pixel values
image = image.astype(np.float32) / 255.0
return image
def resize_image(self, image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
"""
Resize image to target size
Args:
image: Input image as numpy array
target_size: Target size (width, height)
Returns:
Resized image
"""
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
def normalize_image(self, image: np.ndarray) -> np.ndarray:
"""
Normalize image pixel values to [0, 1]
Args:
image: Input image
Returns:
Normalized image
"""
return image.astype(np.float32) / 255.0
def denormalize_image(self, image: np.ndarray) -> np.ndarray:
"""
Denormalize image pixel values to [0, 255]
Args:
image: Input image (normalized)
Returns:
Denormalized image
"""
return (image * 255).astype(np.uint8)
def get_image_info(self, image_path: Path) -> dict:
"""
Get information about image
Args:
image_path: Path to image file
Returns:
Dictionary containing image information
"""
try:
image = cv2.imread(str(image_path))
if image is None:
return {}
height, width, channels = image.shape
file_size = image_path.stat().st_size / (1024 * 1024) # MB
return {
"path": str(image_path),
"width": width,
"height": height,
"channels": channels,
"file_size_mb": round(file_size, 2),
"format": image_path.suffix
}
except Exception as e:
print(f"Error getting image info for {image_path}: {e}")
return {}
def batch_process_images(self, input_dir: Path, output_dir: Path) -> List[Path]:
"""
Process all images in a directory
Args:
input_dir: Input directory containing images
output_dir: Output directory for processed images
Returns:
List of processed image paths
"""
image_files = get_image_files(input_dir)
processed_files = []
print(f"Found {len(image_files)} images to process")
for i, image_path in enumerate(image_files):
print_progress(i + 1, len(image_files), "Processing images")
# Load and preprocess image
image = self.load_and_preprocess(image_path)
if image is None:
continue
# Create output path
output_path = output_dir / image_path.name
# Denormalize for saving
image = self.denormalize_image(image)
# Save processed image
if save_image(image, output_path):
processed_files.append(output_path)
print(f"\nProcessed {len(processed_files)} images successfully")
return processed_files
def create_thumbnail(self, image: np.ndarray, size: Tuple[int, int] = (100, 100)) -> np.ndarray:
"""
Create thumbnail of image
Args:
image: Input image
size: Thumbnail size (width, height)
Returns:
Thumbnail image
"""
return cv2.resize(image, size, interpolation=cv2.INTER_AREA)
def convert_to_grayscale(self, image: np.ndarray) -> np.ndarray:
"""
Convert image to grayscale
Args:
image: Input image (RGB)
Returns:
Grayscale image
"""
if len(image.shape) == 3:
return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return image

8
src/model/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Model module for YOLO-based ID card detection and cropping
"""
from .yolo_detector import YOLODetector
from .id_card_processor import IDCardProcessor
__all__ = ['YOLODetector', 'IDCardProcessor']

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,343 @@
"""
ID Card Processor for background removal and preprocessing
"""
import cv2
import numpy as np
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
import logging
from .yolo_detector import YOLODetector
class IDCardProcessor:
"""
ID Card Processor for background removal and preprocessing
"""
def __init__(self, yolo_detector: Optional[YOLODetector] = None):
"""
Initialize ID Card Processor
Args:
yolo_detector: YOLO detector instance
"""
self.yolo_detector = yolo_detector or YOLODetector()
self.logger = logging.getLogger(__name__)
def remove_background(self, image: np.ndarray, method: str = 'grabcut') -> np.ndarray:
"""
Remove background from image
Args:
image: Input image
method: Background removal method ('grabcut', 'threshold', 'contour')
Returns:
Image with background removed
"""
if method == 'grabcut':
return self._grabcut_background_removal(image)
elif method == 'threshold':
return self._threshold_background_removal(image)
elif method == 'contour':
return self._contour_background_removal(image)
else:
self.logger.warning(f"Unknown method: {method}, using grabcut")
return self._grabcut_background_removal(image)
def _grabcut_background_removal(self, image: np.ndarray) -> np.ndarray:
"""
Remove background using GrabCut algorithm
"""
try:
# Create mask
mask = np.zeros(image.shape[:2], np.uint8)
# Create temporary arrays
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
# Define rectangle (assuming ID card is in center)
height, width = image.shape[:2]
rect = (width//8, height//8, width*3//4, height*3//4)
# Apply GrabCut
cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
# Create mask
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
# Apply mask
result = image * mask2[:, :, np.newaxis]
return result
except Exception as e:
self.logger.error(f"Error in grabcut background removal: {e}")
return image
def _threshold_background_removal(self, image: np.ndarray) -> np.ndarray:
"""
Remove background using thresholding
"""
try:
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply Gaussian blur
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# Apply threshold
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Find contours
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find largest contour (assumed to be the ID card)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
# Create mask
mask = np.zeros_like(gray)
cv2.fillPoly(mask, [largest_contour], 255)
# Apply mask
result = cv2.bitwise_and(image, image, mask=mask)
return result
return image
except Exception as e:
self.logger.error(f"Error in threshold background removal: {e}")
return image
def _contour_background_removal(self, image: np.ndarray) -> np.ndarray:
"""
Remove background using contour detection
"""
try:
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply edge detection
edges = cv2.Canny(gray, 50, 150)
# Find contours
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find largest contour
if contours:
largest_contour = max(contours, key=cv2.contourArea)
# Approximate contour to get rectangle
epsilon = 0.02 * cv2.arcLength(largest_contour, True)
approx = cv2.approxPolyDP(largest_contour, epsilon, True)
# Create mask
mask = np.zeros_like(gray)
cv2.fillPoly(mask, [approx], 255)
# Apply mask
result = cv2.bitwise_and(image, image, mask=mask)
return result
return image
except Exception as e:
self.logger.error(f"Error in contour background removal: {e}")
return image
def enhance_image(self, image: np.ndarray) -> np.ndarray:
"""
Enhance image quality for better OCR
"""
try:
# Convert to LAB color space
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
# Apply CLAHE to L channel
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
# Convert back to BGR
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
# Apply slight Gaussian blur to reduce noise
enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)
return enhanced
except Exception as e:
self.logger.error(f"Error enhancing image: {e}")
return image
def normalize_image(self, image: np.ndarray, target_size: Tuple[int, int] = (800, 600)) -> np.ndarray:
"""
Normalize image size and orientation
"""
try:
# Resize image
resized = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
# Convert to grayscale if needed
if len(resized.shape) == 3:
gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
else:
gray = resized
# Apply histogram equalization
equalized = cv2.equalizeHist(gray)
# Convert back to BGR for consistency
if len(image.shape) == 3:
result = cv2.cvtColor(equalized, cv2.COLOR_GRAY2BGR)
else:
result = equalized
return result
except Exception as e:
self.logger.error(f"Error normalizing image: {e}")
return image
def process_id_card(self, image_path: Path, output_dir: Path,
remove_bg: bool = True, enhance: bool = True,
normalize: bool = True, target_size: Tuple[int, int] = (800, 600)) -> Dict[str, Any]:
"""
Process a single ID card image
Args:
image_path: Path to input image
output_dir: Output directory
remove_bg: Whether to remove background
enhance: Whether to enhance image
normalize: Whether to normalize image
target_size: Target size for normalization
Returns:
Processing results
"""
result = {
'input_path': str(image_path),
'output_paths': [],
'success': False
}
try:
# Load image
image = cv2.imread(str(image_path))
if image is None:
self.logger.error(f"Could not load image: {image_path}")
return result
# Create output filename
stem = image_path.stem
processed_path = output_dir / f"{stem}_processed.jpg"
# Apply processing steps
processed_image = image.copy()
if remove_bg:
self.logger.info(f"Removing background from {image_path.name}")
processed_image = self.remove_background(processed_image)
if enhance:
self.logger.info(f"Enhancing {image_path.name}")
processed_image = self.enhance_image(processed_image)
if normalize:
self.logger.info(f"Normalizing {image_path.name}")
processed_image = self.normalize_image(processed_image, target_size)
# Save processed image
processed_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(processed_path), processed_image)
result['output_paths'].append(str(processed_path))
result['success'] = True
self.logger.info(f"Processed {image_path.name}")
except Exception as e:
self.logger.error(f"Error processing {image_path}: {e}")
return result
def batch_process_id_cards(self, input_dir: Path, output_dir: Path,
detect_first: bool = True, **kwargs) -> Dict[str, Any]:
"""
Process all ID card images in a directory
Args:
input_dir: Input directory
output_dir: Output directory
detect_first: Whether to detect ID cards first using YOLO
**kwargs: Additional arguments for processing
Returns:
Batch processing results
"""
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
if detect_first:
# First detect and crop ID cards
self.logger.info("Detecting and cropping ID cards...")
detection_results = self.yolo_detector.batch_process(input_dir, output_dir / "cropped")
# Process cropped images
cropped_dir = output_dir / "cropped"
if cropped_dir.exists():
self.logger.info("Processing cropped ID cards...")
return self._process_cropped_images(cropped_dir, output_dir / "processed", **kwargs)
else:
self.logger.warning("No cropped images found, processing original images")
return self._process_cropped_images(input_dir, output_dir / "processed", **kwargs)
else:
# Process original images directly
return self._process_cropped_images(input_dir, output_dir / "processed", **kwargs)
def _process_cropped_images(self, input_dir: Path, output_dir: Path, **kwargs) -> Dict[str, Any]:
"""
Process cropped ID card images recursively
"""
# Get all image files recursively from input directory and subdirectories
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_files = []
# Recursively find all image files
for file_path in input_dir.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
image_files.append(file_path)
if not image_files:
self.logger.error(f"No images found in {input_dir} and subdirectories")
return {'success': False, 'error': 'No images found'}
self.logger.info(f"Processing {len(image_files)} images from {input_dir} and subdirectories")
results = {
'total_images': len(image_files),
'processed_images': 0,
'results': []
}
# Process each image
for i, image_path in enumerate(image_files):
self.logger.info(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
# Create subdirectory structure in output to match input structure
relative_path = image_path.relative_to(input_dir)
output_subdir = output_dir / relative_path.parent
output_subdir.mkdir(parents=True, exist_ok=True)
result = self.process_id_card(image_path, output_subdir, **kwargs)
results['results'].append(result)
if result['success']:
results['processed_images'] += 1
# Summary
self.logger.info(f"ID card processing completed:")
self.logger.info(f" - Total images: {results['total_images']}")
self.logger.info(f" - Processed: {results['processed_images']}")
return results

266
src/model/yolo_detector.py Normal file
View File

@@ -0,0 +1,266 @@
"""
YOLO Detector for ID Card Detection and Cropping
"""
import cv2
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import logging
from ultralytics import YOLO
import torch
class YOLODetector:
"""
YOLO-based detector for ID card detection and cropping
"""
def __init__(self, model_path: Optional[str] = None, confidence: float = 0.5):
"""
Initialize YOLO detector
Args:
model_path: Path to YOLO model file (.pt)
confidence: Confidence threshold for detection
"""
self.confidence = confidence
self.logger = logging.getLogger(__name__)
# Initialize model
if model_path and Path(model_path).exists():
self.model = YOLO(model_path)
self.logger.info(f"Loaded custom YOLO model from {model_path}")
else:
# Use pre-trained YOLO model for general object detection
self.model = YOLO('yolov8n.pt')
self.logger.info("Using pre-trained YOLOv8n model")
# Set device
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.logger.info(f"Using device: {self.device}")
def detect_id_cards(self, image_path: Path) -> List[Dict[str, Any]]:
"""
Detect ID cards in an image
Args:
image_path: Path to image file
Returns:
List of detection results with bounding boxes
"""
try:
# Load image
image = cv2.imread(str(image_path))
if image is None:
self.logger.error(f"Could not load image: {image_path}")
return []
# Run detection
results = self.model(image, conf=self.confidence)
detections = []
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get coordinates
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
confidence = float(box.conf[0])
class_id = int(box.cls[0])
class_name = self.model.names[class_id]
detection = {
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'confidence': confidence,
'class_id': class_id,
'class_name': class_name,
'area': (x2 - x1) * (y2 - y1)
}
detections.append(detection)
# Sort by confidence and area (prefer larger, more confident detections)
detections.sort(key=lambda x: (x['confidence'], x['area']), reverse=True)
self.logger.info(f"Found {len(detections)} detections in {image_path.name}")
return detections
except Exception as e:
self.logger.error(f"Error detecting ID cards in {image_path}: {e}")
return []
def crop_id_card(self, image_path: Path, bbox: List[int],
output_path: Optional[Path] = None,
padding: int = 10) -> Optional[np.ndarray]:
"""
Crop ID card from image using bounding box
Args:
image_path: Path to input image
bbox: Bounding box [x1, y1, x2, y2]
output_path: Path to save cropped image
padding: Padding around the bounding box
Returns:
Cropped image as numpy array
"""
try:
# Load image
image = cv2.imread(str(image_path))
if image is None:
self.logger.error(f"Could not load image: {image_path}")
return None
height, width = image.shape[:2]
x1, y1, x2, y2 = bbox
# Add padding
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(width, x2 + padding)
y2 = min(height, y2 + padding)
# Crop image
cropped = image[y1:y2, x1:x2]
# Save if output path provided
if output_path:
output_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(output_path), cropped)
self.logger.info(f"Saved cropped image to {output_path}")
return cropped
except Exception as e:
self.logger.error(f"Error cropping ID card from {image_path}: {e}")
return None
def process_single_image(self, image_path: Path, output_dir: Path,
save_original: bool = False) -> Dict[str, Any]:
"""
Process a single image: detect and crop ID cards
Args:
image_path: Path to input image
output_dir: Output directory for cropped images
save_original: Whether to save original image with bounding boxes
Returns:
Processing results
"""
result = {
'input_path': str(image_path),
'detections': [],
'cropped_paths': [],
'success': False
}
try:
# Detect ID cards
detections = self.detect_id_cards(image_path)
if not detections:
self.logger.warning(f"No ID cards detected in {image_path.name}")
return result
# Process each detection
for i, detection in enumerate(detections):
bbox = detection['bbox']
# Create output filename
stem = image_path.stem
suffix = f"_card_{i+1}.jpg"
output_path = output_dir / f"{stem}{suffix}"
# Crop ID card
cropped = self.crop_id_card(image_path, bbox, output_path)
if cropped is not None:
result['detections'].append(detection)
result['cropped_paths'].append(str(output_path))
# Save original with bounding boxes if requested
if save_original and detections:
image = cv2.imread(str(image_path))
for detection in detections:
bbox = detection['bbox']
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
cv2.putText(image, f"{detection['confidence']:.2f}",
(bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (0, 255, 0), 2)
annotated_path = output_dir / f"{image_path.stem}_annotated.jpg"
cv2.imwrite(str(annotated_path), image)
result['annotated_path'] = str(annotated_path)
result['success'] = True
self.logger.info(f"Processed {image_path.name}: {len(result['cropped_paths'])} cards cropped")
except Exception as e:
self.logger.error(f"Error processing {image_path}: {e}")
return result
def batch_process(self, input_dir: Path, output_dir: Path,
save_annotated: bool = False) -> Dict[str, Any]:
"""
Process all images in a directory and subdirectories
Args:
input_dir: Input directory containing images
output_dir: Output directory for cropped images
save_annotated: Whether to save annotated images
Returns:
Batch processing results
"""
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Get all image files recursively from input directory and subdirectories
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_files = []
# Recursively find all image files
for file_path in input_dir.rglob('*'):
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
image_files.append(file_path)
if not image_files:
self.logger.error(f"No images found in {input_dir} and subdirectories")
return {'success': False, 'error': 'No images found'}
self.logger.info(f"Processing {len(image_files)} images from {input_dir} and subdirectories")
results = {
'total_images': len(image_files),
'processed_images': 0,
'total_detections': 0,
'total_cropped': 0,
'results': []
}
# Process each image
for i, image_path in enumerate(image_files):
self.logger.info(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
# Create subdirectory structure in output to match input structure
relative_path = image_path.relative_to(input_dir)
output_subdir = output_dir / relative_path.parent
output_subdir.mkdir(parents=True, exist_ok=True)
result = self.process_single_image(image_path, output_subdir, save_annotated)
results['results'].append(result)
if result['success']:
results['processed_images'] += 1
results['total_detections'] += len(result['detections'])
results['total_cropped'] += len(result['cropped_paths'])
# Summary
self.logger.info(f"Batch processing completed:")
self.logger.info(f" - Total images: {results['total_images']}")
self.logger.info(f" - Processed: {results['processed_images']}")
self.logger.info(f" - Total detections: {results['total_detections']}")
self.logger.info(f" - Total cropped: {results['total_cropped']}")
return results

98
src/utils.py Normal file
View File

@@ -0,0 +1,98 @@
"""
Utility functions for data augmentation
"""
import os
import logging
from pathlib import Path
from typing import List, Tuple, Optional
import cv2
import numpy as np
from PIL import Image
def setup_logging(log_level: str = "INFO") -> logging.Logger:
"""Setup logging configuration"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('data_augmentation.log'),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
def get_image_files(directory: Path) -> List[Path]:
"""Get all image files from directory"""
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_files = []
if directory.exists():
for ext in SUPPORTED_FORMATS:
image_files.extend(directory.glob(f"*{ext}"))
image_files.extend(directory.glob(f"*{ext.upper()}"))
return sorted(image_files)
def validate_image(image_path: Path) -> bool:
"""Validate if file is a valid image"""
try:
with Image.open(image_path) as img:
img.verify()
return True
except Exception:
return False
def load_image(image_path: Path, target_size: Tuple[int, int] = None) -> Optional[np.ndarray]:
"""Load and resize image"""
try:
# Load image using OpenCV
image = cv2.imread(str(image_path))
if image is None:
return None
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize if target_size is provided
if target_size:
image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
return image
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
def save_image(image: np.ndarray, output_path: Path, quality: int = 95) -> bool:
"""Save image to file"""
try:
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Create output directory if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save image
cv2.imwrite(str(output_path), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
return True
except Exception as e:
print(f"Error saving image {output_path}: {e}")
return False
def create_augmented_filename(original_path: Path, index: int, suffix: str = "aug") -> Path:
"""Create filename for augmented image"""
stem = original_path.stem
suffix = f"_{suffix}_{index:02d}"
return original_path.parent / f"{stem}{suffix}{original_path.suffix}"
def get_file_size_mb(file_path: Path) -> float:
"""Get file size in MB"""
return file_path.stat().st_size / (1024 * 1024)
def print_progress(current: int, total: int, prefix: str = "Progress"):
"""Print progress bar"""
bar_length = 50
filled_length = int(round(bar_length * current / float(total)))
percents = round(100.0 * current / float(total), 1)
bar = '=' * filled_length + '-' * (bar_length - filled_length)
print(f'\r{prefix}: [{bar}] {percents}% ({current}/{total})', end='')
if current == total:
print()