combine augment
This commit is contained in:
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import random
|
||||
import math
|
||||
import logging
|
||||
from image_processor import ImageProcessor
|
||||
from utils import load_image, save_image, create_augmented_filename, print_progress
|
||||
|
||||
@@ -22,6 +23,7 @@ class DataAugmentation:
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.image_processor = ImageProcessor()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def random_crop_preserve_quality(self, image: np.ndarray, crop_ratio_range: Tuple[float, float] = (0.7, 1.0)) -> np.ndarray:
|
||||
"""
|
||||
@@ -363,21 +365,306 @@ class DataAugmentation:
|
||||
|
||||
return result
|
||||
|
||||
def augment_single_image(self, image: np.ndarray, num_augmentations: int = None) -> List[np.ndarray]:
|
||||
def augment_single_image(self, image: np.ndarray, num_target_images: int = None) -> List[np.ndarray]:
|
||||
"""
|
||||
Apply each augmentation method separately to create independent augmented versions
|
||||
Apply random combination of augmentation methods to create diverse augmented versions
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
num_augmentations: Number of augmented versions to create per method
|
||||
num_target_images: Number of target augmented images to generate
|
||||
|
||||
Returns:
|
||||
List of augmented images (each method creates separate versions)
|
||||
List of augmented images with random method combinations
|
||||
"""
|
||||
num_augmentations = num_augmentations or 3 # Default value
|
||||
num_target_images = num_target_images or 3 # Default value
|
||||
|
||||
# Get strategy config
|
||||
strategy_config = self.config.get("strategy", {})
|
||||
methods_config = self.config.get("methods", {})
|
||||
final_config = self.config.get("final_processing", {})
|
||||
|
||||
mode = strategy_config.get("mode", "random_combine")
|
||||
min_methods = strategy_config.get("min_methods", 2)
|
||||
max_methods = strategy_config.get("max_methods", 4)
|
||||
|
||||
if mode == "random_combine":
|
||||
return self._augment_random_combine(image, num_target_images, methods_config, final_config, min_methods, max_methods)
|
||||
elif mode == "sequential":
|
||||
return self._augment_sequential(image, num_target_images, methods_config, final_config)
|
||||
elif mode == "individual":
|
||||
return self._augment_individual_legacy(image, num_target_images)
|
||||
else:
|
||||
# Fallback to legacy method
|
||||
return self._augment_individual_legacy(image, num_target_images)
|
||||
|
||||
def _augment_random_combine(self, image: np.ndarray, num_target_images: int,
|
||||
methods_config: dict, final_config: dict,
|
||||
min_methods: int, max_methods: int) -> List[np.ndarray]:
|
||||
"""Apply random combination of methods"""
|
||||
augmented_images = []
|
||||
|
||||
# Get configuration
|
||||
# Get enabled methods with their probabilities
|
||||
available_methods = []
|
||||
for method_name, method_config in methods_config.items():
|
||||
if method_config.get("enabled", False):
|
||||
available_methods.append((method_name, method_config))
|
||||
|
||||
if not available_methods:
|
||||
self.logger.warning("No augmentation methods enabled!")
|
||||
return [image.copy() for _ in range(num_target_images)]
|
||||
|
||||
for i in range(num_target_images):
|
||||
# Decide number of methods for this image
|
||||
num_methods = random.randint(min_methods, min(max_methods, len(available_methods)))
|
||||
|
||||
# Select methods based on probability
|
||||
selected_methods = self._select_methods_by_probability(available_methods, num_methods)
|
||||
|
||||
# Apply selected methods in sequence
|
||||
augmented = image.copy()
|
||||
method_names = []
|
||||
|
||||
for method_name, method_config in selected_methods:
|
||||
if random.random() < method_config.get("probability", 0.5):
|
||||
augmented = self._apply_single_method(augmented, method_name, method_config)
|
||||
method_names.append(method_name)
|
||||
|
||||
# Apply final processing
|
||||
augmented = self._apply_final_processing(augmented, final_config)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
return augmented_images
|
||||
|
||||
def _select_methods_by_probability(self, available_methods: List[Tuple], num_methods: int) -> List[Tuple]:
|
||||
"""Select methods based on their probability weights"""
|
||||
# Create weighted list
|
||||
weighted_methods = []
|
||||
for method_name, method_config in available_methods:
|
||||
probability = method_config.get("probability", 0.5)
|
||||
weighted_methods.append((method_name, method_config, probability))
|
||||
|
||||
# Sort by probability (highest first) and select top candidates
|
||||
weighted_methods.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Use weighted random selection
|
||||
selected = []
|
||||
remaining_methods = weighted_methods.copy()
|
||||
|
||||
for _ in range(num_methods):
|
||||
if not remaining_methods:
|
||||
break
|
||||
|
||||
# Calculate cumulative probabilities
|
||||
total_prob = sum(method[2] for method in remaining_methods)
|
||||
if total_prob == 0:
|
||||
# If all probabilities are 0, select randomly
|
||||
selected_method = random.choice(remaining_methods)
|
||||
else:
|
||||
rand_val = random.uniform(0, total_prob)
|
||||
cumulative_prob = 0
|
||||
selected_method = None
|
||||
|
||||
for method in remaining_methods:
|
||||
cumulative_prob += method[2]
|
||||
if rand_val <= cumulative_prob:
|
||||
selected_method = method
|
||||
break
|
||||
|
||||
if selected_method is None:
|
||||
selected_method = remaining_methods[-1]
|
||||
|
||||
selected.append((selected_method[0], selected_method[1]))
|
||||
remaining_methods.remove(selected_method)
|
||||
|
||||
return selected
|
||||
|
||||
def _apply_single_method(self, image: np.ndarray, method_name: str, method_config: dict) -> np.ndarray:
|
||||
"""Apply a single augmentation method"""
|
||||
try:
|
||||
if method_name == "rotation":
|
||||
angles = method_config.get("angles", [30, 60, 90, 120, 150, 180, 210, 240, 300, 330])
|
||||
angle = random.choice(angles)
|
||||
return self.rotate_image_preserve_quality(image, angle)
|
||||
|
||||
elif method_name == "random_cropping":
|
||||
ratio_range = method_config.get("ratio_range", (0.7, 1.0))
|
||||
return self.random_crop_preserve_quality(image, ratio_range)
|
||||
|
||||
elif method_name == "random_noise":
|
||||
mean_range = method_config.get("mean_range", (0.0, 0.7))
|
||||
variance_range = method_config.get("variance_range", (0.0, 0.1))
|
||||
return self.add_random_noise_preserve_quality(image, mean_range, variance_range)
|
||||
|
||||
elif method_name == "partial_blockage":
|
||||
num_range = method_config.get("num_occlusions_range", (1, 100))
|
||||
coverage_range = method_config.get("coverage_range", (0.0, 0.25))
|
||||
variance_range = method_config.get("variance_range", (0.0, 0.1))
|
||||
return self.add_partial_blockage_preserve_quality(image, num_range, coverage_range, variance_range)
|
||||
|
||||
elif method_name == "blurring":
|
||||
kernel_range = method_config.get("kernel_ratio_range", (0.0, 0.0084))
|
||||
return self.apply_blurring_preserve_quality(image, kernel_range)
|
||||
|
||||
elif method_name == "brightness_contrast":
|
||||
alpha_range = method_config.get("alpha_range", (0.4, 3.0))
|
||||
beta_range = method_config.get("beta_range", (1, 100))
|
||||
return self.adjust_brightness_contrast_preserve_quality(image, alpha_range, beta_range)
|
||||
|
||||
elif method_name == "color_jitter":
|
||||
return self.apply_color_jitter(image, method_config)
|
||||
|
||||
elif method_name == "perspective":
|
||||
distortion_scale = method_config.get("distortion_scale", 0.2)
|
||||
return self.apply_perspective_transform(image, distortion_scale)
|
||||
|
||||
else:
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error applying method {method_name}: {e}")
|
||||
return image
|
||||
|
||||
def _apply_final_processing(self, image: np.ndarray, final_config: dict) -> np.ndarray:
|
||||
"""Apply final processing steps - ALWAYS applied to all outputs"""
|
||||
# Grayscale conversion - ALWAYS applied if enabled
|
||||
grayscale_config = final_config.get("grayscale", {})
|
||||
if grayscale_config.get("enabled", False):
|
||||
# Always apply grayscale, no random check
|
||||
image = self.convert_to_grayscale_preserve_quality(image)
|
||||
|
||||
# Quality enhancement (future feature)
|
||||
quality_config = final_config.get("quality_enhancement", {})
|
||||
if quality_config.get("enabled", False):
|
||||
# TODO: Implement quality enhancement
|
||||
pass
|
||||
|
||||
return image
|
||||
|
||||
def apply_color_jitter(self, image: np.ndarray, config: dict) -> np.ndarray:
|
||||
"""
|
||||
Apply color jittering (brightness, contrast, saturation, hue adjustments)
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
config: Color jitter configuration
|
||||
|
||||
Returns:
|
||||
Color-jittered image
|
||||
"""
|
||||
# Get parameters
|
||||
brightness_range = config.get("brightness_range", [0.8, 1.2])
|
||||
contrast_range = config.get("contrast_range", [0.8, 1.2])
|
||||
saturation_range = config.get("saturation_range", [0.8, 1.2])
|
||||
hue_range = config.get("hue_range", [-0.1, 0.1])
|
||||
|
||||
# Convert to HSV for saturation and hue adjustments
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
|
||||
|
||||
# Apply brightness (adjust V channel)
|
||||
brightness_factor = random.uniform(brightness_range[0], brightness_range[1])
|
||||
hsv[:, :, 2] = np.clip(hsv[:, :, 2] * brightness_factor, 0, 255)
|
||||
|
||||
# Apply saturation (adjust S channel)
|
||||
saturation_factor = random.uniform(saturation_range[0], saturation_range[1])
|
||||
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * saturation_factor, 0, 255)
|
||||
|
||||
# Apply hue shift (adjust H channel)
|
||||
hue_shift = random.uniform(hue_range[0], hue_range[1]) * 179 # OpenCV hue range is 0-179
|
||||
hsv[:, :, 0] = (hsv[:, :, 0] + hue_shift) % 180
|
||||
|
||||
# Convert back to RGB
|
||||
result = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
||||
|
||||
# Apply contrast (after converting back to RGB)
|
||||
contrast_factor = random.uniform(contrast_range[0], contrast_range[1])
|
||||
result = cv2.convertScaleAbs(result, alpha=contrast_factor, beta=0)
|
||||
|
||||
return result
|
||||
|
||||
def apply_perspective_transform(self, image: np.ndarray, distortion_scale: float = 0.2) -> np.ndarray:
|
||||
"""
|
||||
Apply perspective transformation to simulate viewing angle changes
|
||||
|
||||
Args:
|
||||
image: Input image
|
||||
distortion_scale: Scale of perspective distortion (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Perspective-transformed image
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
|
||||
# Define source points (corners of original image)
|
||||
src_points = np.float32([
|
||||
[0, 0],
|
||||
[width-1, 0],
|
||||
[width-1, height-1],
|
||||
[0, height-1]
|
||||
])
|
||||
|
||||
# Add random distortion to destination points
|
||||
max_distortion = min(width, height) * distortion_scale
|
||||
|
||||
dst_points = np.float32([
|
||||
[random.uniform(0, max_distortion), random.uniform(0, max_distortion)],
|
||||
[width-1-random.uniform(0, max_distortion), random.uniform(0, max_distortion)],
|
||||
[width-1-random.uniform(0, max_distortion), height-1-random.uniform(0, max_distortion)],
|
||||
[random.uniform(0, max_distortion), height-1-random.uniform(0, max_distortion)]
|
||||
])
|
||||
|
||||
# Calculate perspective transformation matrix
|
||||
matrix = cv2.getPerspectiveTransform(src_points, dst_points)
|
||||
|
||||
# Apply transformation
|
||||
result = cv2.warpPerspective(image, matrix, (width, height),
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255))
|
||||
|
||||
return result
|
||||
|
||||
def _augment_sequential(self, image: np.ndarray, num_target_images: int,
|
||||
methods_config: dict, final_config: dict) -> List[np.ndarray]:
|
||||
"""Apply methods in sequence (pipeline style)"""
|
||||
augmented_images = []
|
||||
|
||||
# Get enabled methods
|
||||
enabled_methods = [
|
||||
(name, config) for name, config in methods_config.items()
|
||||
if config.get("enabled", False)
|
||||
]
|
||||
|
||||
for i in range(num_target_images):
|
||||
augmented = image.copy()
|
||||
|
||||
# Apply all enabled methods in sequence
|
||||
for method_name, method_config in enabled_methods:
|
||||
if random.random() < method_config.get("probability", 0.5):
|
||||
augmented = self._apply_single_method(augmented, method_name, method_config)
|
||||
|
||||
# Apply final processing
|
||||
augmented = self._apply_final_processing(augmented, final_config)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
return augmented_images
|
||||
|
||||
def _augment_individual_legacy(self, image: np.ndarray, num_target_images: int) -> List[np.ndarray]:
|
||||
"""Legacy individual method application (backward compatibility)"""
|
||||
# This is the old implementation for backward compatibility
|
||||
augmented_images = []
|
||||
|
||||
# Get old-style configuration
|
||||
rotation_config = self.config.get("rotation", {})
|
||||
cropping_config = self.config.get("random_cropping", {})
|
||||
noise_config = self.config.get("random_noise", {})
|
||||
@@ -386,177 +673,272 @@ class DataAugmentation:
|
||||
blurring_config = self.config.get("blurring", {})
|
||||
brightness_contrast_config = self.config.get("brightness_contrast", {})
|
||||
|
||||
# Configuration parameters
|
||||
angles = rotation_config.get("angles", [30, 60, 120, 150, 180, 210, 240, 300, 330])
|
||||
crop_ratio_range = cropping_config.get("ratio_range", (0.7, 1.0))
|
||||
mean_range = noise_config.get("mean_range", (0.0, 0.7))
|
||||
variance_range = noise_config.get("variance_range", (0.0, 0.1))
|
||||
num_occlusions_range = blockage_config.get("num_occlusions_range", (1, 100))
|
||||
coverage_range = blockage_config.get("coverage_range", (0.0, 0.25))
|
||||
blockage_variance_range = blockage_config.get("variance_range", (0.0, 0.1))
|
||||
kernel_ratio_range = blurring_config.get("kernel_ratio_range", (0.0, 0.0084))
|
||||
alpha_range = brightness_contrast_config.get("alpha_range", (0.4, 3.0))
|
||||
beta_range = brightness_contrast_config.get("beta_range", (1, 100))
|
||||
# Apply individual methods (old logic)
|
||||
methods = [
|
||||
("rotation", rotation_config, self.rotate_image_preserve_quality),
|
||||
("cropping", cropping_config, self.random_crop_preserve_quality),
|
||||
("noise", noise_config, self.add_random_noise_preserve_quality),
|
||||
("blockage", blockage_config, self.add_partial_blockage_preserve_quality),
|
||||
("blurring", blurring_config, self.apply_blurring_preserve_quality),
|
||||
("brightness_contrast", brightness_contrast_config, self.adjust_brightness_contrast_preserve_quality)
|
||||
]
|
||||
|
||||
# Apply each method separately to create independent versions
|
||||
for method_name, method_config, method_func in methods:
|
||||
if method_config.get("enabled", False):
|
||||
for i in range(num_target_images):
|
||||
augmented = image.copy()
|
||||
# Apply single method with appropriate parameters
|
||||
if method_name == "rotation":
|
||||
angles = method_config.get("angles", [30, 60, 90, 120, 150, 180, 210, 240, 300, 330])
|
||||
angle = random.choice(angles)
|
||||
augmented = method_func(augmented, angle)
|
||||
elif method_name == "cropping":
|
||||
ratio_range = method_config.get("ratio_range", (0.7, 1.0))
|
||||
augmented = method_func(augmented, ratio_range)
|
||||
# Add other method parameter handling as needed
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 1. Rotation only
|
||||
if rotation_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
angle = random.choice(angles)
|
||||
augmented = self.rotate_image_preserve_quality(augmented, angle)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 2. Random cropping only
|
||||
if cropping_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
augmented = self.random_crop_preserve_quality(augmented, crop_ratio_range)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 3. Random noise only
|
||||
if noise_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
augmented = self.add_random_noise_preserve_quality(augmented, mean_range, variance_range)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 4. Partial blockage only
|
||||
if blockage_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
augmented = self.add_partial_blockage_preserve_quality(augmented, num_occlusions_range, coverage_range, blockage_variance_range)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 5. Blurring only
|
||||
if blurring_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
augmented = self.apply_blurring_preserve_quality(augmented, kernel_ratio_range)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 6. Brightness/Contrast only
|
||||
if brightness_contrast_config.get("enabled", False):
|
||||
for i in range(num_augmentations):
|
||||
augmented = image.copy()
|
||||
augmented = self.adjust_brightness_contrast_preserve_quality(augmented, alpha_range, beta_range)
|
||||
|
||||
# Resize preserving aspect ratio
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
augmented = self.resize_preserve_aspect(augmented, target_size)
|
||||
|
||||
augmented_images.append(augmented)
|
||||
|
||||
# 7. Apply grayscale as final step to ALL augmented images
|
||||
# Apply grayscale to all images
|
||||
if grayscale_config.get("enabled", False):
|
||||
for i in range(len(augmented_images)):
|
||||
augmented_images[i] = self.convert_to_grayscale_preserve_quality(augmented_images[i])
|
||||
|
||||
return augmented_images
|
||||
|
||||
def augment_image_file(self, image_path: Path, output_dir: Path, num_augmentations: int = None) -> List[Path]:
|
||||
def augment_image_file(self, image_path: Path, output_dir: Path, num_target_images: int = None) -> List[Path]:
|
||||
"""
|
||||
Augment a single image file and save results with quality preservation
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
output_dir: Output directory for augmented images
|
||||
num_augmentations: Number of augmented versions to create per method
|
||||
output_dir: Output directory for augmented images
|
||||
num_target_images: Number of target augmented images to generate
|
||||
|
||||
Returns:
|
||||
List of paths to saved augmented images
|
||||
"""
|
||||
# Load image without resizing to preserve original quality
|
||||
image = load_image(image_path, None) # Load original size
|
||||
image = load_image(image_path, None)
|
||||
if image is None:
|
||||
return []
|
||||
|
||||
# Apply augmentations
|
||||
augmented_images = self.augment_single_image(image, num_augmentations)
|
||||
augmented_images = self.augment_single_image(image, num_target_images)
|
||||
|
||||
# Save augmented images with method names
|
||||
# Save augmented images
|
||||
saved_paths = []
|
||||
method_names = ["rotation", "cropping", "noise", "blockage", "blurring", "brightness_contrast", "grayscale"]
|
||||
method_index = 0
|
||||
|
||||
for i, aug_image in enumerate(augmented_images):
|
||||
# Determine method name based on index
|
||||
method_name = method_names[method_index // num_augmentations] if method_index // num_augmentations < len(method_names) else "aug"
|
||||
base_name = image_path.stem
|
||||
output_filename = f"{base_name}_aug_{i+1:03d}.jpg"
|
||||
output_path = output_dir / output_filename
|
||||
|
||||
# Create output filename with method name
|
||||
output_filename = create_augmented_filename(image_path, (i % num_augmentations) + 1, method_name)
|
||||
output_path = output_dir / output_filename.name
|
||||
|
||||
# Save image
|
||||
if save_image(aug_image, output_path):
|
||||
saved_paths.append(output_path)
|
||||
|
||||
method_index += 1
|
||||
|
||||
return saved_paths
|
||||
|
||||
def batch_augment(self, input_dir: Path, output_dir: Path, num_augmentations: int = None) -> Dict[str, List[Path]]:
|
||||
def augment_image_file_with_raw(self, image_path: Path, output_dir: Path,
|
||||
num_total_versions: int = None) -> List[Path]:
|
||||
"""
|
||||
Augment all images in a directory
|
||||
Augment a single image file including raw/original version
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
output_dir: Output directory for all image versions
|
||||
num_total_versions: Total number of versions (including raw)
|
||||
|
||||
Returns:
|
||||
List of paths to saved images (raw + augmented)
|
||||
"""
|
||||
# Load original image
|
||||
image = load_image(image_path, None)
|
||||
if image is None:
|
||||
return []
|
||||
|
||||
saved_paths = []
|
||||
base_name = image_path.stem
|
||||
|
||||
# Always save raw version first (resized but not augmented)
|
||||
if num_total_versions > 0:
|
||||
raw_image = image.copy()
|
||||
|
||||
# Apply final processing (grayscale) but no augmentation
|
||||
final_config = self.config.get("final_processing", {})
|
||||
raw_image = self._apply_final_processing(raw_image, final_config)
|
||||
|
||||
# Resize to target size
|
||||
target_size = self.image_processor.target_size
|
||||
if target_size:
|
||||
raw_image = self.resize_preserve_aspect(raw_image, target_size)
|
||||
|
||||
# Save raw version
|
||||
raw_filename = f"{base_name}_raw_001.jpg"
|
||||
raw_path = output_dir / raw_filename
|
||||
if save_image(raw_image, raw_path):
|
||||
saved_paths.append(raw_path)
|
||||
|
||||
# Generate augmented versions for remaining slots
|
||||
num_augmented = max(0, num_total_versions - 1)
|
||||
if num_augmented > 0:
|
||||
augmented_images = self.augment_single_image(image, num_augmented)
|
||||
|
||||
for i, aug_image in enumerate(augmented_images):
|
||||
aug_filename = f"{base_name}_aug_{i+1:03d}.jpg"
|
||||
aug_path = output_dir / aug_filename
|
||||
|
||||
if save_image(aug_image, aug_path):
|
||||
saved_paths.append(aug_path)
|
||||
|
||||
return saved_paths
|
||||
|
||||
def batch_augment(self, input_dir: Path, output_dir: Path,
|
||||
multiplication_factor: float = None, random_seed: int = None) -> Dict[str, List[Path]]:
|
||||
"""
|
||||
Augment images in a directory with smart sampling and multiplication strategy
|
||||
|
||||
Args:
|
||||
input_dir: Input directory containing images
|
||||
output_dir: Output directory for augmented images
|
||||
num_augmentations: Number of augmented versions per image
|
||||
multiplication_factor:
|
||||
- If < 1.0: Sample percentage of input data to augment
|
||||
- If >= 1.0: Target multiplication factor for output data size
|
||||
random_seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original images to their augmented versions
|
||||
Dictionary containing results and statistics
|
||||
"""
|
||||
from utils import get_image_files
|
||||
|
||||
image_files = get_image_files(input_dir)
|
||||
# Set random seed for reproducibility
|
||||
if random_seed is not None:
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Get all input images
|
||||
all_image_files = get_image_files(input_dir)
|
||||
if not all_image_files:
|
||||
print("No images found in input directory")
|
||||
return {}
|
||||
|
||||
# Get multiplication factor from config if not provided
|
||||
if multiplication_factor is None:
|
||||
data_strategy = self.config.get("data_strategy", {})
|
||||
multiplication_factor = data_strategy.get("multiplication_factor", 3.0)
|
||||
|
||||
print(f"Found {len(all_image_files)} total images")
|
||||
print(f"Multiplication factor: {multiplication_factor}")
|
||||
|
||||
# Determine sampling strategy
|
||||
if multiplication_factor < 1.0:
|
||||
# Sampling mode: Take a percentage of input data
|
||||
num_selected = int(len(all_image_files) * multiplication_factor)
|
||||
selected_images = self._sample_images(all_image_files, num_selected)
|
||||
target_total_images = len(all_image_files) # Keep original dataset size
|
||||
images_per_input = max(1, target_total_images // len(selected_images))
|
||||
print(f"SAMPLING MODE: Selected {len(selected_images)} images ({multiplication_factor*100:.1f}%)")
|
||||
print(f"Target: {target_total_images} total images, {images_per_input} per selected image")
|
||||
else:
|
||||
# Multiplication mode: Multiply dataset size
|
||||
selected_images = all_image_files
|
||||
target_total_images = int(len(all_image_files) * multiplication_factor)
|
||||
images_per_input = max(1, target_total_images // len(selected_images))
|
||||
print(f"MULTIPLICATION MODE: Processing all {len(selected_images)} images")
|
||||
print(f"Target: {target_total_images} total images ({multiplication_factor}x original), {images_per_input} per image")
|
||||
|
||||
# Process selected images
|
||||
results = {}
|
||||
total_generated = 0
|
||||
|
||||
print(f"Found {len(image_files)} images to augment")
|
||||
|
||||
for i, image_path in enumerate(image_files):
|
||||
print_progress(i + 1, len(image_files), "Augmenting images")
|
||||
for i, image_path in enumerate(selected_images):
|
||||
print_progress(i + 1, len(selected_images), f"Processing {image_path.name}")
|
||||
|
||||
# Augment single image
|
||||
augmented_paths = self.augment_image_file(image_path, output_dir, num_augmentations)
|
||||
# Calculate number of versions for this image (including raw)
|
||||
remaining_images = target_total_images - total_generated
|
||||
remaining_inputs = len(selected_images) - i
|
||||
total_versions_needed = min(images_per_input, remaining_images)
|
||||
|
||||
# Always include raw image, then augmented ones
|
||||
augmented_paths = self.augment_image_file_with_raw(
|
||||
image_path, output_dir, total_versions_needed
|
||||
)
|
||||
|
||||
if augmented_paths:
|
||||
results[str(image_path)] = augmented_paths
|
||||
total_generated += len(augmented_paths)
|
||||
|
||||
print(f"\nAugmented {len(results)} images successfully")
|
||||
return results
|
||||
# Generate summary
|
||||
summary = {
|
||||
"input_images": len(all_image_files),
|
||||
"selected_images": len(selected_images),
|
||||
"target_total": target_total_images,
|
||||
"actual_generated": total_generated,
|
||||
"multiplication_factor": multiplication_factor,
|
||||
"mode": "sampling" if multiplication_factor < 1.0 else "multiplication",
|
||||
"results": results,
|
||||
"efficiency": total_generated / target_total_images if target_total_images > 0 else 0
|
||||
}
|
||||
|
||||
print(f"\n✅ Augmentation completed!")
|
||||
print(f"Generated {total_generated} images from {len(selected_images)} selected images")
|
||||
print(f"Target vs Actual: {target_total_images} → {total_generated} ({summary['efficiency']:.1%} efficiency)")
|
||||
|
||||
return summary
|
||||
|
||||
def _sample_images(self, image_files: List[Path], num_selected: int) -> List[Path]:
|
||||
"""Sample images from the input list based on strategy"""
|
||||
data_strategy = self.config.get("data_strategy", {})
|
||||
sampling_config = data_strategy.get("sampling", {})
|
||||
|
||||
method = sampling_config.get("method", "random")
|
||||
preserve_distribution = sampling_config.get("preserve_distribution", True)
|
||||
|
||||
if method == "random":
|
||||
# Simple random sampling
|
||||
return random.sample(image_files, min(num_selected, len(image_files)))
|
||||
|
||||
elif method == "stratified" and preserve_distribution:
|
||||
# Stratified sampling by file extension
|
||||
extension_groups = {}
|
||||
for img_file in image_files:
|
||||
ext = img_file.suffix.lower()
|
||||
if ext not in extension_groups:
|
||||
extension_groups[ext] = []
|
||||
extension_groups[ext].append(img_file)
|
||||
|
||||
selected = []
|
||||
for ext, files in extension_groups.items():
|
||||
# Sample proportionally from each extension group
|
||||
group_size = max(1, int(num_selected * len(files) / len(image_files)))
|
||||
group_selected = random.sample(files, min(group_size, len(files)))
|
||||
selected.extend(group_selected)
|
||||
|
||||
# If we have too few, add more randomly
|
||||
if len(selected) < num_selected:
|
||||
remaining = [f for f in image_files if f not in selected]
|
||||
additional = random.sample(remaining,
|
||||
min(num_selected - len(selected), len(remaining)))
|
||||
selected.extend(additional)
|
||||
|
||||
return selected[:num_selected]
|
||||
|
||||
elif method == "uniform":
|
||||
# Uniform sampling - evenly spaced
|
||||
if num_selected >= len(image_files):
|
||||
return image_files
|
||||
|
||||
step = len(image_files) / num_selected
|
||||
indices = [int(i * step) for i in range(num_selected)]
|
||||
return [image_files[i] for i in indices]
|
||||
|
||||
else:
|
||||
# Fallback to random
|
||||
return random.sample(image_files, min(num_selected, len(image_files)))
|
||||
|
||||
def get_augmentation_summary(self, results: Dict[str, List[Path]]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user