86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
"""
|
|
Utility functions for data augmentation
|
|
"""
|
|
import os
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
|
"""Setup logging configuration"""
|
|
logging.basicConfig(
|
|
level=getattr(logging, log_level.upper()),
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('data_augmentation.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
def get_image_files(directory: Path) -> List[Path]:
|
|
"""Get all image files from directory"""
|
|
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
|
|
|
|
image_files = []
|
|
if directory.exists():
|
|
for ext in SUPPORTED_FORMATS:
|
|
image_files.extend(directory.glob(f"*{ext}"))
|
|
image_files.extend(directory.glob(f"*{ext.upper()}"))
|
|
return sorted(image_files)
|
|
|
|
|
|
|
|
def load_image(image_path: Path, target_size: Tuple[int, int] = None) -> Optional[np.ndarray]:
|
|
"""Load and resize image"""
|
|
try:
|
|
# Load image using OpenCV
|
|
image = cv2.imread(str(image_path))
|
|
if image is None:
|
|
return None
|
|
# Convert BGR to RGB
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
# Resize if target_size is provided
|
|
if target_size:
|
|
image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
|
return image
|
|
except Exception as e:
|
|
print(f"Error loading image {image_path}: {e}")
|
|
return None
|
|
|
|
def save_image(image: np.ndarray, output_path: Path, quality: int = 95) -> bool:
|
|
"""Save image to file"""
|
|
try:
|
|
# Convert RGB to BGR for OpenCV
|
|
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
|
# Create output directory if it doesn't exist
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save image
|
|
cv2.imwrite(str(output_path), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error saving image {output_path}: {e}")
|
|
return False
|
|
|
|
def create_augmented_filename(original_path: Path, index: int, method: str = "aug") -> Path:
|
|
"""Create filename for augmented image with method name"""
|
|
stem = original_path.stem
|
|
suffix = f"_{method}_{index:02d}"
|
|
return original_path.parent / f"{stem}{suffix}{original_path.suffix}"
|
|
|
|
|
|
|
|
def print_progress(current: int, total: int, prefix: str = "Progress"):
|
|
"""Print progress bar"""
|
|
bar_length = 50
|
|
filled_length = int(round(bar_length * current / float(total)))
|
|
percents = round(100.0 * current / float(total), 1)
|
|
bar = '=' * filled_length + '-' * (bar_length - filled_length)
|
|
print(f'\r{prefix}: [{bar}] {percents}% ({current}/{total})', end='')
|
|
if current == total:
|
|
print() |