init structure
This commit is contained in:
124
src/deqa_scorer.py
Normal file
124
src/deqa_scorer.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
DeQA Image Quality Scorer
|
||||
Core module for scoring images using the DeQA model.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from PIL import Image
|
||||
from typing import List, Union
|
||||
from .logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DeQAScorer:
|
||||
"""DeQA model wrapper for image quality scoring."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the DeQA scorer."""
|
||||
self.model = None
|
||||
self._is_loaded = False
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load the DeQA scoring model."""
|
||||
if self._is_loaded:
|
||||
return
|
||||
|
||||
logger.info("Loading DeQA model...")
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"zhiyuanyou/DeQA-Score-Mix3",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
self._is_loaded = True
|
||||
logger.info("DeQA model loaded successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load DeQA model: {e}")
|
||||
raise
|
||||
|
||||
def score_single_image(self, image_path: str) -> Union[float, None]:
|
||||
"""
|
||||
Score a single image using DeQA model.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
DeQA score (0-5 scale) or None if failed
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
scores = self.model.score([image])
|
||||
|
||||
# Convert tensor to float
|
||||
if hasattr(scores, 'item'):
|
||||
return float(scores.item())
|
||||
elif hasattr(scores, 'tolist'):
|
||||
return float(scores.tolist()[0])
|
||||
else:
|
||||
return float(scores[0])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring image {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def score_multiple_images(self, image_paths: List[str]) -> List[Union[float, None]]:
|
||||
"""
|
||||
Score multiple images using DeQA model.
|
||||
|
||||
Args:
|
||||
image_paths: List of image file paths
|
||||
|
||||
Returns:
|
||||
List of DeQA scores (0-5 scale) or None for failed images
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
# Open all images
|
||||
images = []
|
||||
for path in image_paths:
|
||||
try:
|
||||
image = Image.open(path)
|
||||
images.append(image)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to open image {path}: {e}")
|
||||
images.append(None)
|
||||
|
||||
# Score images
|
||||
scores = self.model.score(images)
|
||||
|
||||
# Convert scores to list of floats
|
||||
result_scores = []
|
||||
for i, score in enumerate(scores):
|
||||
if images[i] is None:
|
||||
result_scores.append(None)
|
||||
else:
|
||||
try:
|
||||
if hasattr(score, 'item'):
|
||||
result_scores.append(float(score.item()))
|
||||
elif hasattr(score, 'tolist'):
|
||||
result_scores.append(float(score.tolist()))
|
||||
else:
|
||||
result_scores.append(float(score))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert score for image {image_paths[i]}: {e}")
|
||||
result_scores.append(None)
|
||||
|
||||
return result_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring images: {e}")
|
||||
return [None] * len(image_paths)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if the model is loaded."""
|
||||
return self._is_loaded
|
Reference in New Issue
Block a user