175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
![]() |
"""
|
||
|
Configuration manager for data augmentation
|
||
|
"""
|
||
|
import yaml
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
from typing import Dict, Any, Optional, Union
|
||
|
|
||
|
class ConfigManager:
|
||
|
"""Manages configuration loading and validation"""
|
||
|
|
||
|
def __init__(self, config_path: Optional[Union[str, Path]] = None):
|
||
|
"""
|
||
|
Initialize ConfigManager
|
||
|
|
||
|
Args:
|
||
|
config_path: Path to main config file
|
||
|
"""
|
||
|
self.config_path = Path(config_path) if config_path else Path("config/config.yaml")
|
||
|
self.config = {}
|
||
|
|
||
|
self._load_config()
|
||
|
|
||
|
def _load_config(self):
|
||
|
"""Load main configuration file"""
|
||
|
try:
|
||
|
if self.config_path.exists():
|
||
|
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
|
self.config = yaml.safe_load(f)
|
||
|
print(f"✅ Loaded configuration from {self.config_path}")
|
||
|
else:
|
||
|
print(f"⚠️ Config file not found: {self.config_path}")
|
||
|
self.config = self._get_default_config()
|
||
|
except Exception as e:
|
||
|
print(f"❌ Error loading config: {e}")
|
||
|
self.config = self._get_default_config()
|
||
|
|
||
|
def _get_default_config(self) -> Dict[str, Any]:
|
||
|
"""Get default configuration"""
|
||
|
return {
|
||
|
"paths": {
|
||
|
"input_dir": "data/dataset/training_data/images",
|
||
|
"output_dir": "data/augmented_data",
|
||
|
"log_file": "logs/data_augmentation.log"
|
||
|
},
|
||
|
"augmentation": {
|
||
|
"rotation": {"enabled": True, "angles": [30, 60, 120, 150, 180, 210, 240, 300, 330], "probability": 1.0}
|
||
|
},
|
||
|
"processing": {
|
||
|
"target_size": [224, 224],
|
||
|
"batch_size": 32,
|
||
|
"num_augmentations": 3,
|
||
|
"save_format": "jpg",
|
||
|
"quality": 95
|
||
|
},
|
||
|
"supported_formats": [".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
|
||
|
"logging": {
|
||
|
"level": "INFO",
|
||
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
|
},
|
||
|
"performance": {
|
||
|
"num_workers": 4,
|
||
|
"prefetch_factor": 2,
|
||
|
"pin_memory": True,
|
||
|
"use_gpu": False
|
||
|
}
|
||
|
}
|
||
|
|
||
|
def get_config(self) -> Dict[str, Any]:
|
||
|
"""Get current configuration"""
|
||
|
return self.config
|
||
|
|
||
|
def get_augmentation_config(self) -> Dict[str, Any]:
|
||
|
"""Get augmentation configuration"""
|
||
|
return self.config.get("augmentation", {})
|
||
|
|
||
|
def get_processing_config(self) -> Dict[str, Any]:
|
||
|
"""Get processing configuration"""
|
||
|
return self.config.get("processing", {})
|
||
|
|
||
|
def get_paths_config(self) -> Dict[str, Any]:
|
||
|
"""Get paths configuration"""
|
||
|
return self.config.get("paths", {})
|
||
|
|
||
|
def get_logging_config(self) -> Dict[str, Any]:
|
||
|
"""Get logging configuration"""
|
||
|
return self.config.get("logging", {})
|
||
|
|
||
|
def get_performance_config(self) -> Dict[str, Any]:
|
||
|
"""Get performance configuration"""
|
||
|
return self.config.get("performance", {})
|
||
|
|
||
|
|
||
|
|
||
|
def update_config(self, updates: Dict[str, Any]) -> bool:
|
||
|
"""
|
||
|
Update configuration with new values
|
||
|
|
||
|
Args:
|
||
|
updates: Dictionary with updates to apply
|
||
|
|
||
|
Returns:
|
||
|
True if updated successfully
|
||
|
"""
|
||
|
try:
|
||
|
self.config = self._merge_configs(self.config, updates)
|
||
|
return True
|
||
|
except Exception as e:
|
||
|
print(f"❌ Error updating config: {e}")
|
||
|
return False
|
||
|
|
||
|
def _merge_configs(self, base_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
"""Merge updates with base configuration"""
|
||
|
merged = base_config.copy()
|
||
|
|
||
|
def deep_merge(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
|
||
|
result = base.copy()
|
||
|
for key, value in update.items():
|
||
|
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||
|
result[key] = deep_merge(result[key], value)
|
||
|
else:
|
||
|
result[key] = value
|
||
|
return result
|
||
|
|
||
|
return deep_merge(merged, updates)
|
||
|
|
||
|
def save_config(self, output_path: Optional[Union[str, Path]] = None) -> bool:
|
||
|
"""
|
||
|
Save current configuration to file
|
||
|
|
||
|
Args:
|
||
|
output_path: Path to save config file
|
||
|
|
||
|
Returns:
|
||
|
True if saved successfully
|
||
|
"""
|
||
|
try:
|
||
|
output_path = Path(output_path) if output_path else self.config_path
|
||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||
|
yaml.dump(self.config, f, default_flow_style=False, indent=2, allow_unicode=True)
|
||
|
|
||
|
print(f"✅ Configuration saved to {output_path}")
|
||
|
return True
|
||
|
except Exception as e:
|
||
|
print(f"❌ Error saving config: {e}")
|
||
|
return False
|
||
|
|
||
|
def print_config_summary(self):
|
||
|
"""Print configuration summary"""
|
||
|
print("\n" + "="*50)
|
||
|
print("CONFIGURATION SUMMARY")
|
||
|
print("="*50)
|
||
|
|
||
|
# Paths
|
||
|
paths = self.get_paths_config()
|
||
|
print(f"Input directory: {paths.get('input_dir', 'Not set')}")
|
||
|
print(f"Output directory: {paths.get('output_dir', 'Not set')}")
|
||
|
|
||
|
# Processing
|
||
|
processing = self.get_processing_config()
|
||
|
print(f"Target size: {processing.get('target_size', 'Not set')}")
|
||
|
print(f"Number of augmentations: {processing.get('num_augmentations', 'Not set')}")
|
||
|
|
||
|
# Augmentation
|
||
|
augmentation = self.get_augmentation_config()
|
||
|
enabled_augmentations = []
|
||
|
for name, config in augmentation.items():
|
||
|
if isinstance(config, dict) and config.get('enabled', False):
|
||
|
enabled_augmentations.append(name)
|
||
|
|
||
|
print(f"Enabled augmentations: {', '.join(enabled_augmentations) if enabled_augmentations else 'None'}")
|
||
|
|
||
|
print("="*50)
|