init
This commit is contained in:
98
src/utils.py
Normal file
98
src/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Utility functions for data augmentation
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
||||
"""Setup logging configuration"""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('data_augmentation.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
def get_image_files(directory: Path) -> List[Path]:
|
||||
"""Get all image files from directory"""
|
||||
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
|
||||
|
||||
image_files = []
|
||||
if directory.exists():
|
||||
for ext in SUPPORTED_FORMATS:
|
||||
image_files.extend(directory.glob(f"*{ext}"))
|
||||
image_files.extend(directory.glob(f"*{ext.upper()}"))
|
||||
return sorted(image_files)
|
||||
|
||||
def validate_image(image_path: Path) -> bool:
|
||||
"""Validate if file is a valid image"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
img.verify()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load_image(image_path: Path, target_size: Tuple[int, int] = None) -> Optional[np.ndarray]:
|
||||
"""Load and resize image"""
|
||||
try:
|
||||
# Load image using OpenCV
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
# Convert BGR to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize if target_size is provided
|
||||
if target_size:
|
||||
image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def save_image(image: np.ndarray, output_path: Path, quality: int = 95) -> bool:
|
||||
"""Save image to file"""
|
||||
try:
|
||||
# Convert RGB to BGR for OpenCV
|
||||
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save image
|
||||
cv2.imwrite(str(output_path), image_bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving image {output_path}: {e}")
|
||||
return False
|
||||
|
||||
def create_augmented_filename(original_path: Path, index: int, suffix: str = "aug") -> Path:
|
||||
"""Create filename for augmented image"""
|
||||
stem = original_path.stem
|
||||
suffix = f"_{suffix}_{index:02d}"
|
||||
return original_path.parent / f"{stem}{suffix}{original_path.suffix}"
|
||||
|
||||
def get_file_size_mb(file_path: Path) -> float:
|
||||
"""Get file size in MB"""
|
||||
return file_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
def print_progress(current: int, total: int, prefix: str = "Progress"):
|
||||
"""Print progress bar"""
|
||||
bar_length = 50
|
||||
filled_length = int(round(bar_length * current / float(total)))
|
||||
percents = round(100.0 * current / float(total), 1)
|
||||
bar = '=' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f'\r{prefix}: [{bar}] {percents}% ({current}/{total})', end='')
|
||||
if current == total:
|
||||
print()
|
Reference in New Issue
Block a user