init
This commit is contained in:
23
src/__init__.py
Normal file
23
src/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Data Augmentation Package
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "OCR Data Augmentation Tool"
|
||||
|
||||
from .utils import *
|
||||
from .image_processor import ImageProcessor
|
||||
from .data_augmentation import DataAugmentation
|
||||
from .config_manager import ConfigManager
|
||||
|
||||
__all__ = [
|
||||
"ImageProcessor",
|
||||
"DataAugmentation",
|
||||
"ConfigManager",
|
||||
"setup_logging",
|
||||
"get_image_files",
|
||||
"load_image",
|
||||
"save_image",
|
||||
"validate_image",
|
||||
"print_progress",
|
||||
]
|
BIN
src/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
src/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
src/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/config_manager.cpython-39.pyc
Normal file
BIN
src/__pycache__/config_manager.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/data_augmentation.cpython-39.pyc
Normal file
BIN
src/__pycache__/data_augmentation.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/image_processor.cpython-39.pyc
Normal file
BIN
src/__pycache__/image_processor.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/utils.cpython-313.pyc
Normal file
BIN
src/__pycache__/utils.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/utils.cpython-39.pyc
Normal file
BIN
src/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
40
src/config.py
Normal file
40
src/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Configuration file for data augmentation
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Paths
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
DATA_DIR = BASE_DIR / "data"
|
||||
INPUT_IMAGES_DIR = DATA_DIR / "dataset" / "training_data" / "images"
|
||||
OUTPUT_DIR = DATA_DIR / "augmented_data"
|
||||
|
||||
# Data augmentation parameters
|
||||
AUGMENTATION_CONFIG = {
|
||||
"rotation_range": 15, # degrees
|
||||
"width_shift_range": 0.1, # fraction of total width
|
||||
"height_shift_range": 0.1, # fraction of total height
|
||||
"brightness_range": [0.8, 1.2], # brightness factor
|
||||
"zoom_range": [0.9, 1.1], # zoom factor
|
||||
"horizontal_flip": True,
|
||||
"vertical_flip": False,
|
||||
"fill_mode": "nearest",
|
||||
"cval": 0,
|
||||
"rescale": 1./255,
|
||||
}
|
||||
|
||||
# Processing parameters
|
||||
PROCESSING_CONFIG = {
|
||||
"target_size": (224, 224), # (width, height)
|
||||
"batch_size": 32,
|
||||
"num_augmentations": 3, # number of augmented versions per image
|
||||
"save_format": "jpg",
|
||||
"quality": 95,
|
||||
}
|
||||
|
||||
# Supported image formats
|
||||
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
175
src/config_manager.py
Normal file
175
src/config_manager.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Configuration manager for data augmentation
|
||||
"""
|
||||
import yaml
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
class ConfigManager:
|
||||
"""Manages configuration loading and validation"""
|
||||
|
||||
def __init__(self, config_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Initialize ConfigManager
|
||||
|
||||
Args:
|
||||
config_path: Path to main config file
|
||||
"""
|
||||
self.config_path = Path(config_path) if config_path else Path("config/config.yaml")
|
||||
self.config = {}
|
||||
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load main configuration file"""
|
||||
try:
|
||||
if self.config_path.exists():
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
print(f"✅ Loaded configuration from {self.config_path}")
|
||||
else:
|
||||
print(f"⚠️ Config file not found: {self.config_path}")
|
||||
self.config = self._get_default_config()
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading config: {e}")
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
"paths": {
|
||||
"input_dir": "data/dataset/training_data/images",
|
||||
"output_dir": "data/augmented_data",
|
||||
"log_file": "logs/data_augmentation.log"
|
||||
},
|
||||
"augmentation": {
|
||||
"rotation": {"enabled": True, "angles": [30, 60, 120, 150, 180, 210, 240, 300, 330], "probability": 1.0}
|
||||
},
|
||||
"processing": {
|
||||
"target_size": [224, 224],
|
||||
"batch_size": 32,
|
||||
"num_augmentations": 3,
|
||||
"save_format": "jpg",
|
||||
"quality": 95
|
||||
},
|
||||
"supported_formats": [".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
|
||||
"logging": {
|
||||
"level": "INFO",
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
},
|
||||
"performance": {
|
||||
"num_workers": 4,
|
||||
"prefetch_factor": 2,
|
||||
"pin_memory": True,
|
||||
"use_gpu": False
|
||||
}
|
||||
}
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current configuration"""
|
||||
return self.config
|
||||
|
||||
def get_augmentation_config(self) -> Dict[str, Any]:
|
||||
"""Get augmentation configuration"""
|
||||
return self.config.get("augmentation", {})
|
||||
|
||||
def get_processing_config(self) -> Dict[str, Any]:
|
||||
"""Get processing configuration"""
|
||||
return self.config.get("processing", {})
|
||||
|
||||
def get_paths_config(self) -> Dict[str, Any]:
|
||||
"""Get paths configuration"""
|
||||
return self.config.get("paths", {})
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""Get logging configuration"""
|
||||
return self.config.get("logging", {})
|
||||
|
||||
def get_performance_config(self) -> Dict[str, Any]:
|
||||
"""Get performance configuration"""
|
||||
return self.config.get("performance", {})
|
||||
|
||||
|
||||
|
||||
def update_config(self, updates: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update configuration with new values
|
||||
|
||||
Args:
|
||||
updates: Dictionary with updates to apply
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
try:
|
||||
self.config = self._merge_configs(self.config, updates)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error updating config: {e}")
|
||||
return False
|
||||
|
||||
def _merge_configs(self, base_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge updates with base configuration"""
|
||||
merged = base_config.copy()
|
||||
|
||||
def deep_merge(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in update.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
return deep_merge(merged, updates)
|
||||
|
||||
def save_config(self, output_path: Optional[Union[str, Path]] = None) -> bool:
|
||||
"""
|
||||
Save current configuration to file
|
||||
|
||||
Args:
|
||||
output_path: Path to save config file
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
"""
|
||||
try:
|
||||
output_path = Path(output_path) if output_path else self.config_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(self.config, f, default_flow_style=False, indent=2, allow_unicode=True)
|
||||
|
||||
print(f"✅ Configuration saved to {output_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving config: {e}")
|
||||
return False
|
||||
|
||||
def print_config_summary(self):
|
||||
"""Print configuration summary"""
|
||||
print("\n" + "="*50)
|
||||
print("CONFIGURATION SUMMARY")
|
||||
print("="*50)
|
||||
|
||||
# Paths
|
||||
paths = self.get_paths_config()
|
||||
print(f"Input directory: {paths.get('input_dir', 'Not set')}")
|
||||
print(f"Output directory: {paths.get('output_dir', 'Not set')}")
|
||||
|
||||
# Processing
|
||||
processing = self.get_processing_config()
|
||||
print(f"Target size: {processing.get('target_size', 'Not set')}")
|
||||
print(f"Number of augmentations: {processing.get('num_augmentations', 'Not set')}")
|
||||
|
||||
# Augmentation
|
||||
augmentation = self.get_augmentation_config()
|
||||
enabled_augmentations = []
|
||||
for name, config in augmentation.items():
|
||||
if isinstance(config, dict) and config.get('enabled', False):
|
||||
enabled_augmentations.append(name)
|
||||
|
||||
print(f"Enabled augmentations: {', '.join(enabled_augmentations) if enabled_augmentations else 'None'}")
|
||||
|
||||
print("="*50)
|
161
src/data_augmentation.py
Normal file
161
src/data_augmentation.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Data augmentation class for image augmentation - ONLY ROTATION
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import random
|
||||
import math
|
||||
from image_processor import ImageProcessor
|
||||
from utils import load_image, save_image, create_augmented_filename, print_progress
|
||||
|
||||
class DataAugmentation:
|
||||
"""Class for image data augmentation - ONLY ROTATION"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize DataAugmentation
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary for augmentation parameters
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.image_processor = ImageProcessor()
|
||||
|
||||
def rotate_image(self, image: np.ndarray, angle: float) -> np.ndarray:
|
||||
"""
|
||||
Rotate image by given angle
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
angle: Rotation angle in degrees
|
||||
|
||||
Returns:
|
||||
Rotated image
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
center = (width // 2, height // 2)
|
||||
|
||||
# Create rotation matrix
|
||||
rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
||||
|
||||
# Perform rotation
|
||||
rotated = cv2.warpAffine(image, rotation_matrix, (width, height),
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
return rotated
|
||||
|
||||
def augment_single_image(self, image: np.ndarray, num_augmentations: int = None) -> List[np.ndarray]:
|
||||
"""
|
||||
Apply rotation augmentation to a single image
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
num_augmentations: Number of augmented versions to create
|
||||
|
||||
Returns:
|
||||
List of augmented images
|
||||
"""
|
||||
num_augmentations = num_augmentations or 3 # Default value
|
||||
augmented_images = []
|
||||
|
||||
# Get rotation configuration
|
||||
rotation_config = self.config.get("rotation", {})
|
||||
angles = rotation_config.get("angles", [30, 60, 120, 150, 180, 210, 240, 300, 330])
|
||||
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
|
||||
# Apply rotation with random angle from the specified list
|
||||
if rotation_config.get("enabled", False):
|
||||
angle = random.choice(angles)
|
||||
augmented = self.rotate_image(augmented, angle)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
return augmented_images
|
||||
|
||||
def augment_image_file(self, image_path: Path, output_dir: Path, num_augmentations: int = None) -> List[Path]:
|
||||
"""
|
||||
Augment a single image file and save results
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
output_dir: Output directory for augmented images
|
||||
num_augmentations: Number of augmented versions to create
|
||||
|
||||
Returns:
|
||||
List of paths to saved augmented images
|
||||
"""
|
||||
# Load image
|
||||
image = load_image(image_path, self.image_processor.target_size)
|
||||
if image is None:
|
||||
return []
|
||||
|
||||
# Apply augmentations
|
||||
augmented_images = self.augment_single_image(image, num_augmentations)
|
||||
|
||||
# Save augmented images
|
||||
saved_paths = []
|
||||
for i, aug_image in enumerate(augmented_images):
|
||||
# Create output filename
|
||||
output_filename = create_augmented_filename(image_path, i + 1)
|
||||
output_path = output_dir / output_filename.name
|
||||
|
||||
# Save image
|
||||
if save_image(aug_image, output_path):
|
||||
saved_paths.append(output_path)
|
||||
|
||||
return saved_paths
|
||||
|
||||
def batch_augment(self, input_dir: Path, output_dir: Path, num_augmentations: int = None) -> Dict[str, List[Path]]:
|
||||
"""
|
||||
Augment all images in a directory
|
||||
|
||||
Args:
|
||||
input_dir: Input directory containing images
|
||||
output_dir: Output directory for augmented images
|
||||
num_augmentations: Number of augmented versions per image
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original images to their augmented versions
|
||||
"""
|
||||
from utils import get_image_files
|
||||
|
||||
image_files = get_image_files(input_dir)
|
||||
results = {}
|
||||
|
||||
print(f"Found {len(image_files)} images to augment")
|
||||
|
||||
for i, image_path in enumerate(image_files):
|
||||
print_progress(i + 1, len(image_files), "Augmenting images")
|
||||
|
||||
# Augment single image
|
||||
augmented_paths = self.augment_image_file(image_path, output_dir, num_augmentations)
|
||||
|
||||
if augmented_paths:
|
||||
results[str(image_path)] = augmented_paths
|
||||
|
||||
print(f"\nAugmented {len(results)} images successfully")
|
||||
return results
|
||||
|
||||
def get_augmentation_summary(self, results: Dict[str, List[Path]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get summary of augmentation results
|
||||
|
||||
Args:
|
||||
results: Results from batch_augment
|
||||
|
||||
Returns:
|
||||
Summary dictionary
|
||||
"""
|
||||
total_original = len(results)
|
||||
total_augmented = sum(len(paths) for paths in results.values())
|
||||
|
||||
return {
|
||||
"total_original_images": total_original,
|
||||
"total_augmented_images": total_augmented,
|
||||
"augmentation_ratio": total_augmented / total_original if total_original > 0 else 0,
|
||||
"successful_augmentations": len([paths for paths in results.values() if paths])
|
||||
}
|
174
src/image_processor.py
Normal file
174
src/image_processor.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Image processing class for basic image operations
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional, List
|
||||
from utils import load_image, save_image, validate_image, get_image_files
|
||||
|
||||
class ImageProcessor:
|
||||
"""Class for basic image processing operations"""
|
||||
|
||||
def __init__(self, target_size: Tuple[int, int] = None):
|
||||
"""
|
||||
Initialize ImageProcessor
|
||||
|
||||
Args:
|
||||
target_size: Target size for image resizing (width, height)
|
||||
"""
|
||||
self.target_size = target_size or (224, 224) # Default size
|
||||
|
||||
def load_and_preprocess(self, image_path: Path) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Load and preprocess image
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Preprocessed image as numpy array or None if failed
|
||||
"""
|
||||
if not validate_image(image_path):
|
||||
print(f"Invalid image file: {image_path}")
|
||||
return None
|
||||
|
||||
image = load_image(image_path, self.target_size)
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
# Normalize pixel values
|
||||
image = image.astype(np.float32) / 255.0
|
||||
|
||||
return image
|
||||
|
||||
def resize_image(self, image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""
|
||||
Resize image to target size
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array
|
||||
target_size: Target size (width, height)
|
||||
|
||||
Returns:
|
||||
Resized image
|
||||
"""
|
||||
return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
def normalize_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize image pixel values to [0, 1]
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
|
||||
Returns:
|
||||
Normalized image
|
||||
"""
|
||||
return image.astype(np.float32) / 255.0
|
||||
|
||||
def denormalize_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Denormalize image pixel values to [0, 255]
|
||||
|
||||
Args:
|
||||
image: Input image (normalized)
|
||||
|
||||
Returns:
|
||||
Denormalized image
|
||||
"""
|
||||
return (image * 255).astype(np.uint8)
|
||||
|
||||
def get_image_info(self, image_path: Path) -> dict:
|
||||
"""
|
||||
Get information about image
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Dictionary containing image information
|
||||
"""
|
||||
try:
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
return {}
|
||||
|
||||
height, width, channels = image.shape
|
||||
file_size = image_path.stat().st_size / (1024 * 1024) # MB
|
||||
|
||||
return {
|
||||
"path": str(image_path),
|
||||
"width": width,
|
||||
"height": height,
|
||||
"channels": channels,
|
||||
"file_size_mb": round(file_size, 2),
|
||||
"format": image_path.suffix
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error getting image info for {image_path}: {e}")
|
||||
return {}
|
||||
|
||||
def batch_process_images(self, input_dir: Path, output_dir: Path) -> List[Path]:
|
||||
"""
|
||||
Process all images in a directory
|
||||
|
||||
Args:
|
||||
input_dir: Input directory containing images
|
||||
output_dir: Output directory for processed images
|
||||
|
||||
Returns:
|
||||
List of processed image paths
|
||||
"""
|
||||
image_files = get_image_files(input_dir)
|
||||
processed_files = []
|
||||
|
||||
print(f"Found {len(image_files)} images to process")
|
||||
|
||||
for i, image_path in enumerate(image_files):
|
||||
print_progress(i + 1, len(image_files), "Processing images")
|
||||
|
||||
# Load and preprocess image
|
||||
image = self.load_and_preprocess(image_path)
|
||||
if image is None:
|
||||
continue
|
||||
|
||||
# Create output path
|
||||
output_path = output_dir / image_path.name
|
||||
|
||||
# Denormalize for saving
|
||||
image = self.denormalize_image(image)
|
||||
|
||||
# Save processed image
|
||||
if save_image(image, output_path):
|
||||
processed_files.append(output_path)
|
||||
|
||||
print(f"\nProcessed {len(processed_files)} images successfully")
|
||||
return processed_files
|
||||
|
||||
def create_thumbnail(self, image: np.ndarray, size: Tuple[int, int] = (100, 100)) -> np.ndarray:
|
||||
"""
|
||||
Create thumbnail of image
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
size: Thumbnail size (width, height)
|
||||
|
||||
Returns:
|
||||
Thumbnail image
|
||||
"""
|
||||
return cv2.resize(image, size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
def convert_to_grayscale(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert image to grayscale
|
||||
|
||||
Args:
|
||||
image: Input image (RGB)
|
||||
|
||||
Returns:
|
||||
Grayscale image
|
||||
"""
|
||||
if len(image.shape) == 3:
|
||||
return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
||||
return image
|
8
src/model/__init__.py
Normal file
8
src/model/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Model module for YOLO-based ID card detection and cropping
|
||||
"""
|
||||
|
||||
from .yolo_detector import YOLODetector
|
||||
from .id_card_processor import IDCardProcessor
|
||||
|
||||
__all__ = ['YOLODetector', 'IDCardProcessor']
|
BIN
src/model/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
src/model/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/model/__pycache__/id_card_processor.cpython-39.pyc
Normal file
BIN
src/model/__pycache__/id_card_processor.cpython-39.pyc
Normal file
Binary file not shown.
BIN
src/model/__pycache__/yolo_detector.cpython-39.pyc
Normal file
BIN
src/model/__pycache__/yolo_detector.cpython-39.pyc
Normal file
Binary file not shown.
343
src/model/id_card_processor.py
Normal file
343
src/model/id_card_processor.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
ID Card Processor for background removal and preprocessing
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import logging
|
||||
from .yolo_detector import YOLODetector
|
||||
|
||||
class IDCardProcessor:
|
||||
"""
|
||||
ID Card Processor for background removal and preprocessing
|
||||
"""
|
||||
|
||||
def __init__(self, yolo_detector: Optional[YOLODetector] = None):
|
||||
"""
|
||||
Initialize ID Card Processor
|
||||
|
||||
Args:
|
||||
yolo_detector: YOLO detector instance
|
||||
"""
|
||||
self.yolo_detector = yolo_detector or YOLODetector()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def remove_background(self, image: np.ndarray, method: str = 'grabcut') -> np.ndarray:
|
||||
"""
|
||||
Remove background from image
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
method: Background removal method ('grabcut', 'threshold', 'contour')
|
||||
|
||||
Returns:
|
||||
Image with background removed
|
||||
"""
|
||||
if method == 'grabcut':
|
||||
return self._grabcut_background_removal(image)
|
||||
elif method == 'threshold':
|
||||
return self._threshold_background_removal(image)
|
||||
elif method == 'contour':
|
||||
return self._contour_background_removal(image)
|
||||
else:
|
||||
self.logger.warning(f"Unknown method: {method}, using grabcut")
|
||||
return self._grabcut_background_removal(image)
|
||||
|
||||
def _grabcut_background_removal(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Remove background using GrabCut algorithm
|
||||
"""
|
||||
try:
|
||||
# Create mask
|
||||
mask = np.zeros(image.shape[:2], np.uint8)
|
||||
|
||||
# Create temporary arrays
|
||||
bgd_model = np.zeros((1, 65), np.float64)
|
||||
fgd_model = np.zeros((1, 65), np.float64)
|
||||
|
||||
# Define rectangle (assuming ID card is in center)
|
||||
height, width = image.shape[:2]
|
||||
rect = (width//8, height//8, width*3//4, height*3//4)
|
||||
|
||||
# Apply GrabCut
|
||||
cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
|
||||
|
||||
# Create mask
|
||||
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
|
||||
|
||||
# Apply mask
|
||||
result = image * mask2[:, :, np.newaxis]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in grabcut background removal: {e}")
|
||||
return image
|
||||
|
||||
def _threshold_background_removal(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Remove background using thresholding
|
||||
"""
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply Gaussian blur
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
|
||||
# Apply threshold
|
||||
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find largest contour (assumed to be the ID card)
|
||||
if contours:
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Create mask
|
||||
mask = np.zeros_like(gray)
|
||||
cv2.fillPoly(mask, [largest_contour], 255)
|
||||
|
||||
# Apply mask
|
||||
result = cv2.bitwise_and(image, image, mask=mask)
|
||||
return result
|
||||
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in threshold background removal: {e}")
|
||||
return image
|
||||
|
||||
def _contour_background_removal(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Remove background using contour detection
|
||||
"""
|
||||
try:
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply edge detection
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find largest contour
|
||||
if contours:
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Approximate contour to get rectangle
|
||||
epsilon = 0.02 * cv2.arcLength(largest_contour, True)
|
||||
approx = cv2.approxPolyDP(largest_contour, epsilon, True)
|
||||
|
||||
# Create mask
|
||||
mask = np.zeros_like(gray)
|
||||
cv2.fillPoly(mask, [approx], 255)
|
||||
|
||||
# Apply mask
|
||||
result = cv2.bitwise_and(image, image, mask=mask)
|
||||
return result
|
||||
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in contour background removal: {e}")
|
||||
return image
|
||||
|
||||
def enhance_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Enhance image quality for better OCR
|
||||
"""
|
||||
try:
|
||||
# Convert to LAB color space
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
||||
|
||||
# Apply CLAHE to L channel
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
|
||||
|
||||
# Convert back to BGR
|
||||
enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
||||
|
||||
# Apply slight Gaussian blur to reduce noise
|
||||
enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)
|
||||
|
||||
return enhanced
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error enhancing image: {e}")
|
||||
return image
|
||||
|
||||
def normalize_image(self, image: np.ndarray, target_size: Tuple[int, int] = (800, 600)) -> np.ndarray:
|
||||
"""
|
||||
Normalize image size and orientation
|
||||
"""
|
||||
try:
|
||||
# Resize image
|
||||
resized = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Convert to grayscale if needed
|
||||
if len(resized.shape) == 3:
|
||||
gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = resized
|
||||
|
||||
# Apply histogram equalization
|
||||
equalized = cv2.equalizeHist(gray)
|
||||
|
||||
# Convert back to BGR for consistency
|
||||
if len(image.shape) == 3:
|
||||
result = cv2.cvtColor(equalized, cv2.COLOR_GRAY2BGR)
|
||||
else:
|
||||
result = equalized
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error normalizing image: {e}")
|
||||
return image
|
||||
|
||||
def process_id_card(self, image_path: Path, output_dir: Path,
|
||||
remove_bg: bool = True, enhance: bool = True,
|
||||
normalize: bool = True, target_size: Tuple[int, int] = (800, 600)) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single ID card image
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
output_dir: Output directory
|
||||
remove_bg: Whether to remove background
|
||||
enhance: Whether to enhance image
|
||||
normalize: Whether to normalize image
|
||||
target_size: Target size for normalization
|
||||
|
||||
Returns:
|
||||
Processing results
|
||||
"""
|
||||
result = {
|
||||
'input_path': str(image_path),
|
||||
'output_paths': [],
|
||||
'success': False
|
||||
}
|
||||
|
||||
try:
|
||||
# Load image
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
self.logger.error(f"Could not load image: {image_path}")
|
||||
return result
|
||||
|
||||
# Create output filename
|
||||
stem = image_path.stem
|
||||
processed_path = output_dir / f"{stem}_processed.jpg"
|
||||
|
||||
# Apply processing steps
|
||||
processed_image = image.copy()
|
||||
|
||||
if remove_bg:
|
||||
self.logger.info(f"Removing background from {image_path.name}")
|
||||
processed_image = self.remove_background(processed_image)
|
||||
|
||||
if enhance:
|
||||
self.logger.info(f"Enhancing {image_path.name}")
|
||||
processed_image = self.enhance_image(processed_image)
|
||||
|
||||
if normalize:
|
||||
self.logger.info(f"Normalizing {image_path.name}")
|
||||
processed_image = self.normalize_image(processed_image, target_size)
|
||||
|
||||
# Save processed image
|
||||
processed_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(processed_path), processed_image)
|
||||
result['output_paths'].append(str(processed_path))
|
||||
|
||||
result['success'] = True
|
||||
self.logger.info(f"Processed {image_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {image_path}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def batch_process_id_cards(self, input_dir: Path, output_dir: Path,
|
||||
detect_first: bool = True, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all ID card images in a directory
|
||||
|
||||
Args:
|
||||
input_dir: Input directory
|
||||
output_dir: Output directory
|
||||
detect_first: Whether to detect ID cards first using YOLO
|
||||
**kwargs: Additional arguments for processing
|
||||
|
||||
Returns:
|
||||
Batch processing results
|
||||
"""
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if detect_first:
|
||||
# First detect and crop ID cards
|
||||
self.logger.info("Detecting and cropping ID cards...")
|
||||
detection_results = self.yolo_detector.batch_process(input_dir, output_dir / "cropped")
|
||||
|
||||
# Process cropped images
|
||||
cropped_dir = output_dir / "cropped"
|
||||
if cropped_dir.exists():
|
||||
self.logger.info("Processing cropped ID cards...")
|
||||
return self._process_cropped_images(cropped_dir, output_dir / "processed", **kwargs)
|
||||
else:
|
||||
self.logger.warning("No cropped images found, processing original images")
|
||||
return self._process_cropped_images(input_dir, output_dir / "processed", **kwargs)
|
||||
else:
|
||||
# Process original images directly
|
||||
return self._process_cropped_images(input_dir, output_dir / "processed", **kwargs)
|
||||
|
||||
def _process_cropped_images(self, input_dir: Path, output_dir: Path, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Process cropped ID card images recursively
|
||||
"""
|
||||
# Get all image files recursively from input directory and subdirectories
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
||||
image_files = []
|
||||
|
||||
# Recursively find all image files
|
||||
for file_path in input_dir.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
|
||||
image_files.append(file_path)
|
||||
|
||||
if not image_files:
|
||||
self.logger.error(f"No images found in {input_dir} and subdirectories")
|
||||
return {'success': False, 'error': 'No images found'}
|
||||
|
||||
self.logger.info(f"Processing {len(image_files)} images from {input_dir} and subdirectories")
|
||||
|
||||
results = {
|
||||
'total_images': len(image_files),
|
||||
'processed_images': 0,
|
||||
'results': []
|
||||
}
|
||||
|
||||
# Process each image
|
||||
for i, image_path in enumerate(image_files):
|
||||
self.logger.info(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
|
||||
|
||||
# Create subdirectory structure in output to match input structure
|
||||
relative_path = image_path.relative_to(input_dir)
|
||||
output_subdir = output_dir / relative_path.parent
|
||||
output_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = self.process_id_card(image_path, output_subdir, **kwargs)
|
||||
results['results'].append(result)
|
||||
|
||||
if result['success']:
|
||||
results['processed_images'] += 1
|
||||
|
||||
# Summary
|
||||
self.logger.info(f"ID card processing completed:")
|
||||
self.logger.info(f" - Total images: {results['total_images']}")
|
||||
self.logger.info(f" - Processed: {results['processed_images']}")
|
||||
|
||||
return results
|
266
src/model/yolo_detector.py
Normal file
266
src/model/yolo_detector.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
YOLO Detector for ID Card Detection and Cropping
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import logging
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
class YOLODetector:
|
||||
"""
|
||||
YOLO-based detector for ID card detection and cropping
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: Optional[str] = None, confidence: float = 0.5):
|
||||
"""
|
||||
Initialize YOLO detector
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model file (.pt)
|
||||
confidence: Confidence threshold for detection
|
||||
"""
|
||||
self.confidence = confidence
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize model
|
||||
if model_path and Path(model_path).exists():
|
||||
self.model = YOLO(model_path)
|
||||
self.logger.info(f"Loaded custom YOLO model from {model_path}")
|
||||
else:
|
||||
# Use pre-trained YOLO model for general object detection
|
||||
self.model = YOLO('yolov8n.pt')
|
||||
self.logger.info("Using pre-trained YOLOv8n model")
|
||||
|
||||
# Set device
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.logger.info(f"Using device: {self.device}")
|
||||
|
||||
def detect_id_cards(self, image_path: Path) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect ID cards in an image
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
List of detection results with bounding boxes
|
||||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
self.logger.error(f"Could not load image: {image_path}")
|
||||
return []
|
||||
|
||||
# Run detection
|
||||
results = self.model(image, conf=self.confidence)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Get coordinates
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
confidence = float(box.conf[0])
|
||||
class_id = int(box.cls[0])
|
||||
class_name = self.model.names[class_id]
|
||||
|
||||
detection = {
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class_id': class_id,
|
||||
'class_name': class_name,
|
||||
'area': (x2 - x1) * (y2 - y1)
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
# Sort by confidence and area (prefer larger, more confident detections)
|
||||
detections.sort(key=lambda x: (x['confidence'], x['area']), reverse=True)
|
||||
|
||||
self.logger.info(f"Found {len(detections)} detections in {image_path.name}")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error detecting ID cards in {image_path}: {e}")
|
||||
return []
|
||||
|
||||
def crop_id_card(self, image_path: Path, bbox: List[int],
|
||||
output_path: Optional[Path] = None,
|
||||
padding: int = 10) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Crop ID card from image using bounding box
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
bbox: Bounding box [x1, y1, x2, y2]
|
||||
output_path: Path to save cropped image
|
||||
padding: Padding around the bounding box
|
||||
|
||||
Returns:
|
||||
Cropped image as numpy array
|
||||
"""
|
||||
try:
|
||||
# Load image
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
self.logger.error(f"Could not load image: {image_path}")
|
||||
return None
|
||||
|
||||
height, width = image.shape[:2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# Add padding
|
||||
x1 = max(0, x1 - padding)
|
||||
y1 = max(0, y1 - padding)
|
||||
x2 = min(width, x2 + padding)
|
||||
y2 = min(height, y2 + padding)
|
||||
|
||||
# Crop image
|
||||
cropped = image[y1:y2, x1:x2]
|
||||
|
||||
# Save if output path provided
|
||||
if output_path:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(output_path), cropped)
|
||||
self.logger.info(f"Saved cropped image to {output_path}")
|
||||
|
||||
return cropped
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cropping ID card from {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def process_single_image(self, image_path: Path, output_dir: Path,
|
||||
save_original: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single image: detect and crop ID cards
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
output_dir: Output directory for cropped images
|
||||
save_original: Whether to save original image with bounding boxes
|
||||
|
||||
Returns:
|
||||
Processing results
|
||||
"""
|
||||
result = {
|
||||
'input_path': str(image_path),
|
||||
'detections': [],
|
||||
'cropped_paths': [],
|
||||
'success': False
|
||||
}
|
||||
|
||||
try:
|
||||
# Detect ID cards
|
||||
detections = self.detect_id_cards(image_path)
|
||||
|
||||
if not detections:
|
||||
self.logger.warning(f"No ID cards detected in {image_path.name}")
|
||||
return result
|
||||
|
||||
# Process each detection
|
||||
for i, detection in enumerate(detections):
|
||||
bbox = detection['bbox']
|
||||
|
||||
# Create output filename
|
||||
stem = image_path.stem
|
||||
suffix = f"_card_{i+1}.jpg"
|
||||
output_path = output_dir / f"{stem}{suffix}"
|
||||
|
||||
# Crop ID card
|
||||
cropped = self.crop_id_card(image_path, bbox, output_path)
|
||||
|
||||
if cropped is not None:
|
||||
result['detections'].append(detection)
|
||||
result['cropped_paths'].append(str(output_path))
|
||||
|
||||
# Save original with bounding boxes if requested
|
||||
if save_original and detections:
|
||||
image = cv2.imread(str(image_path))
|
||||
for detection in detections:
|
||||
bbox = detection['bbox']
|
||||
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
cv2.putText(image, f"{detection['confidence']:.2f}",
|
||||
(bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5, (0, 255, 0), 2)
|
||||
|
||||
annotated_path = output_dir / f"{image_path.stem}_annotated.jpg"
|
||||
cv2.imwrite(str(annotated_path), image)
|
||||
result['annotated_path'] = str(annotated_path)
|
||||
|
||||
result['success'] = True
|
||||
self.logger.info(f"Processed {image_path.name}: {len(result['cropped_paths'])} cards cropped")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {image_path}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def batch_process(self, input_dir: Path, output_dir: Path,
|
||||
save_annotated: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all images in a directory and subdirectories
|
||||
|
||||
Args:
|
||||
input_dir: Input directory containing images
|
||||
output_dir: Output directory for cropped images
|
||||
save_annotated: Whether to save annotated images
|
||||
|
||||
Returns:
|
||||
Batch processing results
|
||||
"""
|
||||
# Create output directory
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get all image files recursively from input directory and subdirectories
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
||||
image_files = []
|
||||
|
||||
# Recursively find all image files
|
||||
for file_path in input_dir.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix.lower() in image_extensions:
|
||||
image_files.append(file_path)
|
||||
|
||||
if not image_files:
|
||||
self.logger.error(f"No images found in {input_dir} and subdirectories")
|
||||
return {'success': False, 'error': 'No images found'}
|
||||
|
||||
self.logger.info(f"Processing {len(image_files)} images from {input_dir} and subdirectories")
|
||||
|
||||
results = {
|
||||
'total_images': len(image_files),
|
||||
'processed_images': 0,
|
||||
'total_detections': 0,
|
||||
'total_cropped': 0,
|
||||
'results': []
|
||||
}
|
||||
|
||||
# Process each image
|
||||
for i, image_path in enumerate(image_files):
|
||||
self.logger.info(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
|
||||
|
||||
# Create subdirectory structure in output to match input structure
|
||||
relative_path = image_path.relative_to(input_dir)
|
||||
output_subdir = output_dir / relative_path.parent
|
||||
output_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = self.process_single_image(image_path, output_subdir, save_annotated)
|
||||
results['results'].append(result)
|
||||
|
||||
if result['success']:
|
||||
results['processed_images'] += 1
|
||||
results['total_detections'] += len(result['detections'])
|
||||
results['total_cropped'] += len(result['cropped_paths'])
|
||||
|
||||
# Summary
|
||||
self.logger.info(f"Batch processing completed:")
|
||||
self.logger.info(f" - Total images: {results['total_images']}")
|
||||
self.logger.info(f" - Processed: {results['processed_images']}")
|
||||
self.logger.info(f" - Total detections: {results['total_detections']}")
|
||||
self.logger.info(f" - Total cropped: {results['total_cropped']}")
|
||||
|
||||
return results
|
98
src/utils.py
Normal file
98
src/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Utility functions for data augmentation
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
||||
"""Setup logging configuration"""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('data_augmentation.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def get_image_files(directory: Path) -> List[Path]:
|
||||
"""Get all image files from directory"""
|
||||
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
|
||||
|
||||
image_files = []
|
||||
if directory.exists():
|
||||
for ext in SUPPORTED_FORMATS:
|
||||
image_files.extend(directory.glob(f"*{ext}"))
|
||||
image_files.extend(directory.glob(f"*{ext.upper()}"))
|
||||
return sorted(image_files)
|
||||
|
||||
def validate_image(image_path: Path) -> bool:
|
||||
"""Validate if file is a valid image"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
img.verify()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load_image(image_path: Path, target_size: Tuple[int, int] = None) -> Optional[np.ndarray]:
|
||||
"""Load and resize image"""
|
||||
try:
|
||||
# Load image using OpenCV
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
# Convert BGR to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize if target_size is provided
|
||||
if target_size:
|
||||
image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def save_image(image: np.ndarray, output_path: Path, quality: int = 95) -> bool:
|
||||
"""Save image to file"""
|
||||
try:
|
||||
# Convert RGB to BGR for OpenCV
|
||||
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save image
|
||||
cv2.imwrite(str(output_path), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving image {output_path}: {e}")
|
||||
return False
|
||||
|
||||
def create_augmented_filename(original_path: Path, index: int, suffix: str = "aug") -> Path:
|
||||
"""Create filename for augmented image"""
|
||||
stem = original_path.stem
|
||||
suffix = f"_{suffix}_{index:02d}"
|
||||
return original_path.parent / f"{stem}{suffix}{original_path.suffix}"
|
||||
|
||||
def get_file_size_mb(file_path: Path) -> float:
|
||||
"""Get file size in MB"""
|
||||
return file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
def print_progress(current: int, total: int, prefix: str = "Progress"):
|
||||
"""Print progress bar"""
|
||||
bar_length = 50
|
||||
filled_length = int(round(bar_length * current / float(total)))
|
||||
percents = round(100.0 * current / float(total), 1)
|
||||
bar = '=' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f'\r{prefix}: [{bar}] {percents}% ({current}/{total})', end='')
|
||||
if current == total:
|
||||
print()
|
Reference in New Issue
Block a user