update augment + YOLO pipeline
This commit is contained in:
170
main.py
170
main.py
@@ -12,6 +12,7 @@ sys.path.append(str(Path(__file__).parent / "src"))
|
||||
from src.config_manager import ConfigManager
|
||||
from src.data_augmentation import DataAugmentation
|
||||
from src.image_processor import ImageProcessor
|
||||
from src.id_card_detector import IDCardDetector
|
||||
from src.utils import setup_logging, get_image_files, print_progress
|
||||
|
||||
def parse_arguments():
|
||||
@@ -83,6 +84,38 @@ def parse_arguments():
|
||||
help="Logging level"
|
||||
)
|
||||
|
||||
# ID Card Detection arguments
|
||||
parser.add_argument(
|
||||
"--enable-id-detection",
|
||||
action="store_true",
|
||||
help="Enable ID card detection and cropping before augmentation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
help="Path to YOLO model for ID card detection (overrides config)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence",
|
||||
type=float,
|
||||
help="Confidence threshold for ID card detection (overrides config)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--crop-mode",
|
||||
type=str,
|
||||
choices=["bbox", "square", "aspect_ratio"],
|
||||
help="Crop mode for ID cards (overrides config)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--crop-target-size",
|
||||
type=str,
|
||||
help="Target size for cropped ID cards (widthxheight) (overrides config)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def parse_range(range_str: str) -> tuple:
|
||||
@@ -134,7 +167,8 @@ def show_image_info(input_dir: Path):
|
||||
print(f"\nTotal file size: {total_size:.2f} MB")
|
||||
print(f"Average file size: {total_size/len(image_files):.2f} MB")
|
||||
|
||||
def preview_augmentation(input_dir: Path, output_dir: Path, config: Dict[str, Any]):
|
||||
def preview_augmentation(input_dir: Path, output_dir: Path, config: Dict[str, Any],
|
||||
id_detection_config: Dict[str, Any] = None):
|
||||
"""Preview augmentation on first image"""
|
||||
image_files = get_image_files(input_dir)
|
||||
|
||||
@@ -147,7 +181,40 @@ def preview_augmentation(input_dir: Path, output_dir: Path, config: Dict[str, An
|
||||
# Create augmentation instance
|
||||
augmenter = DataAugmentation(config)
|
||||
|
||||
# Augment first image
|
||||
# Process with ID detection if enabled
|
||||
if id_detection_config and id_detection_config.get('enabled', False):
|
||||
print("🔍 ID Card Detection enabled - processing with YOLO model...")
|
||||
|
||||
# Initialize ID card detector
|
||||
detector = IDCardDetector(
|
||||
model_path=id_detection_config.get('model_path'),
|
||||
config=config
|
||||
)
|
||||
|
||||
if not detector.model:
|
||||
print("❌ Failed to load YOLO model, proceeding with normal augmentation")
|
||||
else:
|
||||
# Process single image with ID detection
|
||||
result = detector.process_single_image(
|
||||
image_path=image_files[0],
|
||||
output_dir=output_dir,
|
||||
apply_augmentation=True,
|
||||
save_original=id_detection_config.get('save_original_crops', True),
|
||||
confidence=id_detection_config.get('confidence_threshold', 0.25),
|
||||
iou_threshold=id_detection_config.get('iou_threshold', 0.45),
|
||||
crop_mode=id_detection_config.get('crop_mode', 'bbox'),
|
||||
target_size=id_detection_config.get('target_size'),
|
||||
padding=id_detection_config.get('padding', 10)
|
||||
)
|
||||
|
||||
if result and result.get('detections'):
|
||||
print(f"✅ Detected {len(result['detections'])} ID cards")
|
||||
print(f"💾 Saved {len(result['processed_cards'])} processed cards")
|
||||
return
|
||||
else:
|
||||
print("⚠️ No ID cards detected, proceeding with normal augmentation")
|
||||
|
||||
# Normal augmentation (fallback)
|
||||
augmented_paths = augmenter.augment_image_file(
|
||||
image_files[0],
|
||||
output_dir,
|
||||
@@ -225,9 +292,29 @@ def main():
|
||||
show_image_info(input_dir)
|
||||
return
|
||||
|
||||
# Get ID detection config
|
||||
id_detection_config = config.get('id_card_detection', {})
|
||||
|
||||
# Override ID detection config with command line arguments
|
||||
if args.enable_id_detection:
|
||||
id_detection_config['enabled'] = True
|
||||
|
||||
if args.model_path:
|
||||
id_detection_config['model_path'] = args.model_path
|
||||
|
||||
if args.confidence:
|
||||
id_detection_config['confidence_threshold'] = args.confidence
|
||||
|
||||
if args.crop_mode:
|
||||
id_detection_config['crop_mode'] = args.crop_mode
|
||||
|
||||
if args.crop_target_size:
|
||||
target_size = parse_size(args.crop_target_size)
|
||||
id_detection_config['target_size'] = list(target_size)
|
||||
|
||||
# Preview augmentation if requested
|
||||
if args.preview:
|
||||
preview_augmentation(input_dir, output_dir, augmentation_config)
|
||||
preview_augmentation(input_dir, output_dir, augmentation_config, id_detection_config)
|
||||
return
|
||||
|
||||
# Get image files
|
||||
@@ -242,35 +329,56 @@ def main():
|
||||
logger.info(f"Number of augmentations per image: {processing_config.get('num_augmentations', 3)}")
|
||||
logger.info(f"Target size: {processing_config.get('target_size', [224, 224])}")
|
||||
|
||||
# Create augmentation instance with new config
|
||||
augmenter = DataAugmentation(augmentation_config)
|
||||
# Process with ID detection if enabled
|
||||
if id_detection_config.get('enabled', False):
|
||||
logger.info("ID Card Detection enabled - processing with YOLO model...")
|
||||
|
||||
# Initialize ID card detector
|
||||
detector = IDCardDetector(
|
||||
model_path=id_detection_config.get('model_path'),
|
||||
config=config
|
||||
)
|
||||
|
||||
if not detector.model:
|
||||
logger.error("Failed to load YOLO model")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"YOLO model loaded: {detector.model_path}")
|
||||
logger.info(f"Confidence threshold: {id_detection_config.get('confidence_threshold', 0.25)}")
|
||||
logger.info(f"Crop mode: {id_detection_config.get('crop_mode', 'bbox')}")
|
||||
|
||||
# Bước 1: Detect và crop ID cards vào thư mục processed
|
||||
processed_dir = output_dir / "processed"
|
||||
processed_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Step 1: Detect and crop ID cards...")
|
||||
detector.batch_process(
|
||||
input_dir=input_dir,
|
||||
output_dir=processed_dir,
|
||||
confidence=id_detection_config.get('confidence_threshold', 0.25),
|
||||
iou_threshold=id_detection_config.get('iou_threshold', 0.45),
|
||||
crop_mode=id_detection_config.get('crop_mode', 'bbox'),
|
||||
target_size=id_detection_config.get('target_size'),
|
||||
padding=id_detection_config.get('padding', 10)
|
||||
)
|
||||
# Bước 2: Augment các card đã crop
|
||||
logger.info("Step 2: Augment cropped ID cards...")
|
||||
augmenter = DataAugmentation(augmentation_config)
|
||||
augmenter.batch_augment(
|
||||
processed_dir,
|
||||
output_dir,
|
||||
num_augmentations=processing_config.get("num_augmentations", 3)
|
||||
)
|
||||
else:
|
||||
# Augment trực tiếp ảnh gốc
|
||||
logger.info("Starting normal batch augmentation (direct augmentation)...")
|
||||
augmenter = DataAugmentation(augmentation_config)
|
||||
augmenter.batch_augment(
|
||||
input_dir,
|
||||
output_dir,
|
||||
num_augmentations=processing_config.get("num_augmentations", 3)
|
||||
)
|
||||
|
||||
# Update target size
|
||||
target_size = tuple(processing_config.get("target_size", [224, 224]))
|
||||
augmenter.image_processor.target_size = target_size
|
||||
|
||||
# Perform batch augmentation
|
||||
logger.info("Starting batch augmentation...")
|
||||
results = augmenter.batch_augment(
|
||||
input_dir,
|
||||
output_dir,
|
||||
num_augmentations=processing_config.get("num_augmentations", 3)
|
||||
)
|
||||
|
||||
# Get and display summary
|
||||
summary = augmenter.get_augmentation_summary(results)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("AUGMENTATION SUMMARY")
|
||||
print("="*50)
|
||||
print(f"Original images: {summary['total_original_images']}")
|
||||
print(f"Augmented images: {summary['total_augmented_images']}")
|
||||
print(f"Augmentation ratio: {summary['augmentation_ratio']:.2f}")
|
||||
print(f"Successful augmentations: {summary['successful_augmentations']}")
|
||||
print(f"Output directory: {output_dir}")
|
||||
print("="*50)
|
||||
|
||||
logger.info("Data augmentation completed successfully")
|
||||
logger.info("Data processing completed successfully")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user