update batch loader

This commit is contained in:
2025-09-04 12:11:50 +00:00
parent 44febd7d2e
commit 116ab04283
8 changed files with 3949 additions and 5 deletions

30
main.py
View File

@@ -32,6 +32,9 @@ Examples:
# Run analysis and save to custom output directory # Run analysis and save to custom output directory
python main.py --output-dir custom_results python main.py --output-dir custom_results
# Run with memory-efficient batch processing
python main.py --deqa-only --batch-size 4 --max-image-size 512 512
""" """
) )
@@ -102,6 +105,28 @@ Examples:
help='Use only DeQA metric (disable traditional metrics)' help='Use only DeQA metric (disable traditional metrics)'
) )
parser.add_argument(
'--batch-size',
type=int,
default=8,
help='Batch size for memory-efficient processing (default: 8)'
)
parser.add_argument(
'--max-image-size',
type=int,
nargs=2,
default=[1024, 1024],
metavar=('WIDTH', 'HEIGHT'),
help='Maximum image dimensions for preprocessing (default: 1024 1024)'
)
parser.add_argument(
'--disable-memory-monitoring',
action='store_true',
help='Disable memory usage monitoring'
)
args = parser.parse_args() args = parser.parse_args()
# Setup logging # Setup logging
@@ -150,7 +175,10 @@ Examples:
enable_deqa=enable_deqa, enable_deqa=enable_deqa,
enable_traditional=enable_traditional, enable_traditional=enable_traditional,
enable_pyiqa=enable_pyiqa_flag, enable_pyiqa=enable_pyiqa_flag,
pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None) pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None),
batch_size=args.batch_size,
max_image_size=tuple(args.max_image_size),
enable_memory_monitoring=not args.disable_memory_monitoring
) )
results, report = analyzer.run_analysis() results, report = analyzer.run_analysis()

3398
results/facture.txt Normal file

File diff suppressed because it is too large Load Diff

301
src/batch_loader.py Normal file
View File

