Files
MultimodalOCR/val_text_recognition.py
2023-05-12 16:54:54 +08:00

160 lines
5.3 KiB
Python

import torch
import math
import os
import PIL
import time
import lmdb
import six
import logging
import sys
import traceback
import torch.distributed as dist
from multiprocessing import Queue, Process
torch.multiprocessing.set_sharing_strategy('file_system')
def pad_image(image, target_size):
"""
:param image: input image
:param target_size: a tuple (num,num)
:return: new image
"""
iw, ih = image.size
w, h = target_size
scale = min(w / iw, h / ih)
nw = int(iw * scale+0.5)
nh = int(ih * scale+0.5)
w += 128
h += 128
image = image.resize((nw, nh), PIL.Image.BICUBIC)
new_image = PIL.Image.new('RGB', (w, h), (0, 0, 0))
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
return new_image
def process_data(_quene, path, batch_size):
from lavis.models import load_model_and_preprocess
_, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=torch.device("cpu"))
env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
txn = env.begin(write=False)
length = int(txn.get('num-samples'.encode()))
print("the length of dataset:", length)
batch_image = []
batch_text = []
idx_list = []
for idx in range(length):
image_key, label_key = f'image-{idx + 1:09d}', f'label-{idx + 1:09d}'
imgbuf = txn.get(image_key.encode()) # image
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
image = PIL.Image.open(buf).convert("RGB")
image = pad_image(image, (224, 224))
label = str(txn.get(label_key.encode()), 'utf-8').strip()
batch_image.append(vis_processors["eval"](image).unsqueeze(0))
batch_text.append(label)
idx_list.append(idx)
if len(batch_image) >= batch_size:
assert len(batch_image) == len(batch_text)
batch_image_tensor = torch.cat(batch_image, dim=0)
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
_quene.put(batch)
batch_text = []
batch_image = []
idx_list = []
if len(batch_image) > 0:
assert len(batch_image) == len(batch_text)
batch_image_tensor = torch.cat(batch_image, dim=0)
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
_quene.put(batch)
_quene.put(None)
while True:
pass
def process_by_model(cuda_idx, _quene_get, _quene_put):
from lavis.models import load_model_and_preprocess
logging.info('init cuda:{}'.format(cuda_idx))
device = torch.device("cuda:{}".format(cuda_idx))
model, _, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device)
logging.info('cuda {} ready'.format(cuda_idx))
while True:
batch = _quene_get.get(True)
if batch is None:
_quene_put.put(None)
while True:
pass
text = batch['text_input']
image = batch['image'].to(device)
idx_list = batch['idx_list']
batch_size = len(text)
assert batch_size == image.shape[0]
with torch.no_grad():
# What is the text of the picture? 66.92/81.39
# What is the content of the text?
# What does the text in the picture say? 66.81/82.39 66.87/82.39(32)
# What is written on the picture? 68.80/81.78 68.86/81.78(32)
answer = model.predict_answers(samples={"image": image, "text_input": ['Question: What does the text in the picture say? Short answer:'] * batch_size}, inference_method="generate", max_len=32)
_quene_put.put([idx_list, text, answer])
if __name__ == '__main__':
path = sys.argv[1]
queue_data = Queue(maxsize=32)
queue_result = Queue()
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s][line:%(lineno)d][%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
batch_size = 128
data_process = Process(target=process_data, args=(queue_data, path, batch_size))
data_process.start()
model_process_list = []
for i in range(1):
model_process = Process(target=process_by_model, args=(i, queue_data, queue_result))
model_process.start()
model_process_list.append(model_process)
# time.sleep(20)
save_all = []
last_time = time.time()
while True:
batch_data = queue_result.get(True)
if batch_data is None:
break
for i in range(len(batch_data[0])):
print('Label: {} Answer: {}'.format(batch_data[1][i], batch_data[2][i]))
save_all.append([batch_data[1][i], batch_data[2][i]])
right_num = 0.0
in_num = 0.0
for label, answer in save_all:
label = label.lower()
answer = answer.lower()
if label == answer or label == answer.split(' ')[0]:
right_num += 1
if label in answer.split(' ') or label in answer or label in answer.replace(' ', '').replace('\'', ''):
in_num += 1
else:
print('[error] Label: {} Answer: {}'.format(label, answer))
print(right_num / len(save_all), right_num, len(save_all))
print('in', in_num / len(save_all), in_num, len(save_all))