""" Main script for data augmentation """ import argparse import sys from pathlib import Path from typing import Dict, Any # Add src to path for imports sys.path.append(str(Path(__file__).parent / "src")) from src.config_manager import ConfigManager from src.data_augmentation import DataAugmentation from src.image_processor import ImageProcessor from src.id_card_detector import IDCardDetector from src.utils import setup_logging, get_image_files, print_progress def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser(description="Image Data Augmentation Tool") parser.add_argument( "--config", type=str, default="config/config.yaml", help="Path to configuration file" ) parser.add_argument( "--preset", type=str, help="Apply augmentation preset (light, medium, heavy, ocr_optimized, document)" ) parser.add_argument( "--input-dir", type=str, help="Input directory containing images (overrides config)" ) parser.add_argument( "--output-dir", type=str, help="Output directory for augmented images (overrides config)" ) parser.add_argument( "--num-augmentations", type=int, help="Number of augmented versions per image (overrides config)" ) parser.add_argument( "--target-size", type=str, help="Target size for images (width x height) (overrides config)" ) parser.add_argument( "--preview", action="store_true", help="Preview augmentation on first image only" ) parser.add_argument( "--info", action="store_true", help="Show information about images in input directory" ) parser.add_argument( "--list-presets", action="store_true", help="List available presets and exit" ) parser.add_argument( "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level" ) # ID Card Detection arguments parser.add_argument( "--enable-id-detection", action="store_true", help="Enable ID card detection and cropping before augmentation" ) parser.add_argument( "--model-path", type=str, help="Path to YOLO model for ID card detection (overrides config)" ) parser.add_argument( "--confidence", type=float, help="Confidence threshold for ID card detection (overrides config)" ) parser.add_argument( "--crop-mode", type=str, choices=["bbox", "square", "aspect_ratio"], help="Crop mode for ID cards (overrides config)" ) parser.add_argument( "--crop-target-size", type=str, help="Target size for cropped ID cards (widthxheight) (overrides config)" ) return parser.parse_args() def parse_range(range_str: str) -> tuple: """Parse range string like '0.8-1.2' to tuple (0.8, 1.2)""" try: min_val, max_val = map(float, range_str.split('-')) return (min_val, max_val) except ValueError: print(f"Invalid range format: {range_str}. Expected format: min-max") sys.exit(1) def parse_size(size_str: str) -> tuple: """Parse size string like '224x224' to tuple (224, 224)""" try: width, height = map(int, size_str.split('x')) return (width, height) except ValueError: print(f"Invalid size format: {size_str}. Expected format: widthxheight") sys.exit(1) def show_image_info(input_dir: Path): """Show information about images in input directory""" image_files = get_image_files(input_dir) if not image_files: print(f"No images found in {input_dir}") return print(f"\nFound {len(image_files)} images in {input_dir}") print("\nImage Information:") print("-" * 80) processor = ImageProcessor() total_size = 0 for i, image_path in enumerate(image_files[:10]): # Show first 10 images info = processor.get_image_info(image_path) if info: print(f"{i+1:2d}. {image_path.name}") print(f" Size: {info['width']}x{info['height']} pixels") print(f" Channels: {info['channels']}") print(f" File size: {info['file_size_mb']} MB") print(f" Format: {info['format']}") total_size += info['file_size_mb'] if len(image_files) > 10: print(f"\n... and {len(image_files) - 10} more images") print(f"\nTotal file size: {total_size:.2f} MB") print(f"Average file size: {total_size/len(image_files):.2f} MB") def preview_augmentation(input_dir: Path, output_dir: Path, config: Dict[str, Any], id_detection_config: Dict[str, Any] = None): """Preview augmentation on first image""" image_files = get_image_files(input_dir) if not image_files: print(f"No images found in {input_dir}") return print(f"\nPreviewing augmentation on: {image_files[0].name}") # Create augmentation instance augmenter = DataAugmentation(config) # Process with ID detection if enabled if id_detection_config and id_detection_config.get('enabled', False): print("🔍 ID Card Detection enabled - processing with YOLO model...") # Initialize ID card detector detector = IDCardDetector( model_path=id_detection_config.get('model_path'), config=config ) if not detector.model: print("❌ Failed to load YOLO model, proceeding with normal augmentation") else: # Process single image with ID detection result = detector.process_single_image( image_path=image_files[0], output_dir=output_dir, apply_augmentation=True, save_original=id_detection_config.get('save_original_crops', True), confidence=id_detection_config.get('confidence_threshold', 0.25), iou_threshold=id_detection_config.get('iou_threshold', 0.45), crop_mode=id_detection_config.get('crop_mode', 'bbox'), target_size=id_detection_config.get('target_size'), padding=id_detection_config.get('padding', 10) ) if result and result.get('detections'): print(f"✅ Detected {len(result['detections'])} ID cards") print(f"💾 Saved {len(result['processed_cards'])} processed cards") return else: print("⚠️ No ID cards detected, proceeding with normal augmentation") # Normal augmentation (fallback) with new logic augmented_paths = augmenter.augment_image_file( image_files[0], output_dir, num_target_images=3 ) if augmented_paths: print(f"Created {len(augmented_paths)} augmented versions:") for i, path in enumerate(augmented_paths, 1): print(f" {i}. {path.name}") else: print("Failed to create augmented images") def main(): """Main function""" args = parse_arguments() # Initialize config manager config_manager = ConfigManager(args.config) # List presets if requested if args.list_presets: presets = config_manager.list_presets() print("\nAvailable presets:") for preset in presets: print(f" - {preset}") return # Apply preset if specified if args.preset: if not config_manager.apply_preset(args.preset): sys.exit(1) # Override config with command line arguments if args.input_dir: config_manager.update_config({"paths": {"input_dir": args.input_dir}}) if args.output_dir: config_manager.update_config({"paths": {"output_dir": args.output_dir}}) if args.num_augmentations: config_manager.update_config({"processing": {"num_augmentations": args.num_augmentations}}) if args.target_size: target_size = parse_size(args.target_size) config_manager.update_config({"processing": {"target_size": list(target_size)}}) # Get configuration config = config_manager.get_config() paths_config = config_manager.get_paths_config() processing_config = config_manager.get_processing_config() augmentation_config = config_manager.get_augmentation_config() logging_config = config_manager.get_logging_config() data_strategy_config = config.get("data_strategy", {}) # Setup logging logger = setup_logging(logging_config.get("level", "INFO")) logger.info("Starting data augmentation process") # Parse paths input_dir = Path(paths_config.get("input_dir", "data/dataset/training_data/images")) output_dir = Path(paths_config.get("output_dir", "data/augmented_data")) # Check if input directory exists if not input_dir.exists(): logger.error(f"Input directory does not exist: {input_dir}") sys.exit(1) # Create output directory output_dir.mkdir(parents=True, exist_ok=True) # Show image information if requested if args.info: show_image_info(input_dir) return # Get ID detection config id_detection_config = config.get('id_card_detection', {}) # Override ID detection config with command line arguments if args.enable_id_detection: id_detection_config['enabled'] = True if args.model_path: id_detection_config['model_path'] = args.model_path if args.confidence: id_detection_config['confidence_threshold'] = args.confidence if args.crop_mode: id_detection_config['crop_mode'] = args.crop_mode if args.crop_target_size: target_size = parse_size(args.crop_target_size) id_detection_config['target_size'] = list(target_size) # Preview augmentation if requested if args.preview: preview_augmentation(input_dir, output_dir, augmentation_config, id_detection_config) return # Get image files image_files = get_image_files(input_dir) if not image_files: logger.error(f"No images found in {input_dir}") sys.exit(1) # Get data strategy parameters multiplication_factor = data_strategy_config.get("multiplication_factor", 3.0) random_seed = data_strategy_config.get("random_seed") logger.info(f"Found {len(image_files)} images to process") logger.info(f"Output directory: {output_dir}") logger.info(f"Data strategy: multiplication_factor = {multiplication_factor}") if multiplication_factor < 1.0: logger.info(f"SAMPLING MODE: Will process {multiplication_factor*100:.1f}% of input images") else: logger.info(f"MULTIPLICATION MODE: Target {multiplication_factor}x dataset size") logger.info(f"Target size: {processing_config.get('target_size', [224, 224])}") if random_seed: logger.info(f"Random seed: {random_seed}") # Process with ID detection if enabled if id_detection_config.get('enabled', False): logger.info("ID Card Detection enabled - processing with YOLO model...") # Initialize ID card detector detector = IDCardDetector( model_path=id_detection_config.get('model_path'), config=config ) if not detector.model: logger.error("Failed to load YOLO model") sys.exit(1) logger.info(f"YOLO model loaded: {detector.model_path}") logger.info(f"Confidence threshold: {id_detection_config.get('confidence_threshold', 0.25)}") logger.info(f"Crop mode: {id_detection_config.get('crop_mode', 'bbox')}") # Bước 1: Detect và crop ID cards vào thư mục processed processed_dir = output_dir / "processed" processed_dir.mkdir(parents=True, exist_ok=True) logger.info("Step 1: Detect and crop ID cards...") detector.batch_process( input_dir=input_dir, output_dir=processed_dir, confidence=id_detection_config.get('confidence_threshold', 0.25), iou_threshold=id_detection_config.get('iou_threshold', 0.45), crop_mode=id_detection_config.get('crop_mode', 'bbox'), target_size=id_detection_config.get('target_size'), padding=id_detection_config.get('padding', 10) ) # Bước 2: Augment các card đã crop với strategy mới logger.info("Step 2: Augment cropped ID cards with smart strategy...") augmenter = DataAugmentation(augmentation_config) # Truyền full config để augmenter có thể access data_strategy augmenter.config.update({"data_strategy": data_strategy_config}) augment_results = augmenter.batch_augment( processed_dir, output_dir, multiplication_factor=multiplication_factor, random_seed=random_seed ) # Log results if augment_results: logger.info(f"Augmentation Summary:") logger.info(f" Input images: {augment_results.get('input_images', 0)}") logger.info(f" Selected for processing: {augment_results.get('selected_images', 0)}") logger.info(f" Target total: {augment_results.get('target_total', 0)}") logger.info(f" Actually generated: {augment_results.get('actual_generated', 0)}") logger.info(f" Efficiency: {augment_results.get('efficiency', 0):.1%}") else: # Augment trực tiếp ảnh gốc với strategy mới logger.info("Starting smart batch augmentation (direct augmentation)...") augmenter = DataAugmentation(augmentation_config) # Truyền full config để augmenter có thể access data_strategy augmenter.config.update({"data_strategy": data_strategy_config}) augment_results = augmenter.batch_augment( input_dir, output_dir, multiplication_factor=multiplication_factor, random_seed=random_seed ) # Log results if augment_results: logger.info(f"Augmentation Summary:") logger.info(f" Input images: {augment_results.get('input_images', 0)}") logger.info(f" Selected for processing: {augment_results.get('selected_images', 0)}") logger.info(f" Target total: {augment_results.get('target_total', 0)}") logger.info(f" Actually generated: {augment_results.get('actual_generated', 0)}") logger.info(f" Efficiency: {augment_results.get('efficiency', 0):.1%}") logger.info("Data processing completed successfully") if __name__ == "__main__": main()