Files
MultimodalOCR/models/lavis/lavis.py

28 lines
1.3 KiB
Python
Raw Normal View History

2023-05-17 03:38:36 +08:00
import torch
from PIL import Image
from lavis.models import load_model_and_preprocess
2023-05-23 18:24:16 +08:00
from ..process import pad_image, resize_image
2023-05-17 03:38:36 +08:00
class lavis:
def __init__(self, model_name, model_type, device) -> None:
model, vis_processors, txt_processors = load_model_and_preprocess(name = model_name, model_type = model_type, is_eval=True, device=device)
2023-05-23 18:24:16 +08:00
self.model_name = model_name
2023-05-17 03:38:36 +08:00
self.model = model
self.vis_processors = vis_processors
self.txt_processors = txt_processors
self.device = device
2023-05-23 18:24:16 +08:00
def generate(self, image, question, name='resize'):
if 'opt' in self.model_name:
prompt = f'Question: {question} Answer:'
elif 't5' in self.model_name:
prompt = f'Question: {question} Short answer:'
else:
prompt = f'Question: {question} Answer:'
2023-05-17 03:38:36 +08:00
image = Image.open(image).convert("RGB")
2023-05-23 18:24:16 +08:00
if name == "pad":
2023-05-17 03:38:36 +08:00
image = pad_image(image, (224,224))
2023-05-23 18:24:16 +08:00
elif name == "resize":
image = resize_image(image, (224,224))
2023-05-17 03:38:36 +08:00
image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
prompt = self.txt_processors["eval"](prompt)
2023-05-23 18:24:16 +08:00
answer = self.model.predict_answers(samples={"image": image, "text_input": prompt}, inference_method="generate", max_len=48, min_len=1)[0]
2023-05-17 03:38:36 +08:00
return answer