init
This commit is contained in:
175
src/config_manager.py
Normal file
175
src/config_manager.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Configuration manager for data augmentation
|
||||
"""
|
||||
import yaml
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
class ConfigManager:
|
||||
"""Manages configuration loading and validation"""
|
||||
|
||||
def __init__(self, config_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Initialize ConfigManager
|
||||
|
||||
Args:
|
||||
config_path: Path to main config file
|
||||
"""
|
||||
self.config_path = Path(config_path) if config_path else Path("config/config.yaml")
|
||||
self.config = {}
|
||||
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load main configuration file"""
|
||||
try:
|
||||
if self.config_path.exists():
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
print(f"✅ Loaded configuration from {self.config_path}")
|
||||
else:
|
||||
print(f"⚠️ Config file not found: {self.config_path}")
|
||||
self.config = self._get_default_config()
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading config: {e}")
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
"paths": {
|
||||
"input_dir": "data/dataset/training_data/images",
|
||||
"output_dir": "data/augmented_data",
|
||||
"log_file": "logs/data_augmentation.log"
|
||||
},
|
||||
"augmentation": {
|
||||
"rotation": {"enabled": True, "angles": [30, 60, 120, 150, 180, 210, 240, 300, 330], "probability": 1.0}
|
||||
},
|
||||
"processing": {
|
||||
"target_size": [224, 224],
|
||||
"batch_size": 32,
|
||||
"num_augmentations": 3,
|
||||
"save_format": "jpg",
|
||||
"quality": 95
|
||||
},
|
||||
"supported_formats": [".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
|
||||
"logging": {
|
||||
"level": "INFO",
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
},
|
||||
"performance": {
|
||||
"num_workers": 4,
|
||||
"prefetch_factor": 2,
|
||||
"pin_memory": True,
|
||||
"use_gpu": False
|
||||
}
|
||||
}
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current configuration"""
|
||||
return self.config
|
||||
|
||||
def get_augmentation_config(self) -> Dict[str, Any]:
|
||||
"""Get augmentation configuration"""
|
||||
return self.config.get("augmentation", {})
|
||||
|
||||
def get_processing_config(self) -> Dict[str, Any]:
|
||||
"""Get processing configuration"""
|
||||
return self.config.get("processing", {})
|
||||
|
||||
def get_paths_config(self) -> Dict[str, Any]:
|
||||
"""Get paths configuration"""
|
||||
return self.config.get("paths", {})
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""Get logging configuration"""
|
||||
return self.config.get("logging", {})
|
||||
|
||||
def get_performance_config(self) -> Dict[str, Any]:
|
||||
"""Get performance configuration"""
|
||||
return self.config.get("performance", {})
|
||||
|
||||
|
||||
|
||||
def update_config(self, updates: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update configuration with new values
|
||||
|
||||
Args:
|
||||
updates: Dictionary with updates to apply
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
try:
|
||||
self.config = self._merge_configs(self.config, updates)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error updating config: {e}")
|
||||
return False
|
||||
|
||||
def _merge_configs(self, base_config: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge updates with base configuration"""
|
||||
merged = base_config.copy()
|
||||
|
||||
def deep_merge(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = base.copy()
|
||||
for key, value in update.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
return deep_merge(merged, updates)
|
||||
|
||||
def save_config(self, output_path: Optional[Union[str, Path]] = None) -> bool:
|
||||
"""
|
||||
Save current configuration to file
|
||||
|
||||
Args:
|
||||
output_path: Path to save config file
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
"""
|
||||
try:
|
||||
output_path = Path(output_path) if output_path else self.config_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(self.config, f, default_flow_style=False, indent=2, allow_unicode=True)
|
||||
|
||||
print(f"✅ Configuration saved to {output_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving config: {e}")
|
||||
return False
|
||||
|
||||
def print_config_summary(self):
|
||||
"""Print configuration summary"""
|
||||
print("\n" + "="*50)
|
||||
print("CONFIGURATION SUMMARY")
|
||||
print("="*50)
|
||||
|
||||
# Paths
|
||||
paths = self.get_paths_config()
|
||||
print(f"Input directory: {paths.get('input_dir', 'Not set')}")
|
||||
print(f"Output directory: {paths.get('output_dir', 'Not set')}")
|
||||
|
||||
# Processing
|
||||
processing = self.get_processing_config()
|
||||
print(f"Target size: {processing.get('target_size', 'Not set')}")
|
||||
print(f"Number of augmentations: {processing.get('num_augmentations', 'Not set')}")
|
||||
|
||||
# Augmentation
|
||||
augmentation = self.get_augmentation_config()
|
||||
enabled_augmentations = []
|
||||
for name, config in augmentation.items():
|
||||
if isinstance(config, dict) and config.get('enabled', False):
|
||||
enabled_augmentations.append(name)
|
||||
|
||||
print(f"Enabled augmentations: {', '.join(enabled_augmentations) if enabled_augmentations else 'None'}")
|
||||
|
||||
print("="*50)
|
Reference in New Issue
Block a user