@@ -0,0 +1,301 @@
"""
Memory-efficient batch loader for image processing.
Reduces memory usage by processing images in configurable batches with optional preprocessing.
"""
import gc
import torch
from PIL import Image
from pathlib import Path
from typing import List, Dict, Any, Iterator, Tuple, Optional
import psutil
import os
from .logger_config import get_logger
logger = get_logger(__name__)
class MemoryEfficientBatchLoader:
"""Memory-efficient image batch loader with preprocessing options."""
def __init__(self,
batch_size: int = 8,
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
enable_memory_monitoring: bool = True,
memory_threshold_percent: float = 85.0):
"""
Initialize the batch loader.
Args:
batch_size: Number of images to process per batch
max_image_size: Maximum image dimensions (width, height). If None, no resizing
enable_memory_monitoring: Whether to monitor memory usage
memory_threshold_percent: Memory usage threshold to trigger cleanup
"""
self.batch_size = batch_size
self.max_image_size = max_image_size
self.enable_memory_monitoring = enable_memory_monitoring
self.memory_threshold_percent = memory_threshold_percent
logger.info(f"BatchLoader initialized: batch_size={batch_size}, "
f"max_size={max_image_size}, memory_monitoring={enable_memory_monitoring}")
def _get_memory_usage(self) -> float:
"""Get current memory usage percentage."""
return psutil.virtual_memory().percent
def _cleanup_memory(self):
"""Force garbage collection and clear GPU cache if available."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.enable_memory_monitoring:
memory_usage = self._get_memory_usage()
logger.debug(f"Memory cleanup performed. Current usage: {memory_usage:.1f}%")
def _check_memory_threshold(self):
"""Check if memory usage exceeds threshold and cleanup if needed."""
if not self.enable_memory_monitoring:
return
memory_usage = self._get_memory_usage()
if memory_usage > self.memory_threshold_percent:
logger.warning(f"Memory usage ({memory_usage:.1f}%) exceeds threshold "
f"({self.memory_threshold_percent}%). Performing cleanup...")
self._cleanup_memory()
def _preprocess_image(self, image_path: str) -> Optional[Image.Image]:
"""
Load and preprocess image to reduce memory usage.
Args:
image_path: Path to the image file
Returns:
PIL Image object or None if failed
"""
try:
image = Image.open(image_path)
# Convert to RGB if needed (removes alpha channel)
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize if max_image_size is specified
if self.max_image_size:
# Calculate resize ratio to maintain aspect ratio
width, height = image.size
max_w, max_h = self.max_image_size
if width > max_w or height > max_h:
ratio = min(max_w / width, max_h / height)
new_size = (int(width * ratio), int(height * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
logger.debug(f"Resized {Path(image_path).name} from {width}x{height} to {new_size[0]}x{new_size[1]}")
return image
except Exception as e:
logger.error(f"Failed to preprocess image {image_path}: {e}")
return None
def _get_file_size_mb(self, file_path: str) -> float:
"""Get file size in MB."""
try:
return os.path.getsize(file_path) / (1024 * 1024)
except:
return 0.0
def create_batches(self, image_paths: List[str]) -> Iterator[List[str]]:
"""
Create batches of image paths.
Args:
image_paths: List of image file paths
Yields:
Batches of image paths
"""
for i in range(0, len(image_paths), self.batch_size):
batch = image_paths[i:i + self.batch_size]
yield batch
def process_batch(self, image_paths: List[str],
scoring_function,
include_file_info: bool = True) -> Dict[str, Dict[str, Any]]:
"""
Process a batch of images with memory management.
Args:
image_paths: List of image paths in the batch
scoring_function: Function to score images (should accept list of PIL Images)
include_file_info: Whether to include file size and path info
Returns:
Dictionary with results for each image in the batch
"""
results = {}
# Check memory before processing batch
self._check_memory_threshold()
# Preprocess all images in batch
processed_images = []
valid_paths = []
for path in image_paths:
image = self._preprocess_image(path)
if image is not None:
processed_images.append(image)
valid_paths.append(path)
else:
# Handle failed image
image_name = Path(path).name
results[image_name] = {
'error': f'Failed to load/preprocess image: {path}',
'file_path': path if include_file_info else None,
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
}
if processed_images:
try:
# Score the batch
scores = scoring_function(processed_images)
# Combine results
for i, (path, score) in enumerate(zip(valid_paths, scores)):
image_name = Path(path).name
result = {
'deqa_score': score,
}
if include_file_info:
result.update({
'file_path': path,
'file_size_mb': round(self._get_file_size_mb(path), 2)
})
results[image_name] = result
logger.debug(f"Processed batch of {len(processed_images)} images")
except Exception as e:
logger.error(f"Error scoring batch: {e}")
# Mark all as failed
for path in valid_paths:
image_name = Path(path).name
results[image_name] = {
'error': f'Scoring failed: {str(e)}',
'file_path': path if include_file_info else None,
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
}
finally:
# Clear processed images from memory
del processed_images
self._cleanup_memory()
return results
def get_memory_stats(self) -> Dict[str, Any]:
"""Get current memory statistics."""
memory = psutil.virtual_memory()
stats = {
'memory_percent': memory.percent,
'memory_available_gb': memory.available / (1024**3),
'memory_used_gb': memory.used / (1024**3),
'memory_total_gb': memory.total / (1024**3)
}
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
gpu_cached = torch.cuda.memory_reserved() / (1024**3)
stats.update({
'gpu_memory_allocated_gb': gpu_memory,
'gpu_memory_cached_gb': gpu_cached
})
return stats
class BatchProcessor:
"""High-level processor that uses MemoryEfficientBatchLoader."""
def __init__(self,
batch_size: int = 8,
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
enable_memory_monitoring: bool = True):
"""
Initialize the batch processor.
Args:
batch_size: Number of images to process per batch
max_image_size: Maximum image dimensions for preprocessing
enable_memory_monitoring: Whether to monitor memory usage
"""
self.loader = MemoryEfficientBatchLoader(
batch_size=batch_size,
max_image_size=max_image_size,
enable_memory_monitoring=enable_memory_monitoring
)
def process_images_in_batches(self,
image_paths: List[str],
scorer,
progress_callback=None) -> Dict[str, Dict[str, Any]]:
"""
Process all images in batches.
Args:
image_paths: List of all image paths to process
scorer: Object with score_multiple_images method
progress_callback: Optional function to call with progress updates
Returns:
Combined results from all batches
"""
all_results = {}
total_images = len(image_paths)
processed_images = 0
logger.info(f"Processing {total_images} images in batches of {self.loader.batch_size}")
# Show initial memory stats
if self.loader.enable_memory_monitoring:
initial_stats = self.loader.get_memory_stats()
logger.info(f"Initial memory usage: {initial_stats['memory_percent']:.1f}% "
f"({initial_stats['memory_used_gb']:.1f}GB/{initial_stats['memory_total_gb']:.1f}GB)")
for batch_idx, batch_paths in enumerate(self.loader.create_batches(image_paths)):
logger.info(f"Processing batch {batch_idx + 1} ({len(batch_paths)} images)...")
# Create scoring function that works with the batch loader
def scoring_function(images):
return scorer.score_multiple_images_from_pil(images)
# Process batch
batch_results = self.loader.process_batch(batch_paths, scoring_function)
all_results.update(batch_results)
processed_images += len(batch_paths)
# Progress callback
if progress_callback:
progress_callback(processed_images, total_images)
# Log progress and memory stats
progress_percent = (processed_images / total_images) * 100
logger.info(f"Progress: {processed_images}/{total_images} ({progress_percent:.1f}%)")
if self.loader.enable_memory_monitoring:
current_stats = self.loader.get_memory_stats()
logger.debug(f"Current memory usage: {current_stats['memory_percent']:.1f}%")
logger.info(f"Completed processing {total_images} images")
# Final memory cleanup
self.loader._cleanup_memory()
return all_results

View File

@@ -119,6 +119,57 @@ class DeQAScorer:
logger.error(f"Error scoring images: {e}") logger.error(f"Error scoring images: {e}")
return [None] * len(image_paths) return [None] * len(image_paths)
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Union[float, None]]:
"""
Score multiple PIL images directly (for batch processing).
Args:
images: List of PIL Image objects
Returns:
List of DeQA scores (0-5 scale) or None for failed images
"""
if not self._is_loaded:
self.load_model()
try:
# Filter out None images
valid_images = [img for img in images if img is not None]
if not valid_images:
return [None] * len(images)
# Score images
scores = self.model.score(valid_images)
# Convert scores to list of floats and map back to original positions
result_scores = []
valid_idx = 0
for img in images:
if img is None:
result_scores.append(None)
else:
try:
score = scores[valid_idx]
if hasattr(score, 'item'):
result_scores.append(float(score.item()))
elif hasattr(score, 'tolist'):
result_scores.append(float(score.tolist()))
else:
result_scores.append(float(score))
valid_idx += 1
except Exception as e:
logger.warning(f"Failed to convert score: {e}")
result_scores.append(None)
valid_idx += 1
return result_scores
except Exception as e:
logger.error(f"Error scoring PIL images: {e}")
return [None] * len(images)
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""Check if the model is loaded.""" """Check if the model is loaded."""
return self._is_loaded return self._is_loaded

