125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
"""
|
|
DeQA Image Quality Scorer
|
|
Core module for scoring images using the DeQA model.
|
|
"""
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM
|
|
from PIL import Image
|
|
from typing import List, Union
|
|
from .logger_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DeQAScorer:
|
|
"""DeQA model wrapper for image quality scoring."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the DeQA scorer."""
|
|
self.model = None
|
|
self._is_loaded = False
|
|
|
|
def load_model(self) -> None:
|
|
"""Load the DeQA scoring model."""
|
|
if self._is_loaded:
|
|
return
|
|
|
|
logger.info("Loading DeQA model...")
|
|
try:
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
"zhiyuanyou/DeQA-Score-Mix3",
|
|
trust_remote_code=True,
|
|
attn_implementation="eager",
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
)
|
|
self._is_loaded = True
|
|
logger.info("DeQA model loaded successfully!")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load DeQA model: {e}")
|
|
raise
|
|
|
|
def score_single_image(self, image_path: str) -> Union[float, None]:
|
|
"""
|
|
Score a single image using DeQA model.
|
|
|
|
Args:
|
|
image_path: Path to the image file
|
|
|
|
Returns:
|
|
DeQA score (0-5 scale) or None if failed
|
|
"""
|
|
if not self._is_loaded:
|
|
self.load_model()
|
|
|
|
try:
|
|
image = Image.open(image_path)
|
|
scores = self.model.score([image])
|
|
|
|
# Convert tensor to float
|
|
if hasattr(scores, 'item'):
|
|
return float(scores.item())
|
|
elif hasattr(scores, 'tolist'):
|
|
return float(scores.tolist()[0])
|
|
else:
|
|
return float(scores[0])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error scoring image {image_path}: {e}")
|
|
return None
|
|
|
|
def score_multiple_images(self, image_paths: List[str]) -> List[Union[float, None]]:
|
|
"""
|
|
Score multiple images using DeQA model.
|
|
|
|
Args:
|
|
image_paths: List of image file paths
|
|
|
|
Returns:
|
|
List of DeQA scores (0-5 scale) or None for failed images
|
|
"""
|
|
if not self._is_loaded:
|
|
self.load_model()
|
|
|
|
try:
|
|
# Open all images
|
|
images = []
|
|
for path in image_paths:
|
|
try:
|
|
image = Image.open(path)
|
|
images.append(image)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to open image {path}: {e}")
|
|
images.append(None)
|
|
|
|
# Score images
|
|
scores = self.model.score(images)
|
|
|
|
# Convert scores to list of floats
|
|
result_scores = []
|
|
for i, score in enumerate(scores):
|
|
if images[i] is None:
|
|
result_scores.append(None)
|
|
else:
|
|
try:
|
|
if hasattr(score, 'item'):
|
|
result_scores.append(float(score.item()))
|
|
elif hasattr(score, 'tolist'):
|
|
result_scores.append(float(score.tolist()))
|
|
else:
|
|
result_scores.append(float(score))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to convert score for image {image_paths[i]}: {e}")
|
|
result_scores.append(None)
|
|
|
|
return result_scores
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error scoring images: {e}")
|
|
return [None] * len(image_paths)
|
|
|
|
def is_loaded(self) -> bool:
|
|
"""Check if the model is loaded."""
|
|
return self._is_loaded
|