update batch loader
This commit is contained in:
30
main.py
30
main.py
@@ -32,6 +32,9 @@ Examples:
|
|||||||
|
|
||||||
# Run analysis and save to custom output directory
|
# Run analysis and save to custom output directory
|
||||||
python main.py --output-dir custom_results
|
python main.py --output-dir custom_results
|
||||||
|
|
||||||
|
# Run with memory-efficient batch processing
|
||||||
|
python main.py --deqa-only --batch-size 4 --max-image-size 512 512
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -102,6 +105,28 @@ Examples:
|
|||||||
help='Use only DeQA metric (disable traditional metrics)'
|
help='Use only DeQA metric (disable traditional metrics)'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help='Batch size for memory-efficient processing (default: 8)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-image-size',
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[1024, 1024],
|
||||||
|
metavar=('WIDTH', 'HEIGHT'),
|
||||||
|
help='Maximum image dimensions for preprocessing (default: 1024 1024)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--disable-memory-monitoring',
|
||||||
|
action='store_true',
|
||||||
|
help='Disable memory usage monitoring'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -150,7 +175,10 @@ Examples:
|
|||||||
enable_deqa=enable_deqa,
|
enable_deqa=enable_deqa,
|
||||||
enable_traditional=enable_traditional,
|
enable_traditional=enable_traditional,
|
||||||
enable_pyiqa=enable_pyiqa_flag,
|
enable_pyiqa=enable_pyiqa_flag,
|
||||||
pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None)
|
pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None),
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
max_image_size=tuple(args.max_image_size),
|
||||||
|
enable_memory_monitoring=not args.disable_memory_monitoring
|
||||||
)
|
)
|
||||||
results, report = analyzer.run_analysis()
|
results, report = analyzer.run_analysis()
|
||||||
|
|
||||||
|
3398
results/facture.txt
Normal file
3398
results/facture.txt
Normal file
File diff suppressed because it is too large
Load Diff
301
src/batch_loader.py
Normal file
301
src/batch_loader.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
Memory-efficient batch loader for image processing.
|
||||||
|
Reduces memory usage by processing images in configurable batches with optional preprocessing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Iterator, Tuple, Optional
|
||||||
|
import psutil
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .logger_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientBatchLoader:
|
||||||
|
"""Memory-efficient image batch loader with preprocessing options."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
batch_size: int = 8,
|
||||||
|
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
|
||||||
|
enable_memory_monitoring: bool = True,
|
||||||
|
memory_threshold_percent: float = 85.0):
|
||||||
|
"""
|
||||||
|
Initialize the batch loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of images to process per batch
|
||||||
|
max_image_size: Maximum image dimensions (width, height). If None, no resizing
|
||||||
|
enable_memory_monitoring: Whether to monitor memory usage
|
||||||
|
memory_threshold_percent: Memory usage threshold to trigger cleanup
|
||||||
|
"""
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_image_size = max_image_size
|
||||||
|
self.enable_memory_monitoring = enable_memory_monitoring
|
||||||
|
self.memory_threshold_percent = memory_threshold_percent
|
||||||
|
|
||||||
|
logger.info(f"BatchLoader initialized: batch_size={batch_size}, "
|
||||||
|
f"max_size={max_image_size}, memory_monitoring={enable_memory_monitoring}")
|
||||||
|
|
||||||
|
def _get_memory_usage(self) -> float:
|
||||||
|
"""Get current memory usage percentage."""
|
||||||
|
return psutil.virtual_memory().percent
|
||||||
|
|
||||||
|
def _cleanup_memory(self):
|
||||||
|
"""Force garbage collection and clear GPU cache if available."""
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.enable_memory_monitoring:
|
||||||
|
memory_usage = self._get_memory_usage()
|
||||||
|
logger.debug(f"Memory cleanup performed. Current usage: {memory_usage:.1f}%")
|
||||||
|
|
||||||
|
def _check_memory_threshold(self):
|
||||||
|
"""Check if memory usage exceeds threshold and cleanup if needed."""
|
||||||
|
if not self.enable_memory_monitoring:
|
||||||
|
return
|
||||||
|
|
||||||
|
memory_usage = self._get_memory_usage()
|
||||||
|
if memory_usage > self.memory_threshold_percent:
|
||||||
|
logger.warning(f"Memory usage ({memory_usage:.1f}%) exceeds threshold "
|
||||||
|
f"({self.memory_threshold_percent}%). Performing cleanup...")
|
||||||
|
self._cleanup_memory()
|
||||||
|
|
||||||
|
def _preprocess_image(self, image_path: str) -> Optional[Image.Image]:
|
||||||
|
"""
|
||||||
|
Load and preprocess image to reduce memory usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image object or None if failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
# Convert to RGB if needed (removes alpha channel)
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
# Resize if max_image_size is specified
|
||||||
|
if self.max_image_size:
|
||||||
|
# Calculate resize ratio to maintain aspect ratio
|
||||||
|
width, height = image.size
|
||||||
|
max_w, max_h = self.max_image_size
|
||||||
|
|
||||||
|
if width > max_w or height > max_h:
|
||||||
|
ratio = min(max_w / width, max_h / height)
|
||||||
|
new_size = (int(width * ratio), int(height * ratio))
|
||||||
|
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
logger.debug(f"Resized {Path(image_path).name} from {width}x{height} to {new_size[0]}x{new_size[1]}")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to preprocess image {image_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_file_size_mb(self, file_path: str) -> float:
|
||||||
|
"""Get file size in MB."""
|
||||||
|
try:
|
||||||
|
return os.path.getsize(file_path) / (1024 * 1024)
|
||||||
|
except:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def create_batches(self, image_paths: List[str]) -> Iterator[List[str]]:
|
||||||
|
"""
|
||||||
|
Create batches of image paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: List of image file paths
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Batches of image paths
|
||||||
|
"""
|
||||||
|
for i in range(0, len(image_paths), self.batch_size):
|
||||||
|
batch = image_paths[i:i + self.batch_size]
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def process_batch(self, image_paths: List[str],
|
||||||
|
scoring_function,
|
||||||
|
include_file_info: bool = True) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process a batch of images with memory management.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: List of image paths in the batch
|
||||||
|
scoring_function: Function to score images (should accept list of PIL Images)
|
||||||
|
include_file_info: Whether to include file size and path info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with results for each image in the batch
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Check memory before processing batch
|
||||||
|
self._check_memory_threshold()
|
||||||
|
|
||||||
|
# Preprocess all images in batch
|
||||||
|
processed_images = []
|
||||||
|
valid_paths = []
|
||||||
|
|
||||||
|
for path in image_paths:
|
||||||
|
image = self._preprocess_image(path)
|
||||||
|
if image is not None:
|
||||||
|
processed_images.append(image)
|
||||||
|
valid_paths.append(path)
|
||||||
|
else:
|
||||||
|
# Handle failed image
|
||||||
|
image_name = Path(path).name
|
||||||
|
results[image_name] = {
|
||||||
|
'error': f'Failed to load/preprocess image: {path}',
|
||||||
|
'file_path': path if include_file_info else None,
|
||||||
|
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
|
||||||
|
}
|
||||||
|
|
||||||
|
if processed_images:
|
||||||
|
try:
|
||||||
|
# Score the batch
|
||||||
|
scores = scoring_function(processed_images)
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
for i, (path, score) in enumerate(zip(valid_paths, scores)):
|
||||||
|
image_name = Path(path).name
|
||||||
|
result = {
|
||||||
|
'deqa_score': score,
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_file_info:
|
||||||
|
result.update({
|
||||||
|
'file_path': path,
|
||||||
|
'file_size_mb': round(self._get_file_size_mb(path), 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
results[image_name] = result
|
||||||
|
|
||||||
|
logger.debug(f"Processed batch of {len(processed_images)} images")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scoring batch: {e}")
|
||||||
|
# Mark all as failed
|
||||||
|
for path in valid_paths:
|
||||||
|
image_name = Path(path).name
|
||||||
|
results[image_name] = {
|
||||||
|
'error': f'Scoring failed: {str(e)}',
|
||||||
|
'file_path': path if include_file_info else None,
|
||||||
|
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clear processed images from memory
|
||||||
|
del processed_images
|
||||||
|
self._cleanup_memory()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_memory_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get current memory statistics."""
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
stats = {
|
||||||
|
'memory_percent': memory.percent,
|
||||||
|
'memory_available_gb': memory.available / (1024**3),
|
||||||
|
'memory_used_gb': memory.used / (1024**3),
|
||||||
|
'memory_total_gb': memory.total / (1024**3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||||
|
gpu_cached = torch.cuda.memory_reserved() / (1024**3)
|
||||||
|
stats.update({
|
||||||
|
'gpu_memory_allocated_gb': gpu_memory,
|
||||||
|
'gpu_memory_cached_gb': gpu_cached
|
||||||
|
})
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessor:
|
||||||
|
"""High-level processor that uses MemoryEfficientBatchLoader."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
batch_size: int = 8,
|
||||||
|
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
|
||||||
|
enable_memory_monitoring: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize the batch processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of images to process per batch
|
||||||
|
max_image_size: Maximum image dimensions for preprocessing
|
||||||
|
enable_memory_monitoring: Whether to monitor memory usage
|
||||||
|
"""
|
||||||
|
self.loader = MemoryEfficientBatchLoader(
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_image_size=max_image_size,
|
||||||
|
enable_memory_monitoring=enable_memory_monitoring
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_images_in_batches(self,
|
||||||
|
image_paths: List[str],
|
||||||
|
scorer,
|
||||||
|
progress_callback=None) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Process all images in batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: List of all image paths to process
|
||||||
|
scorer: Object with score_multiple_images method
|
||||||
|
progress_callback: Optional function to call with progress updates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined results from all batches
|
||||||
|
"""
|
||||||
|
all_results = {}
|
||||||
|
total_images = len(image_paths)
|
||||||
|
processed_images = 0
|
||||||
|
|
||||||
|
logger.info(f"Processing {total_images} images in batches of {self.loader.batch_size}")
|
||||||
|
|
||||||
|
# Show initial memory stats
|
||||||
|
if self.loader.enable_memory_monitoring:
|
||||||
|
initial_stats = self.loader.get_memory_stats()
|
||||||
|
logger.info(f"Initial memory usage: {initial_stats['memory_percent']:.1f}% "
|
||||||
|
f"({initial_stats['memory_used_gb']:.1f}GB/{initial_stats['memory_total_gb']:.1f}GB)")
|
||||||
|
|
||||||
|
for batch_idx, batch_paths in enumerate(self.loader.create_batches(image_paths)):
|
||||||
|
logger.info(f"Processing batch {batch_idx + 1} ({len(batch_paths)} images)...")
|
||||||
|
|
||||||
|
# Create scoring function that works with the batch loader
|
||||||
|
def scoring_function(images):
|
||||||
|
return scorer.score_multiple_images_from_pil(images)
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
batch_results = self.loader.process_batch(batch_paths, scoring_function)
|
||||||
|
all_results.update(batch_results)
|
||||||
|
|
||||||
|
processed_images += len(batch_paths)
|
||||||
|
|
||||||
|
# Progress callback
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(processed_images, total_images)
|
||||||
|
|
||||||
|
# Log progress and memory stats
|
||||||
|
progress_percent = (processed_images / total_images) * 100
|
||||||
|
logger.info(f"Progress: {processed_images}/{total_images} ({progress_percent:.1f}%)")
|
||||||
|
|
||||||
|
if self.loader.enable_memory_monitoring:
|
||||||
|
current_stats = self.loader.get_memory_stats()
|
||||||
|
logger.debug(f"Current memory usage: {current_stats['memory_percent']:.1f}%")
|
||||||
|
|
||||||
|
logger.info(f"Completed processing {total_images} images")
|
||||||
|
|
||||||
|
# Final memory cleanup
|
||||||
|
self.loader._cleanup_memory()
|
||||||
|
|
||||||
|
return all_results
|
@@ -119,6 +119,57 @@ class DeQAScorer:
|
|||||||
logger.error(f"Error scoring images: {e}")
|
logger.error(f"Error scoring images: {e}")
|
||||||
return [None] * len(image_paths)
|
return [None] * len(image_paths)
|
||||||
|
|
||||||
|
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Union[float, None]]:
|
||||||
|
"""
|
||||||
|
Score multiple PIL images directly (for batch processing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of PIL Image objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DeQA scores (0-5 scale) or None for failed images
|
||||||
|
"""
|
||||||
|
if not self._is_loaded:
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Filter out None images
|
||||||
|
valid_images = [img for img in images if img is not None]
|
||||||
|
|
||||||
|
if not valid_images:
|
||||||
|
return [None] * len(images)
|
||||||
|
|
||||||
|
# Score images
|
||||||
|
scores = self.model.score(valid_images)
|
||||||
|
|
||||||
|
# Convert scores to list of floats and map back to original positions
|
||||||
|
result_scores = []
|
||||||
|
valid_idx = 0
|
||||||
|
|
||||||
|
for img in images:
|
||||||
|
if img is None:
|
||||||
|
result_scores.append(None)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
score = scores[valid_idx]
|
||||||
|
if hasattr(score, 'item'):
|
||||||
|
result_scores.append(float(score.item()))
|
||||||
|
elif hasattr(score, 'tolist'):
|
||||||
|
result_scores.append(float(score.tolist()))
|
||||||
|
else:
|
||||||
|
result_scores.append(float(score))
|
||||||
|
valid_idx += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to convert score: {e}")
|
||||||
|
result_scores.append(None)
|
||||||
|
valid_idx += 1
|
||||||
|
|
||||||
|
return result_scores
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scoring PIL images: {e}")
|
||||||
|
return [None] * len(images)
|
||||||
|
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""Check if the model is loaded."""
|
"""Check if the model is loaded."""
|
||||||
return self._is_loaded
|
return self._is_loaded
|
||||||
|
@@ -12,6 +12,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from .metrics import MetricsManager
|
from .metrics import MetricsManager
|
||||||
from .logger_config import get_logger
|
from .logger_config import get_logger
|
||||||
|
from .batch_loader import BatchProcessor
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -21,7 +22,9 @@ class IQAAnalyzer:
|
|||||||
|
|
||||||
def __init__(self, image_dir: str, output_dir: str = "results",
|
def __init__(self, image_dir: str, output_dir: str = "results",
|
||||||
enable_deqa: bool = True, enable_traditional: bool = True,
|
enable_deqa: bool = True, enable_traditional: bool = True,
|
||||||
enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None):
|
enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None,
|
||||||
|
batch_size: int = 8, max_image_size: tuple = (1024, 1024),
|
||||||
|
enable_memory_monitoring: bool = True):
|
||||||
"""
|
"""
|
||||||
Initialize the IQA analyzer.
|
Initialize the IQA analyzer.
|
||||||
|
|
||||||
@@ -30,11 +33,21 @@ class IQAAnalyzer:
|
|||||||
output_dir: Directory to save results
|
output_dir: Directory to save results
|
||||||
enable_deqa: Whether to enable DeQA metric
|
enable_deqa: Whether to enable DeQA metric
|
||||||
enable_traditional: Whether to enable traditional metrics
|
enable_traditional: Whether to enable traditional metrics
|
||||||
|
enable_pyiqa: Whether to enable PyIQA metrics
|
||||||
|
pyiqa_selected_metrics: List of specific PyIQA metrics to use
|
||||||
|
batch_size: Number of images to process per batch
|
||||||
|
max_image_size: Maximum image dimensions for preprocessing
|
||||||
|
enable_memory_monitoring: Whether to monitor memory usage
|
||||||
"""
|
"""
|
||||||
self.image_dir = Path(image_dir)
|
self.image_dir = Path(image_dir)
|
||||||
self.output_dir = Path(output_dir)
|
self.output_dir = Path(output_dir)
|
||||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Store batch processing settings
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.max_image_size = max_image_size
|
||||||
|
self.enable_memory_monitoring = enable_memory_monitoring
|
||||||
|
|
||||||
# Initialize metrics manager
|
# Initialize metrics manager
|
||||||
self.metrics_manager = MetricsManager(
|
self.metrics_manager = MetricsManager(
|
||||||
enable_deqa=enable_deqa,
|
enable_deqa=enable_deqa,
|
||||||
@@ -43,12 +56,23 @@ class IQAAnalyzer:
|
|||||||
pyiqa_selected_metrics=pyiqa_selected_metrics
|
pyiqa_selected_metrics=pyiqa_selected_metrics
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize batch processor (only if DeQA is enabled for memory efficiency)
|
||||||
|
self.batch_processor = None
|
||||||
|
if enable_deqa:
|
||||||
|
self.batch_processor = BatchProcessor(
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_image_size=max_image_size,
|
||||||
|
enable_memory_monitoring=enable_memory_monitoring
|
||||||
|
)
|
||||||
|
|
||||||
# Results storage
|
# Results storage
|
||||||
self.results = {}
|
self.results = {}
|
||||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
logger.info(f"IQA Analyzer initialized for: {image_dir}")
|
logger.info(f"IQA Analyzer initialized for: {image_dir}")
|
||||||
logger.info(f"Available metrics: {self.metrics_manager.get_available_metrics()}")
|
logger.info(f"Available metrics: {self.metrics_manager.get_available_metrics()}")
|
||||||
|
logger.info(f"Batch processing: {'enabled' if self.batch_processor else 'disabled'} "
|
||||||
|
f"(batch_size={batch_size}, max_size={max_image_size})")
|
||||||
|
|
||||||
def _get_image_files(self) -> List[Path]:
|
def _get_image_files(self) -> List[Path]:
|
||||||
"""Get all image files from the directory."""
|
"""Get all image files from the directory."""
|
||||||
@@ -81,8 +105,30 @@ class IQAAnalyzer:
|
|||||||
# Prepare image paths for analysis
|
# Prepare image paths for analysis
|
||||||
image_paths = [str(img_path) for img_path in image_files]
|
image_paths = [str(img_path) for img_path in image_files]
|
||||||
|
|
||||||
# Calculate all metrics for all images
|
# Use batch processing if available and only DeQA is enabled
|
||||||
logger.info("Calculating comprehensive metrics...")
|
available_metrics = self.metrics_manager.get_available_metrics()
|
||||||
|
use_batch_processing = (self.batch_processor is not None and
|
||||||
|
available_metrics == ['deqa'])
|
||||||
|
|
||||||
|
if use_batch_processing:
|
||||||
|
logger.info(f"Using memory-efficient batch processing (batch_size={self.batch_size})...")
|
||||||
|
|
||||||
|
# Get DeQA scorer
|
||||||
|
deqa_scorer = self.metrics_manager.get_deqa_scorer()
|
||||||
|
if deqa_scorer is None:
|
||||||
|
logger.error("DeQA scorer not available for batch processing")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Process with batch loader
|
||||||
|
def progress_callback(processed, total):
|
||||||
|
logger.info(f"Batch progress: {processed}/{total} images processed")
|
||||||
|
|
||||||
|
self.results = self.batch_processor.process_images_in_batches(
|
||||||
|
image_paths, deqa_scorer, progress_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fall back to original processing
|
||||||
|
logger.info("Using standard processing...")
|
||||||
self.results = self.metrics_manager.calculate_metrics_batch(image_paths)
|
self.results = self.metrics_manager.calculate_metrics_batch(image_paths)
|
||||||
|
|
||||||
logger.info(f"Analysis completed for {len(self.results)} images")
|
logger.info(f"Analysis completed for {len(self.results)} images")
|
||||||
|
@@ -102,6 +102,57 @@ class DeQAMetric:
|
|||||||
logger.error(f"Error scoring multiple images: {e}")
|
logger.error(f"Error scoring multiple images: {e}")
|
||||||
return [None] * len(image_paths)
|
return [None] * len(image_paths)
|
||||||
|
|
||||||
|
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Optional[float]]:
|
||||||
|
"""
|
||||||
|
Score multiple PIL images directly (for batch processing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of PIL Image objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DeQA scores (0-5 scale) or None for failed images
|
||||||
|
"""
|
||||||
|
if not self._is_loaded:
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Filter out None images
|
||||||
|
valid_images = [img for img in images if img is not None]
|
||||||
|
|
||||||
|
if not valid_images:
|
||||||
|
return [None] * len(images)
|
||||||
|
|
||||||
|
# Score images
|
||||||
|
scores = self.model.score(valid_images)
|
||||||
|
|
||||||
|
# Convert scores to list of floats and map back to original positions
|
||||||
|
result_scores = []
|
||||||
|
valid_idx = 0
|
||||||
|
|
||||||
|
for img in images:
|
||||||
|
if img is None:
|
||||||
|
result_scores.append(None)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
score = scores[valid_idx]
|
||||||
|
if hasattr(score, 'item'):
|
||||||
|
result_scores.append(float(score.item()))
|
||||||
|
elif hasattr(score, 'tolist'):
|
||||||
|
result_scores.append(float(score.tolist()))
|
||||||
|
else:
|
||||||
|
result_scores.append(float(score))
|
||||||
|
valid_idx += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to convert score: {e}")
|
||||||
|
result_scores.append(None)
|
||||||
|
valid_idx += 1
|
||||||
|
|
||||||
|
return result_scores
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scoring PIL images: {e}")
|
||||||
|
return [None] * len(images)
|
||||||
|
|
||||||
def get_metric_name(self) -> str:
|
def get_metric_name(self) -> str:
|
||||||
"""Get the name of this metric."""
|
"""Get the name of this metric."""
|
||||||
return "DeQA"
|
return "DeQA"
|
||||||
|
@@ -196,6 +196,12 @@ class MetricsManager:
|
|||||||
return self.metrics['deqa'].score_multiple_images(image_paths)
|
return self.metrics['deqa'].score_multiple_images(image_paths)
|
||||||
return [None] * len(image_paths)
|
return [None] * len(image_paths)
|
||||||
|
|
||||||
|
def get_deqa_scorer(self):
|
||||||
|
"""Get the DeQA scorer object for batch processing."""
|
||||||
|
if 'deqa' in self.metrics and self.metrics['deqa'] is not None:
|
||||||
|
return self.metrics['deqa']
|
||||||
|
return None
|
||||||
|
|
||||||
def get_traditional_metrics_only(self, image_path: str) -> Dict[str, float]:
|
def get_traditional_metrics_only(self, image_path: str) -> Dict[str, float]:
|
||||||
"""Get only traditional metrics for an image."""
|
"""Get only traditional metrics for an image."""
|
||||||
if 'traditional' in self.metrics and self.metrics['traditional'] is not None:
|
if 'traditional' in self.metrics and self.metrics['traditional'] is not None:
|
||||||
|
63
test_batch_loader.py
Normal file
63
test_batch_loader.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for memory-efficient batch loader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add src directory to path for imports
|
||||||
|
sys.path.append(str(Path(__file__).parent / "src"))
|
||||||
|
|
||||||
|
from src.batch_loader import MemoryEfficientBatchLoader, BatchProcessor
|
||||||
|
from src.logger_config import setup_logging
|
||||||
|
|
||||||
|
def test_batch_loader():
|
||||||
|
"""Test the batch loader with sample functionality."""
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = setup_logging(log_dir="logs", log_level="INFO")
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
batch_size = 4
|
||||||
|
max_image_size = (512, 512)
|
||||||
|
|
||||||
|
# Create batch loader
|
||||||
|
loader = MemoryEfficientBatchLoader(
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_image_size=max_image_size,
|
||||||
|
enable_memory_monitoring=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show initial memory stats
|
||||||
|
memory_stats = loader.get_memory_stats()
|
||||||
|
logger.info(f"Initial memory stats: {memory_stats}")
|
||||||
|
|
||||||
|
# Sample image paths (for testing - these would be real image paths)
|
||||||
|
sample_image_paths = [
|
||||||
|
f"/home/nguyendc/thanh-dev/IQA-Metric-Benchmark/data/task/cni/images/image_{i}.png"
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test batch creation
|
||||||
|
batches = list(loader.create_batches(sample_image_paths))
|
||||||
|
logger.info(f"Created {len(batches)} batches from {len(sample_image_paths)} images")
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
logger.info(f"Batch {i+1}: {len(batch)} images")
|
||||||
|
|
||||||
|
# Mock scoring function for testing
|
||||||
|
def mock_scoring_function(images):
|
||||||
|
"""Mock function that returns random scores."""
|
||||||
|
import random
|
||||||
|
return [random.uniform(1.0, 5.0) for _ in images]
|
||||||
|
|
||||||
|
# Test processing a single batch (with mock data)
|
||||||
|
if batches:
|
||||||
|
logger.info("Testing batch processing with mock data...")
|
||||||
|
# Note: This will fail with actual file paths that don't exist
|
||||||
|
# but demonstrates the interface
|
||||||
|
|
||||||
|
logger.info("Batch loader test completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_batch_loader()
|
Reference in New Issue
Block a user