View File

@@ -12,6 +12,7 @@ from datetime import datetime
from .metrics import MetricsManager from .metrics import MetricsManager
from .logger_config import get_logger from .logger_config import get_logger
from .batch_loader import BatchProcessor
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -21,7 +22,9 @@ class IQAAnalyzer:
def __init__(self, image_dir: str, output_dir: str = "results", def __init__(self, image_dir: str, output_dir: str = "results",
enable_deqa: bool = True, enable_traditional: bool = True, enable_deqa: bool = True, enable_traditional: bool = True,
enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None): enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None,
batch_size: int = 8, max_image_size: tuple = (1024, 1024),
enable_memory_monitoring: bool = True):
""" """
Initialize the IQA analyzer. Initialize the IQA analyzer.
@@ -30,11 +33,21 @@ class IQAAnalyzer:
output_dir: Directory to save results output_dir: Directory to save results
enable_deqa: Whether to enable DeQA metric enable_deqa: Whether to enable DeQA metric
enable_traditional: Whether to enable traditional metrics enable_traditional: Whether to enable traditional metrics
enable_pyiqa: Whether to enable PyIQA metrics
pyiqa_selected_metrics: List of specific PyIQA metrics to use
batch_size: Number of images to process per batch
max_image_size: Maximum image dimensions for preprocessing
enable_memory_monitoring: Whether to monitor memory usage
""" """
self.image_dir = Path(image_dir) self.image_dir = Path(image_dir)
self.output_dir = Path(output_dir) self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True)
# Store batch processing settings
self.batch_size = batch_size
self.max_image_size = max_image_size
self.enable_memory_monitoring = enable_memory_monitoring
# Initialize metrics manager # Initialize metrics manager
self.metrics_manager = MetricsManager( self.metrics_manager = MetricsManager(
enable_deqa=enable_deqa, enable_deqa=enable_deqa,
@@ -43,12 +56,23 @@ class IQAAnalyzer:
pyiqa_selected_metrics=pyiqa_selected_metrics pyiqa_selected_metrics=pyiqa_selected_metrics
) )
# Initialize batch processor (only if DeQA is enabled for memory efficiency)
self.batch_processor = None
if enable_deqa:
self.batch_processor = BatchProcessor(
batch_size=batch_size,
max_image_size=max_image_size,
enable_memory_monitoring=enable_memory_monitoring
)
# Results storage # Results storage
self.results = {} self.results = {}
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger.info(f"IQA Analyzer initialized for: {image_dir}") logger.info(f"IQA Analyzer initialized for: {image_dir}")
logger.info(f"Available metrics: {self.metrics_manager.get_available_metrics()}") logger.info(f"Available metrics: {self.metrics_manager.get_available_metrics()}")
logger.info(f"Batch processing: {'enabled' if self.batch_processor else 'disabled'} "
f"(batch_size={batch_size}, max_size={max_image_size})")
def _get_image_files(self) -> List[Path]: def _get_image_files(self) -> List[Path]:
"""Get all image files from the directory.""" """Get all image files from the directory."""
@@ -81,8 +105,30 @@ class IQAAnalyzer:
# Prepare image paths for analysis # Prepare image paths for analysis
image_paths = [str(img_path) for img_path in image_files] image_paths = [str(img_path) for img_path in image_files]
# Calculate all metrics for all images # Use batch processing if available and only DeQA is enabled
logger.info("Calculating comprehensive metrics...") available_metrics = self.metrics_manager.get_available_metrics()
use_batch_processing = (self.batch_processor is not None and
available_metrics == ['deqa'])
if use_batch_processing:
logger.info(f"Using memory-efficient batch processing (batch_size={self.batch_size})...")
# Get DeQA scorer
deqa_scorer = self.metrics_manager.get_deqa_scorer()
if deqa_scorer is None:
logger.error("DeQA scorer not available for batch processing")
return {}
# Process with batch loader
def progress_callback(processed, total):
logger.info(f"Batch progress: {processed}/{total} images processed")
self.results = self.batch_processor.process_images_in_batches(
image_paths, deqa_scorer, progress_callback
)
else:
# Fall back to original processing
logger.info("Using standard processing...")
self.results = self.metrics_manager.calculate_metrics_batch(image_paths) self.results = self.metrics_manager.calculate_metrics_batch(image_paths)
logger.info(f"Analysis completed for {len(self.results)} images") logger.info(f"Analysis completed for {len(self.results)} images")

