update batch loader
This commit is contained in:
30
main.py
30
main.py
@@ -32,6 +32,9 @@ Examples:
|
||||
|
||||
# Run analysis and save to custom output directory
|
||||
python main.py --output-dir custom_results
|
||||
|
||||
# Run with memory-efficient batch processing
|
||||
python main.py --deqa-only --batch-size 4 --max-image-size 512 512
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -102,6 +105,28 @@ Examples:
|
||||
help='Use only DeQA metric (disable traditional metrics)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=8,
|
||||
help='Batch size for memory-efficient processing (default: 8)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-image-size',
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[1024, 1024],
|
||||
metavar=('WIDTH', 'HEIGHT'),
|
||||
help='Maximum image dimensions for preprocessing (default: 1024 1024)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-memory-monitoring',
|
||||
action='store_true',
|
||||
help='Disable memory usage monitoring'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
@@ -150,7 +175,10 @@ Examples:
|
||||
enable_deqa=enable_deqa,
|
||||
enable_traditional=enable_traditional,
|
||||
enable_pyiqa=enable_pyiqa_flag,
|
||||
pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None)
|
||||
pyiqa_selected_metrics=(selected_top20 if args.pyiqa_top20 else None),
|
||||
batch_size=args.batch_size,
|
||||
max_image_size=tuple(args.max_image_size),
|
||||
enable_memory_monitoring=not args.disable_memory_monitoring
|
||||
)
|
||||
results, report = analyzer.run_analysis()
|
||||
|
||||
|
3398
results/facture.txt
Normal file
3398
results/facture.txt
Normal file
File diff suppressed because it is too large
Load Diff
301
src/batch_loader.py
Normal file
301
src/batch_loader.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Memory-efficient batch loader for image processing.
|
||||
Reduces memory usage by processing images in configurable batches with optional preprocessing.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Iterator, Tuple, Optional
|
||||
import psutil
|
||||
import os
|
||||
|
||||
from .logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryEfficientBatchLoader:
|
||||
"""Memory-efficient image batch loader with preprocessing options."""
|
||||
|
||||
def __init__(self,
|
||||
batch_size: int = 8,
|
||||
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
|
||||
enable_memory_monitoring: bool = True,
|
||||
memory_threshold_percent: float = 85.0):
|
||||
"""
|
||||
Initialize the batch loader.
|
||||
|
||||
Args:
|
||||
batch_size: Number of images to process per batch
|
||||
max_image_size: Maximum image dimensions (width, height). If None, no resizing
|
||||
enable_memory_monitoring: Whether to monitor memory usage
|
||||
memory_threshold_percent: Memory usage threshold to trigger cleanup
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.max_image_size = max_image_size
|
||||
self.enable_memory_monitoring = enable_memory_monitoring
|
||||
self.memory_threshold_percent = memory_threshold_percent
|
||||
|
||||
logger.info(f"BatchLoader initialized: batch_size={batch_size}, "
|
||||
f"max_size={max_image_size}, memory_monitoring={enable_memory_monitoring}")
|
||||
|
||||
def _get_memory_usage(self) -> float:
|
||||
"""Get current memory usage percentage."""
|
||||
return psutil.virtual_memory().percent
|
||||
|
||||
def _cleanup_memory(self):
|
||||
"""Force garbage collection and clear GPU cache if available."""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.enable_memory_monitoring:
|
||||
memory_usage = self._get_memory_usage()
|
||||
logger.debug(f"Memory cleanup performed. Current usage: {memory_usage:.1f}%")
|
||||
|
||||
def _check_memory_threshold(self):
|
||||
"""Check if memory usage exceeds threshold and cleanup if needed."""
|
||||
if not self.enable_memory_monitoring:
|
||||
return
|
||||
|
||||
memory_usage = self._get_memory_usage()
|
||||
if memory_usage > self.memory_threshold_percent:
|
||||
logger.warning(f"Memory usage ({memory_usage:.1f}%) exceeds threshold "
|
||||
f"({self.memory_threshold_percent}%). Performing cleanup...")
|
||||
self._cleanup_memory()
|
||||
|
||||
def _preprocess_image(self, image_path: str) -> Optional[Image.Image]:
|
||||
"""
|
||||
Load and preprocess image to reduce memory usage.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
PIL Image object or None if failed
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# Convert to RGB if needed (removes alpha channel)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Resize if max_image_size is specified
|
||||
if self.max_image_size:
|
||||
# Calculate resize ratio to maintain aspect ratio
|
||||
width, height = image.size
|
||||
max_w, max_h = self.max_image_size
|
||||
|
||||
if width > max_w or height > max_h:
|
||||
ratio = min(max_w / width, max_h / height)
|
||||
new_size = (int(width * ratio), int(height * ratio))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
logger.debug(f"Resized {Path(image_path).name} from {width}x{height} to {new_size[0]}x{new_size[1]}")
|
||||
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to preprocess image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def _get_file_size_mb(self, file_path: str) -> float:
|
||||
"""Get file size in MB."""
|
||||
try:
|
||||
return os.path.getsize(file_path) / (1024 * 1024)
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
def create_batches(self, image_paths: List[str]) -> Iterator[List[str]]:
|
||||
"""
|
||||
Create batches of image paths.
|
||||
|
||||
Args:
|
||||
image_paths: List of image file paths
|
||||
|
||||
Yields:
|
||||
Batches of image paths
|
||||
"""
|
||||
for i in range(0, len(image_paths), self.batch_size):
|
||||
batch = image_paths[i:i + self.batch_size]
|
||||
yield batch
|
||||
|
||||
def process_batch(self, image_paths: List[str],
|
||||
scoring_function,
|
||||
include_file_info: bool = True) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Process a batch of images with memory management.
|
||||
|
||||
Args:
|
||||
image_paths: List of image paths in the batch
|
||||
scoring_function: Function to score images (should accept list of PIL Images)
|
||||
include_file_info: Whether to include file size and path info
|
||||
|
||||
Returns:
|
||||
Dictionary with results for each image in the batch
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Check memory before processing batch
|
||||
self._check_memory_threshold()
|
||||
|
||||
# Preprocess all images in batch
|
||||
processed_images = []
|
||||
valid_paths = []
|
||||
|
||||
for path in image_paths:
|
||||
image = self._preprocess_image(path)
|
||||
if image is not None:
|
||||
processed_images.append(image)
|
||||
valid_paths.append(path)
|
||||
else:
|
||||
# Handle failed image
|
||||
image_name = Path(path).name
|
||||
results[image_name] = {
|
||||
'error': f'Failed to load/preprocess image: {path}',
|
||||
'file_path': path if include_file_info else None,
|
||||
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
|
||||
}
|
||||
|
||||
if processed_images:
|
||||
try:
|
||||
# Score the batch
|
||||
scores = scoring_function(processed_images)
|
||||
|
||||
# Combine results
|
||||
for i, (path, score) in enumerate(zip(valid_paths, scores)):
|
||||
image_name = Path(path).name
|
||||
result = {
|
||||
'deqa_score': score,
|
||||
}
|
||||
|
||||
if include_file_info:
|
||||
result.update({
|
||||
'file_path': path,
|
||||
'file_size_mb': round(self._get_file_size_mb(path), 2)
|
||||
})
|
||||
|
||||
results[image_name] = result
|
||||
|
||||
logger.debug(f"Processed batch of {len(processed_images)} images")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring batch: {e}")
|
||||
# Mark all as failed
|
||||
for path in valid_paths:
|
||||
image_name = Path(path).name
|
||||
results[image_name] = {
|
||||
'error': f'Scoring failed: {str(e)}',
|
||||
'file_path': path if include_file_info else None,
|
||||
'file_size_mb': self._get_file_size_mb(path) if include_file_info else None
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clear processed images from memory
|
||||
del processed_images
|
||||
self._cleanup_memory()
|
||||
|
||||
return results
|
||||
|
||||
def get_memory_stats(self) -> Dict[str, Any]:
|
||||
"""Get current memory statistics."""
|
||||
memory = psutil.virtual_memory()
|
||||
stats = {
|
||||
'memory_percent': memory.percent,
|
||||
'memory_available_gb': memory.available / (1024**3),
|
||||
'memory_used_gb': memory.used / (1024**3),
|
||||
'memory_total_gb': memory.total / (1024**3)
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
|
||||
gpu_cached = torch.cuda.memory_reserved() / (1024**3)
|
||||
stats.update({
|
||||
'gpu_memory_allocated_gb': gpu_memory,
|
||||
'gpu_memory_cached_gb': gpu_cached
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""High-level processor that uses MemoryEfficientBatchLoader."""
|
||||
|
||||
def __init__(self,
|
||||
batch_size: int = 8,
|
||||
max_image_size: Optional[Tuple[int, int]] = (1024, 1024),
|
||||
enable_memory_monitoring: bool = True):
|
||||
"""
|
||||
Initialize the batch processor.
|
||||
|
||||
Args:
|
||||
batch_size: Number of images to process per batch
|
||||
max_image_size: Maximum image dimensions for preprocessing
|
||||
enable_memory_monitoring: Whether to monitor memory usage
|
||||
"""
|
||||
self.loader = MemoryEfficientBatchLoader(
|
||||
batch_size=batch_size,
|
||||
max_image_size=max_image_size,
|
||||
enable_memory_monitoring=enable_memory_monitoring
|
||||
)
|
||||
|
||||
def process_images_in_batches(self,
|
||||
image_paths: List[str],
|
||||
scorer,
|
||||
progress_callback=None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Process all images in batches.
|
||||
|
||||
Args:
|
||||
image_paths: List of all image paths to process
|
||||
scorer: Object with score_multiple_images method
|
||||
progress_callback: Optional function to call with progress updates
|
||||
|
||||
Returns:
|
||||
Combined results from all batches
|
||||
"""
|
||||
all_results = {}
|
||||
total_images = len(image_paths)
|
||||
processed_images = 0
|
||||
|
||||
logger.info(f"Processing {total_images} images in batches of {self.loader.batch_size}")
|
||||
|
||||
# Show initial memory stats
|
||||
if self.loader.enable_memory_monitoring:
|
||||
initial_stats = self.loader.get_memory_stats()
|
||||
logger.info(f"Initial memory usage: {initial_stats['memory_percent']:.1f}% "
|
||||
f"({initial_stats['memory_used_gb']:.1f}GB/{initial_stats['memory_total_gb']:.1f}GB)")
|
||||
|
||||
for batch_idx, batch_paths in enumerate(self.loader.create_batches(image_paths)):
|
||||
logger.info(f"Processing batch {batch_idx + 1} ({len(batch_paths)} images)...")
|
||||
|
||||
# Create scoring function that works with the batch loader
|
||||
def scoring_function(images):
|
||||
return scorer.score_multiple_images_from_pil(images)
|
||||
|
||||
# Process batch
|
||||
batch_results = self.loader.process_batch(batch_paths, scoring_function)
|
||||
all_results.update(batch_results)
|
||||
|
||||
processed_images += len(batch_paths)
|
||||
|
||||
# Progress callback
|
||||
if progress_callback:
|
||||
progress_callback(processed_images, total_images)
|
||||
|
||||
# Log progress and memory stats
|
||||
progress_percent = (processed_images / total_images) * 100
|
||||
logger.info(f"Progress: {processed_images}/{total_images} ({progress_percent:.1f}%)")
|
||||
|
||||
if self.loader.enable_memory_monitoring:
|
||||
current_stats = self.loader.get_memory_stats()
|
||||
logger.debug(f"Current memory usage: {current_stats['memory_percent']:.1f}%")
|
||||
|
||||
logger.info(f"Completed processing {total_images} images")
|
||||
|
||||
# Final memory cleanup
|
||||
self.loader._cleanup_memory()
|
||||
|
||||
return all_results
|
@@ -119,6 +119,57 @@ class DeQAScorer:
|
||||
logger.error(f"Error scoring images: {e}")
|
||||
return [None] * len(image_paths)
|
||||
|
||||
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Union[float, None]]:
|
||||
"""
|
||||
Score multiple PIL images directly (for batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Image objects
|
||||
|
||||
Returns:
|
||||
List of DeQA scores (0-5 scale) or None for failed images
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
# Filter out None images
|
||||
valid_images = [img for img in images if img is not None]
|
||||
|
||||
if not valid_images:
|
||||
return [None] * len(images)
|
||||
|
||||
# Score images
|
||||
scores = self.model.score(valid_images)
|
||||
|
||||
# Convert scores to list of floats and map back to original positions
|
||||
result_scores = []
|
||||
valid_idx = 0
|
||||
|
||||
for img in images:
|
||||
if img is None:
|
||||
result_scores.append(None)
|
||||
else:
|
||||
try:
|
||||
score = scores[valid_idx]
|
||||
if hasattr(score, 'item'):
|
||||
result_scores.append(float(score.item()))
|
||||
elif hasattr(score, 'tolist'):
|
||||
result_scores.append(float(score.tolist()))
|
||||
else:
|
||||
result_scores.append(float(score))
|
||||
valid_idx += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert score: {e}")
|
||||
result_scores.append(None)
|
||||
valid_idx += 1
|
||||
|
||||
return result_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring PIL images: {e}")
|
||||
return [None] * len(images)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if the model is loaded."""
|
||||
return self._is_loaded
|
||||
|
@@ -12,6 +12,7 @@ from datetime import datetime
|
||||
|
||||
from .metrics import MetricsManager
|
||||
from .logger_config import get_logger
|
||||
from .batch_loader import BatchProcessor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -21,7 +22,9 @@ class IQAAnalyzer:
|
||||
|
||||
def __init__(self, image_dir: str, output_dir: str = "results",
|
||||
enable_deqa: bool = True, enable_traditional: bool = True,
|
||||
enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None):
|
||||
enable_pyiqa: bool = True, pyiqa_selected_metrics: List[str] | None = None,
|
||||
batch_size: int = 8, max_image_size: tuple = (1024, 1024),
|
||||
enable_memory_monitoring: bool = True):
|
||||
"""
|
||||
Initialize the IQA analyzer.
|
||||
|
||||
@@ -30,11 +33,21 @@ class IQAAnalyzer:
|
||||
output_dir: Directory to save results
|
||||
enable_deqa: Whether to enable DeQA metric
|
||||
enable_traditional: Whether to enable traditional metrics
|
||||
enable_pyiqa: Whether to enable PyIQA metrics
|
||||
pyiqa_selected_metrics: List of specific PyIQA metrics to use
|
||||
batch_size: Number of images to process per batch
|
||||
max_image_size: Maximum image dimensions for preprocessing
|
||||
enable_memory_monitoring: Whether to monitor memory usage
|
||||
"""
|
||||
self.image_dir = Path(image_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store batch processing settings
|
||||
self.batch_size = batch_size
|
||||
self.max_image_size = max_image_size
|
||||
self.enable_memory_monitoring = enable_memory_monitoring
|
||||
|
||||
# Initialize metrics manager
|
||||
self.metrics_manager = MetricsManager(
|
||||
enable_deqa=enable_deqa,
|
||||
@@ -43,12 +56,23 @@ class IQAAnalyzer:
|
||||
pyiqa_selected_metrics=pyiqa_selected_metrics
|
||||
)
|
||||
|
||||
# Initialize batch processor (only if DeQA is enabled for memory efficiency)
|
||||
self.batch_processor = None
|
||||
if enable_deqa:
|
||||
self.batch_processor = BatchProcessor(
|
||||
batch_size=batch_size,
|
||||
max_image_size=max_image_size,
|
||||
enable_memory_monitoring=enable_memory_monitoring
|
||||
)
|
||||
|
||||
# Results storage
|
||||
self.results = {}
|
||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
logger.info(f"IQA Analyzer initialized for: {image_dir}")
|
||||
logger.info(f"Available metrics: {self.metrics_manager.get_available_metrics()}")
|
||||
logger.info(f"Batch processing: {'enabled' if self.batch_processor else 'disabled'} "
|
||||
f"(batch_size={batch_size}, max_size={max_image_size})")
|
||||
|
||||
def _get_image_files(self) -> List[Path]:
|
||||
"""Get all image files from the directory."""
|
||||
@@ -81,8 +105,30 @@ class IQAAnalyzer:
|
||||
# Prepare image paths for analysis
|
||||
image_paths = [str(img_path) for img_path in image_files]
|
||||
|
||||
# Calculate all metrics for all images
|
||||
logger.info("Calculating comprehensive metrics...")
|
||||
# Use batch processing if available and only DeQA is enabled
|
||||
available_metrics = self.metrics_manager.get_available_metrics()
|
||||
use_batch_processing = (self.batch_processor is not None and
|
||||
available_metrics == ['deqa'])
|
||||
|
||||
if use_batch_processing:
|
||||
logger.info(f"Using memory-efficient batch processing (batch_size={self.batch_size})...")
|
||||
|
||||
# Get DeQA scorer
|
||||
deqa_scorer = self.metrics_manager.get_deqa_scorer()
|
||||
if deqa_scorer is None:
|
||||
logger.error("DeQA scorer not available for batch processing")
|
||||
return {}
|
||||
|
||||
# Process with batch loader
|
||||
def progress_callback(processed, total):
|
||||
logger.info(f"Batch progress: {processed}/{total} images processed")
|
||||
|
||||
self.results = self.batch_processor.process_images_in_batches(
|
||||
image_paths, deqa_scorer, progress_callback
|
||||
)
|
||||
else:
|
||||
# Fall back to original processing
|
||||
logger.info("Using standard processing...")
|
||||
self.results = self.metrics_manager.calculate_metrics_batch(image_paths)
|
||||
|
||||
logger.info(f"Analysis completed for {len(self.results)} images")
|
||||
|
@@ -102,6 +102,57 @@ class DeQAMetric:
|
||||
logger.error(f"Error scoring multiple images: {e}")
|
||||
return [None] * len(image_paths)
|
||||
|
||||
def score_multiple_images_from_pil(self, images: List[Image.Image]) -> List[Optional[float]]:
|
||||
"""
|
||||
Score multiple PIL images directly (for batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Image objects
|
||||
|
||||
Returns:
|
||||
List of DeQA scores (0-5 scale) or None for failed images
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
# Filter out None images
|
||||
valid_images = [img for img in images if img is not None]
|
||||
|
||||
if not valid_images:
|
||||
return [None] * len(images)
|
||||
|
||||
# Score images
|
||||
scores = self.model.score(valid_images)
|
||||
|
||||
# Convert scores to list of floats and map back to original positions
|
||||
result_scores = []
|
||||
valid_idx = 0
|
||||
|
||||
for img in images:
|
||||
if img is None:
|
||||
result_scores.append(None)
|
||||
else:
|
||||
try:
|
||||
score = scores[valid_idx]
|
||||
if hasattr(score, 'item'):
|
||||
result_scores.append(float(score.item()))
|
||||
elif hasattr(score, 'tolist'):
|
||||
result_scores.append(float(score.tolist()))
|
||||
else:
|
||||
result_scores.append(float(score))
|
||||
valid_idx += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert score: {e}")
|
||||
result_scores.append(None)
|
||||
valid_idx += 1
|
||||
|
||||
return result_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring PIL images: {e}")
|
||||
return [None] * len(images)
|
||||
|
||||
def get_metric_name(self) -> str:
|
||||
"""Get the name of this metric."""
|
||||
return "DeQA"
|
||||
|
@@ -196,6 +196,12 @@ class MetricsManager:
|
||||
return self.metrics['deqa'].score_multiple_images(image_paths)
|
||||
return [None] * len(image_paths)
|
||||
|
||||
def get_deqa_scorer(self):
|
||||
"""Get the DeQA scorer object for batch processing."""
|
||||
if 'deqa' in self.metrics and self.metrics['deqa'] is not None:
|
||||
return self.metrics['deqa']
|
||||
return None
|
||||
|
||||
def get_traditional_metrics_only(self, image_path: str) -> Dict[str, float]:
|
||||
"""Get only traditional metrics for an image."""
|
||||
if 'traditional' in self.metrics and self.metrics['traditional'] is not None:
|
||||
|
63
test_batch_loader.py
Normal file
63
test_batch_loader.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for memory-efficient batch loader.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.append(str(Path(__file__).parent / "src"))
|
||||
|
||||
from src.batch_loader import MemoryEfficientBatchLoader, BatchProcessor
|
||||
from src.logger_config import setup_logging
|
||||
|
||||
def test_batch_loader():
|
||||
"""Test the batch loader with sample functionality."""
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(log_dir="logs", log_level="INFO")
|
||||
|
||||
# Test configuration
|
||||
batch_size = 4
|
||||
max_image_size = (512, 512)
|
||||
|
||||
# Create batch loader
|
||||
loader = MemoryEfficientBatchLoader(
|
||||
batch_size=batch_size,
|
||||
max_image_size=max_image_size,
|
||||
enable_memory_monitoring=True
|
||||
)
|
||||
|
||||
# Show initial memory stats
|
||||
memory_stats = loader.get_memory_stats()
|
||||
logger.info(f"Initial memory stats: {memory_stats}")
|
||||
|
||||
# Sample image paths (for testing - these would be real image paths)
|
||||
sample_image_paths = [
|
||||
f"/home/nguyendc/thanh-dev/IQA-Metric-Benchmark/data/task/cni/images/image_{i}.png"
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
# Test batch creation
|
||||
batches = list(loader.create_batches(sample_image_paths))
|
||||
logger.info(f"Created {len(batches)} batches from {len(sample_image_paths)} images")
|
||||
for i, batch in enumerate(batches):
|
||||
logger.info(f"Batch {i+1}: {len(batch)} images")
|
||||
|
||||
# Mock scoring function for testing
|
||||
def mock_scoring_function(images):
|
||||
"""Mock function that returns random scores."""
|
||||
import random
|
||||
return [random.uniform(1.0, 5.0) for _ in images]
|
||||
|
||||
# Test processing a single batch (with mock data)
|
||||
if batches:
|
||||
logger.info("Testing batch processing with mock data...")
|
||||
# Note: This will fail with actual file paths that don't exist
|
||||
# but demonstrates the interface
|
||||
|
||||
logger.info("Batch loader test completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch_loader()
|
Reference in New Issue
Block a user