""" DeQA Image Quality Scorer Core module for scoring images using the DeQA model. """ import torch from transformers import AutoModelForCausalLM from PIL import Image from typing import List, Union from .logger_config import get_logger logger = get_logger(__name__) class DeQAScorer: """DeQA model wrapper for image quality scoring.""" def __init__(self): """Initialize the DeQA scorer.""" self.model = None self._is_loaded = False def load_model(self) -> None: """Load the DeQA scoring model.""" if self._is_loaded: return logger.info("Loading DeQA model...") try: self.model = AutoModelForCausalLM.from_pretrained( "zhiyuanyou/DeQA-Score-Mix3", trust_remote_code=True, attn_implementation="eager", torch_dtype=torch.float16, device_map="auto", ) self._is_loaded = True logger.info("DeQA model loaded successfully!") except Exception as e: logger.error(f"Failed to load DeQA model: {e}") raise def score_single_image(self, image_path: str) -> Union[float, None]: """ Score a single image using DeQA model. Args: image_path: Path to the image file Returns: DeQA score (0-5 scale) or None if failed """ if not self._is_loaded: self.load_model() try: image = Image.open(image_path) scores = self.model.score([image]) # Convert tensor to float if hasattr(scores, 'item'): return float(scores.item()) elif hasattr(scores, 'tolist'): return float(scores.tolist()[0]) else: return float(scores[0]) except Exception as e: logger.error(f"Error scoring image {image_path}: {e}") return None def score_multiple_images(self, image_paths: List[str]) -> List[Union[float, None]]: """ Score multiple images using DeQA model. Args: image_paths: List of image file paths Returns: List of DeQA scores (0-5 scale) or None for failed images """ if not self._is_loaded: self.load_model() try: # Open all images images = [] for path in image_paths: try: image = Image.open(path) images.append(image) except Exception as e: logger.warning(f"Failed to open image {path}: {e}") images.append(None) # Score images scores = self.model.score(images) # Convert scores to list of floats result_scores = [] for i, score in enumerate(scores): if images[i] is None: result_scores.append(None) else: try: if hasattr(score, 'item'): result_scores.append(float(score.item())) elif hasattr(score, 'tolist'): result_scores.append(float(score.tolist())) else: result_scores.append(float(score)) except Exception as e: logger.warning(f"Failed to convert score for image {image_paths[i]}: {e}") result_scores.append(None) return result_scores except Exception as e: logger.error(f"Error scoring images: {e}") return [None] * len(image_paths) def is_loaded(self) -> bool: """Check if the model is loaded.""" return self._is_loaded