View File

@@ -102,6 +102,57 @@ class DeQAMetric:
logger.error(f"Error scoring multiple images: {e}") logger.error(f"Error scoring multiple images: {e}")
return [None] * len(image_paths) return [None] * len(image_paths)
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Optional[float]]:
"""
Score multiple PIL images directly (for batch processing).
Args:
images: List of PIL Image objects
Returns:
List of DeQA scores (0-5 scale) or None for failed images
"""
if not self._is_loaded:
self.load_model()
try:
# Filter out None images
valid_images = [img for img in images if img is not None]
if not valid_images:
return [None] * len(images)
# Score images
scores = self.model.score(valid_images)
# Convert scores to list of floats and map back to original positions
result_scores = []
valid_idx = 0
for img in images:
if img is None:
result_scores.append(None)
else:
try:
score = scores[valid_idx]
if hasattr(score, 'item'):
result_scores.append(float(score.item()))
elif hasattr(score, 'tolist'):
result_scores.append(float(score.tolist()))
else:
result_scores.append(float(score))
valid_idx += 1
except Exception as e:
logger.warning(f"Failed to convert score: {e}")
result_scores.append(None)
valid_idx += 1
return result_scores
except Exception as e:
logger.error(f"Error scoring PIL images: {e}")
return [None] * len(images)
def get_metric_name(self) -> str: def get_metric_name(self) -> str:
"""Get the name of this metric.""" """Get the name of this metric."""
return "DeQA" return "DeQA"

