This commit is contained in:
Nguyễn Phước Thành
2025-08-05 19:09:55 +07:00
commit 24060e4ce7
25 changed files with 2268 additions and 0 deletions

98
src/utils.py Normal file
View File

@@ -0,0 +1,98 @@
"""
Utility functions for data augmentation
"""
import os
import logging
from pathlib import Path
from typing import List, Tuple, Optional
import cv2
import numpy as np
from PIL import Image
def setup_logging(log_level: str = "INFO") -> logging.Logger:
"""Setup logging configuration"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('data_augmentation.log'),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
def get_image_files(directory: Path) -> List[Path]:
"""Get all image files from directory"""
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_files = []
if directory.exists():
for ext in SUPPORTED_FORMATS:
image_files.extend(directory.glob(f"*{ext}"))
image_files.extend(directory.glob(f"*{ext.upper()}"))
return sorted(image_files)
def validate_image(image_path: Path) -> bool:
"""Validate if file is a valid image"""
try:
with Image.open(image_path) as img:
img.verify()
return True
except Exception:
return False
def load_image(image_path: Path, target_size: Tuple[int, int] = None) -> Optional[np.ndarray]:
"""Load and resize image"""
try:
# Load image using OpenCV
image = cv2.imread(str(image_path))
if image is None:
return None
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize if target_size is provided
if target_size:
image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
return image
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
def save_image(image: np.ndarray, output_path: Path, quality: int = 95) -> bool:
"""Save image to file"""
try:
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Create output directory if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save image
cv2.imwrite(str(output_path), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
return True
except Exception as e:
print(f"Error saving image {output_path}: {e}")
return False
def create_augmented_filename(original_path: Path, index: int, suffix: str = "aug") -> Path:
"""Create filename for augmented image"""
stem = original_path.stem
suffix = f"_{suffix}_{index:02d}"
return original_path.parent / f"{stem}{suffix}{original_path.suffix}"
def get_file_size_mb(file_path: Path) -> float:
"""Get file size in MB"""
return file_path.stat().st_size / (1024 * 1024)
def print_progress(current: int, total: int, prefix: str = "Progress"):
"""Print progress bar"""
bar_length = 50
filled_length = int(round(bar_length * current / float(total)))
percents = round(100.0 * current / float(total), 1)
bar = '=' * filled_length + '-' * (bar_length - filled_length)
print(f'\r{prefix}: [{bar}] {percents}% ({current}/{total})', end='')
if current == total:
print()