Files
IQA-Metric-Benchmark/src/deqa_scorer.py

125 lines
3.9 KiB
Python
Raw Normal View History

2025-08-26 09:35:24 +00:00
"""
DeQA Image Quality Scorer
Core module for scoring images using the DeQA model.
"""
import torch
from transformers import AutoModelForCausalLM
from PIL import Image
from typing import List, Union
from .logger_config import get_logger
logger = get_logger(__name__)
class DeQAScorer:
"""DeQA model wrapper for image quality scoring."""
def __init__(self):
"""Initialize the DeQA scorer."""
self.model = None
self._is_loaded = False
def load_model(self) -> None:
"""Load the DeQA scoring model."""
if self._is_loaded:
return
logger.info("Loading DeQA model...")
try:
self.model = AutoModelForCausalLM.from_pretrained(
"zhiyuanyou/DeQA-Score-Mix3",
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16,
device_map="auto",
)
self._is_loaded = True
logger.info("DeQA model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load DeQA model: {e}")
raise
def score_single_image(self, image_path: str) -> Union[float, None]:
"""
Score a single image using DeQA model.
Args:
image_path: Path to the image file
Returns:
DeQA score (0-5 scale) or None if failed
"""
if not self._is_loaded:
self.load_model()
try:
image = Image.open(image_path)
scores = self.model.score([image])
# Convert tensor to float
if hasattr(scores, 'item'):
return float(scores.item())
elif hasattr(scores, 'tolist'):
return float(scores.tolist()[0])
else:
return float(scores[0])
except Exception as e:
logger.error(f"Error scoring image {image_path}: {e}")
return None
def score_multiple_images(self, image_paths: List[str]) -> List[Union[float, None]]:
"""
Score multiple images using DeQA model.
Args:
image_paths: List of image file paths
Returns:
List of DeQA scores (0-5 scale) or None for failed images
"""
if not self._is_loaded:
self.load_model()
try:
# Open all images
images = []
for path in image_paths:
try:
image = Image.open(path)
images.append(image)
except Exception as e:
logger.warning(f"Failed to open image {path}: {e}")
images.append(None)
# Score images
scores = self.model.score(images)
# Convert scores to list of floats
result_scores = []
for i, score in enumerate(scores):
if images[i] is None:
result_scores.append(None)
else:
try:
if hasattr(score, 'item'):
result_scores.append(float(score.item()))
elif hasattr(score, 'tolist'):
result_scores.append(float(score.tolist()))
else:
result_scores.append(float(score))
except Exception as e:
logger.warning(f"Failed to convert score for image {image_paths[i]}: {e}")
result_scores.append(None)
return result_scores
except Exception as e:
logger.error(f"Error scoring images: {e}")
return [None] * len(image_paths)
def is_loaded(self) -> bool:
"""Check if the model is loaded."""
return self._is_loaded