View File

@@ -196,6 +196,12 @@ class MetricsManager:
return self.metrics['deqa'].score_multiple_images(image_paths) return self.metrics['deqa'].score_multiple_images(image_paths)
return [None] * len(image_paths) return [None] * len(image_paths)
def get_deqa_scorer(self):
"""Get the DeQA scorer object for batch processing."""
if 'deqa' in self.metrics and self.metrics['deqa'] is not None:
return self.metrics['deqa']
return None
def get_traditional_metrics_only(self, image_path: str) -> Dict[str, float]: def get_traditional_metrics_only(self, image_path: str) -> Dict[str, float]:
"""Get only traditional metrics for an image.""" """Get only traditional metrics for an image."""
if 'traditional' in self.metrics and self.metrics['traditional'] is not None: if 'traditional' in self.metrics and self.metrics['traditional'] is not None:

63
test_batch_loader.py Normal file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
Test script for memory-efficient batch loader.
"""
import sys
from pathlib import Path
# Add src directory to path for imports
sys.path.append(str(Path(__file__).parent / "src"))
from src.batch_loader import MemoryEfficientBatchLoader, BatchProcessor
from src.logger_config import setup_logging
def test_batch_loader():
"""Test the batch loader with sample functionality."""
# Setup logging
logger = setup_logging(log_dir="logs", log_level="INFO")
# Test configuration
batch_size = 4
max_image_size = (512, 512)
# Create batch loader
loader = MemoryEfficientBatchLoader(
batch_size=batch_size,
max_image_size=max_image_size,
enable_memory_monitoring=True
)
# Show initial memory stats
memory_stats = loader.get_memory_stats()
logger.info(f"Initial memory stats: {memory_stats}")
# Sample image paths (for testing - these would be real image paths)
sample_image_paths = [
f"/home/nguyendc/thanh-dev/IQA-Metric-Benchmark/data/task/cni/images/image_{i}.png"
for i in range(10)
]
# Test batch creation
batches = list(loader.create_batches(sample_image_paths))
logger.info(f"Created {len(batches)} batches from {len(sample_image_paths)} images")
for i, batch in enumerate(batches):
logger.info(f"Batch {i+1}: {len(batch)} images")
# Mock scoring function for testing
def mock_scoring_function(images):
"""Mock function that returns random scores."""
import random
return [random.uniform(1.0, 5.0) for _ in images]
# Test processing a single batch (with mock data)
if batches:
logger.info("Testing batch processing with mock data...")
# Note: This will fail with actual file paths that don't exist
# but demonstrates the interface
logger.info("Batch loader test completed!")
if __name__ == "__main__":
test_batch_loader()