423 lines
15 KiB
Python
423 lines
15 KiB
Python
"""
|
|
Main script for data augmentation
|
|
"""
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
|
|
# Add src to path for imports
|
|
sys.path.append(str(Path(__file__).parent / "src"))
|
|
|
|
from src.config_manager import ConfigManager
|
|
from src.data_augmentation import DataAugmentation
|
|
from src.image_processor import ImageProcessor
|
|
from src.id_card_detector import IDCardDetector
|
|
from src.utils import setup_logging, get_image_files, print_progress
|
|
|
|
def parse_arguments():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description="Image Data Augmentation Tool")
|
|
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="config/config.yaml",
|
|
help="Path to configuration file"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--preset",
|
|
type=str,
|
|
help="Apply augmentation preset (light, medium, heavy, ocr_optimized, document)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--input-dir",
|
|
type=str,
|
|
help="Input directory containing images (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
help="Output directory for augmented images (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-augmentations",
|
|
type=int,
|
|
help="Number of augmented versions per image (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--target-size",
|
|
type=str,
|
|
help="Target size for images (width x height) (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--preview",
|
|
action="store_true",
|
|
help="Preview augmentation on first image only"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--info",
|
|
action="store_true",
|
|
help="Show information about images in input directory"
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
"--list-presets",
|
|
action="store_true",
|
|
help="List available presets and exit"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--log-level",
|
|
type=str,
|
|
default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Logging level"
|
|
)
|
|
|
|
# ID Card Detection arguments
|
|
parser.add_argument(
|
|
"--enable-id-detection",
|
|
action="store_true",
|
|
help="Enable ID card detection and cropping before augmentation"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
help="Path to YOLO model for ID card detection (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--confidence",
|
|
type=float,
|
|
help="Confidence threshold for ID card detection (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--crop-mode",
|
|
type=str,
|
|
choices=["bbox", "square", "aspect_ratio"],
|
|
help="Crop mode for ID cards (overrides config)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--crop-target-size",
|
|
type=str,
|
|
help="Target size for cropped ID cards (widthxheight) (overrides config)"
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
def parse_range(range_str: str) -> tuple:
|
|
"""Parse range string like '0.8-1.2' to tuple (0.8, 1.2)"""
|
|
try:
|
|
min_val, max_val = map(float, range_str.split('-'))
|
|
return (min_val, max_val)
|
|
except ValueError:
|
|
print(f"Invalid range format: {range_str}. Expected format: min-max")
|
|
sys.exit(1)
|
|
|
|
def parse_size(size_str: str) -> tuple:
|
|
"""Parse size string like '224x224' to tuple (224, 224)"""
|
|
try:
|
|
width, height = map(int, size_str.split('x'))
|
|
return (width, height)
|
|
except ValueError:
|
|
print(f"Invalid size format: {size_str}. Expected format: widthxheight")
|
|
sys.exit(1)
|
|
|
|
def show_image_info(input_dir: Path):
|
|
"""Show information about images in input directory"""
|
|
image_files = get_image_files(input_dir)
|
|
|
|
if not image_files:
|
|
print(f"No images found in {input_dir}")
|
|
return
|
|
|
|
print(f"\nFound {len(image_files)} images in {input_dir}")
|
|
print("\nImage Information:")
|
|
print("-" * 80)
|
|
|
|
processor = ImageProcessor()
|
|
total_size = 0
|
|
|
|
for i, image_path in enumerate(image_files[:10]): # Show first 10 images
|
|
info = processor.get_image_info(image_path)
|
|
if info:
|
|
print(f"{i+1:2d}. {image_path.name}")
|
|
print(f" Size: {info['width']}x{info['height']} pixels")
|
|
print(f" Channels: {info['channels']}")
|
|
print(f" File size: {info['file_size_mb']} MB")
|
|
print(f" Format: {info['format']}")
|
|
total_size += info['file_size_mb']
|
|
|
|
if len(image_files) > 10:
|
|
print(f"\n... and {len(image_files) - 10} more images")
|
|
|
|
print(f"\nTotal file size: {total_size:.2f} MB")
|
|
print(f"Average file size: {total_size/len(image_files):.2f} MB")
|
|
|
|
def preview_augmentation(input_dir: Path, output_dir: Path, config: Dict[str, Any],
|
|
id_detection_config: Dict[str, Any] = None):
|
|
"""Preview augmentation on first image"""
|
|
image_files = get_image_files(input_dir)
|
|
|
|
if not image_files:
|
|
print(f"No images found in {input_dir}")
|
|
return
|
|
|
|
print(f"\nPreviewing augmentation on: {image_files[0].name}")
|
|
|
|
# Create augmentation instance
|
|
augmenter = DataAugmentation(config)
|
|
|
|
# Process with ID detection if enabled
|
|
if id_detection_config and id_detection_config.get('enabled', False):
|
|
print("🔍 ID Card Detection enabled - processing with YOLO model...")
|
|
|
|
# Initialize ID card detector
|
|
detector = IDCardDetector(
|
|
model_path=id_detection_config.get('model_path'),
|
|
config=config
|
|
)
|
|
|
|
if not detector.model:
|
|
print("❌ Failed to load YOLO model, proceeding with normal augmentation")
|
|
else:
|
|
# Process single image with ID detection
|
|
result = detector.process_single_image(
|
|
image_path=image_files[0],
|
|
output_dir=output_dir,
|
|
apply_augmentation=True,
|
|
save_original=id_detection_config.get('save_original_crops', True),
|
|
confidence=id_detection_config.get('confidence_threshold', 0.25),
|
|
iou_threshold=id_detection_config.get('iou_threshold', 0.45),
|
|
crop_mode=id_detection_config.get('crop_mode', 'bbox'),
|
|
target_size=id_detection_config.get('target_size'),
|
|
padding=id_detection_config.get('padding', 10)
|
|
)
|
|
|
|
if result and result.get('detections'):
|
|
print(f"✅ Detected {len(result['detections'])} ID cards")
|
|
print(f"💾 Saved {len(result['processed_cards'])} processed cards")
|
|
return
|
|
else:
|
|
print("⚠️ No ID cards detected, proceeding with normal augmentation")
|
|
|
|
# Normal augmentation (fallback) with new logic
|
|
augmented_paths = augmenter.augment_image_file(
|
|
image_files[0],
|
|
output_dir,
|
|
num_target_images=3
|
|
)
|
|
|
|
if augmented_paths:
|
|
print(f"Created {len(augmented_paths)} augmented versions:")
|
|
for i, path in enumerate(augmented_paths, 1):
|
|
print(f" {i}. {path.name}")
|
|
else:
|
|
print("Failed to create augmented images")
|
|
|
|
def main():
|
|
"""Main function"""
|
|
args = parse_arguments()
|
|
|
|
# Initialize config manager
|
|
config_manager = ConfigManager(args.config)
|
|
|
|
# List presets if requested
|
|
if args.list_presets:
|
|
presets = config_manager.list_presets()
|
|
print("\nAvailable presets:")
|
|
for preset in presets:
|
|
print(f" - {preset}")
|
|
return
|
|
|
|
# Apply preset if specified
|
|
if args.preset:
|
|
if not config_manager.apply_preset(args.preset):
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
# Override config with command line arguments
|
|
if args.input_dir:
|
|
config_manager.update_config({"paths": {"input_dir": args.input_dir}})
|
|
|
|
if args.output_dir:
|
|
config_manager.update_config({"paths": {"output_dir": args.output_dir}})
|
|
|
|
if args.num_augmentations:
|
|
config_manager.update_config({"processing": {"num_augmentations": args.num_augmentations}})
|
|
|
|
if args.target_size:
|
|
target_size = parse_size(args.target_size)
|
|
config_manager.update_config({"processing": {"target_size": list(target_size)}})
|
|
|
|
# Get configuration
|
|
config = config_manager.get_config()
|
|
paths_config = config_manager.get_paths_config()
|
|
processing_config = config_manager.get_processing_config()
|
|
augmentation_config = config_manager.get_augmentation_config()
|
|
logging_config = config_manager.get_logging_config()
|
|
data_strategy_config = config.get("data_strategy", {})
|
|
|
|
# Setup logging
|
|
logger = setup_logging(logging_config.get("level", "INFO"))
|
|
logger.info("Starting data augmentation process")
|
|
|
|
# Parse paths
|
|
input_dir = Path(paths_config.get("input_dir", "data/dataset/training_data/images"))
|
|
output_dir = Path(paths_config.get("output_dir", "data/augmented_data"))
|
|
|
|
# Check if input directory exists
|
|
if not input_dir.exists():
|
|
logger.error(f"Input directory does not exist: {input_dir}")
|
|
sys.exit(1)
|
|
|
|
# Create output directory
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Show image information if requested
|
|
if args.info:
|
|
show_image_info(input_dir)
|
|
return
|
|
|
|
# Get ID detection config
|
|
id_detection_config = config.get('id_card_detection', {})
|
|
|
|
# Override ID detection config with command line arguments
|
|
if args.enable_id_detection:
|
|
id_detection_config['enabled'] = True
|
|
|
|
if args.model_path:
|
|
id_detection_config['model_path'] = args.model_path
|
|
|
|
if args.confidence:
|
|
id_detection_config['confidence_threshold'] = args.confidence
|
|
|
|
if args.crop_mode:
|
|
id_detection_config['crop_mode'] = args.crop_mode
|
|
|
|
if args.crop_target_size:
|
|
target_size = parse_size(args.crop_target_size)
|
|
id_detection_config['target_size'] = list(target_size)
|
|
|
|
# Preview augmentation if requested
|
|
if args.preview:
|
|
preview_augmentation(input_dir, output_dir, augmentation_config, id_detection_config)
|
|
return
|
|
|
|
# Get image files
|
|
image_files = get_image_files(input_dir)
|
|
|
|
if not image_files:
|
|
logger.error(f"No images found in {input_dir}")
|
|
sys.exit(1)
|
|
|
|
# Get data strategy parameters
|
|
multiplication_factor = data_strategy_config.get("multiplication_factor", 3.0)
|
|
random_seed = data_strategy_config.get("random_seed")
|
|
|
|
logger.info(f"Found {len(image_files)} images to process")
|
|
logger.info(f"Output directory: {output_dir}")
|
|
logger.info(f"Data strategy: multiplication_factor = {multiplication_factor}")
|
|
if multiplication_factor < 1.0:
|
|
logger.info(f"SAMPLING MODE: Will process {multiplication_factor*100:.1f}% of input images")
|
|
else:
|
|
logger.info(f"MULTIPLICATION MODE: Target {multiplication_factor}x dataset size")
|
|
logger.info(f"Target size: {processing_config.get('target_size', [224, 224])}")
|
|
if random_seed:
|
|
logger.info(f"Random seed: {random_seed}")
|
|
|
|
# Process with ID detection if enabled
|
|
if id_detection_config.get('enabled', False):
|
|
logger.info("ID Card Detection enabled - processing with YOLO model...")
|
|
|
|
# Initialize ID card detector
|
|
detector = IDCardDetector(
|
|
model_path=id_detection_config.get('model_path'),
|
|
config=config
|
|
)
|
|
|
|
if not detector.model:
|
|
logger.error("Failed to load YOLO model")
|
|
sys.exit(1)
|
|
|
|
logger.info(f"YOLO model loaded: {detector.model_path}")
|
|
logger.info(f"Confidence threshold: {id_detection_config.get('confidence_threshold', 0.25)}")
|
|
logger.info(f"Crop mode: {id_detection_config.get('crop_mode', 'bbox')}")
|
|
|
|
# Bước 1: Detect và crop ID cards vào thư mục processed
|
|
processed_dir = output_dir / "processed"
|
|
processed_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info("Step 1: Detect and crop ID cards...")
|
|
detector.batch_process(
|
|
input_dir=input_dir,
|
|
output_dir=processed_dir,
|
|
confidence=id_detection_config.get('confidence_threshold', 0.25),
|
|
iou_threshold=id_detection_config.get('iou_threshold', 0.45),
|
|
crop_mode=id_detection_config.get('crop_mode', 'bbox'),
|
|
target_size=id_detection_config.get('target_size'),
|
|
padding=id_detection_config.get('padding', 10)
|
|
)
|
|
# Bước 2: Augment các card đã crop với strategy mới
|
|
logger.info("Step 2: Augment cropped ID cards with smart strategy...")
|
|
augmenter = DataAugmentation(augmentation_config)
|
|
|
|
# Truyền full config để augmenter có thể access data_strategy
|
|
augmenter.config.update({"data_strategy": data_strategy_config})
|
|
|
|
augment_results = augmenter.batch_augment(
|
|
processed_dir,
|
|
output_dir,
|
|
multiplication_factor=multiplication_factor,
|
|
random_seed=random_seed
|
|
)
|
|
|
|
# Log results
|
|
if augment_results:
|
|
logger.info(f"Augmentation Summary:")
|
|
logger.info(f" Input images: {augment_results.get('input_images', 0)}")
|
|
logger.info(f" Selected for processing: {augment_results.get('selected_images', 0)}")
|
|
logger.info(f" Target total: {augment_results.get('target_total', 0)}")
|
|
logger.info(f" Actually generated: {augment_results.get('actual_generated', 0)}")
|
|
logger.info(f" Efficiency: {augment_results.get('efficiency', 0):.1%}")
|
|
else:
|
|
# Augment trực tiếp ảnh gốc với strategy mới
|
|
logger.info("Starting smart batch augmentation (direct augmentation)...")
|
|
augmenter = DataAugmentation(augmentation_config)
|
|
|
|
# Truyền full config để augmenter có thể access data_strategy
|
|
augmenter.config.update({"data_strategy": data_strategy_config})
|
|
|
|
augment_results = augmenter.batch_augment(
|
|
input_dir,
|
|
output_dir,
|
|
multiplication_factor=multiplication_factor,
|
|
random_seed=random_seed
|
|
)
|
|
|
|
# Log results
|
|
if augment_results:
|
|
logger.info(f"Augmentation Summary:")
|
|
logger.info(f" Input images: {augment_results.get('input_images', 0)}")
|
|
logger.info(f" Selected for processing: {augment_results.get('selected_images', 0)}")
|
|
logger.info(f" Target total: {augment_results.get('target_total', 0)}")
|
|
logger.info(f" Actually generated: {augment_results.get('actual_generated', 0)}")
|
|
logger.info(f" Efficiency: {augment_results.get('efficiency', 0):.1%}")
|
|
|
|
logger.info("Data processing completed successfully")
|
|
|
|
if __name__ == "__main__":
|
|
main() |