Add files via upload

This commit is contained in:
Yuliang Liu
2023-05-12 16:54:54 +08:00
committed by GitHub
parent 80d43dca2c
commit 215accefa6
4 changed files with 489 additions and 0 deletions

47
eval_textvqa.py Normal file
View File

@@ -0,0 +1,47 @@
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import os
import argparse
import json
#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp']
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--ocr_path", type=str, default="./data/")
parser.add_argument("--ocr_dataset", type=str, default="textVQA")
parser.add_argument("--answers-file", type=str, default="./answers_textvqa")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
model_name = "Salesforce/blip2-opt-6.7b"
device = "cuda"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float16
)
model.to(device)
ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w", encoding="utf-8")
with open(args.ocr_path+args.ocr_dataset+'/TextVQA_0.5.1_val.json', 'r') as f:
data = json.load(f)
for i in range(len(data['data'])):
prompt = data['data'][i]['question']
image_file = args.ocr_path+args.ocr_dataset+'/train_images/'+data['data'][i]['image_id']+'.jpg'
question_id = data['data'][i]['question_id']
gt_answers = data['data'][i]['answers']
image = Image.open(image_file)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
ans_file.write(json.dumps({
"image_path": image_file,
"question_id": question_id,
"prompt": prompt,
"answer": generated_text,
"gt_answers":gt_answers,
"model_name":model_name}) + "\n")
ans_file.flush()
ans_file.close()

233
eval_vqa.py Normal file
View File

@@ -0,0 +1,233 @@
import re
import json
def has_word(sentence, word):
pattern = r"\b" + re.escape(word) + r"\b"
match = re.search(pattern, sentence)
if match:
return True
else:
return False
class VQAEval:
def __init__(self):
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def evaluate(self, answer, gt_answers):
answer = answer.replace("\n", " ")
answer = answer.replace("\t", " ")
answer = answer.strip()
answer = self.processPunctuation(answer)
answer = self.processDigitArticle(answer)
for i in range(len(gt_answers)):
gt_answers[i] = gt_answers[i].replace("\n", " ")
gt_answers[i] = gt_answers[i].replace("\t", " ")
gt_answers[i] = gt_answers[i].strip()
gt_answers[i] = self.processPunctuation(gt_answers[i])
gt_answers[i] = self.processDigitArticle(gt_answers[i])
if has_word(answer, gt_answers[i]):
return 1
return 0
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = " ".join(outText)
return outText
if __name__ == "__main__":
path = '/path/to/GPT4/mPLUG-Owl/answers_textvqa/textVQA.jsonl'
cor = 0
num = 0
eval = VQAEval()
with open(path, 'r') as f:
for line in f:
obj = json.loads(line)
gt_answers = obj['gt_answers']
answer = obj['answer']
if eval.evaluate(answer,gt_answers)==1:
cor+=1
num+=1
print(float(cor)/num)

49
kie_eval.py Normal file
View File

@@ -0,0 +1,49 @@
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import os
import argparse
import json
#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp']
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--ocr_path", type=str, default="/path/to/GPT4/KIE_data/")
parser.add_argument("--ocr_dataset", type=str, default="FUNSD")
parser.add_argument("--answers-file", type=str, default="/path/to/GPT4/KIE_data/cmr/blip2answer")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
model_name = "Salesforce/blip2-opt-6.7b"
device = "cuda"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.float16
)
model.to(device)
# prompt = "Question: what is written in the image? Answer:"
ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w+", encoding="utf-8")
with open(args.ocr_path+ args.ocr_dataset +'.txt', 'r') as file:
for line in file:
image_file = line.split()[0]
question_start = line.find('question:') + len('question:')
label_start = line.find('label:') + len('label:')
question = line[question_start:line.find('label:')].strip()
label = line[label_start:].strip()
prompt = f"What is the '{question}' information written in this image ?"
img_path = os.path.join(args.ocr_path+args.ocr_dataset, image_file)
image = Image.open(img_path)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
ans_file.write(json.dumps({"prompt": prompt,
"image_path": image_file,
"label": label,
"text": generated_text,
"model_name":model_name}) + "\n")
ans_file.flush()
ans_file.close()

160
val_text_recognition.py Normal file
View File

@@ -0,0 +1,160 @@
import torch
import math
import os
import PIL
import time
import lmdb
import six
import logging
import sys
import traceback
import torch.distributed as dist
from multiprocessing import Queue, Process
torch.multiprocessing.set_sharing_strategy('file_system')
def pad_image(image, target_size):
"""
:param image: input image
:param target_size: a tuple (num,num)
:return: new image
"""
iw, ih = image.size
w, h = target_size
scale = min(w / iw, h / ih)
nw = int(iw * scale+0.5)
nh = int(ih * scale+0.5)
w += 128
h += 128
image = image.resize((nw, nh), PIL.Image.BICUBIC)
new_image = PIL.Image.new('RGB', (w, h), (0, 0, 0))
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
return new_image
def process_data(_quene, path, batch_size):
from lavis.models import load_model_and_preprocess
_, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=torch.device("cpu"))
env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
txn = env.begin(write=False)
length = int(txn.get('num-samples'.encode()))
print("the length of dataset:", length)
batch_image = []
batch_text = []
idx_list = []
for idx in range(length):
image_key, label_key = f'image-{idx + 1:09d}', f'label-{idx + 1:09d}'
imgbuf = txn.get(image_key.encode()) # image
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
image = PIL.Image.open(buf).convert("RGB")
image = pad_image(image, (224, 224))
label = str(txn.get(label_key.encode()), 'utf-8').strip()
batch_image.append(vis_processors["eval"](image).unsqueeze(0))
batch_text.append(label)
idx_list.append(idx)
if len(batch_image) >= batch_size:
assert len(batch_image) == len(batch_text)
batch_image_tensor = torch.cat(batch_image, dim=0)
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
_quene.put(batch)
batch_text = []
batch_image = []
idx_list = []
if len(batch_image) > 0:
assert len(batch_image) == len(batch_text)
batch_image_tensor = torch.cat(batch_image, dim=0)
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
_quene.put(batch)
_quene.put(None)
while True:
pass
def process_by_model(cuda_idx, _quene_get, _quene_put):
from lavis.models import load_model_and_preprocess
logging.info('init cuda:{}'.format(cuda_idx))
device = torch.device("cuda:{}".format(cuda_idx))
model, _, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device)
logging.info('cuda {} ready'.format(cuda_idx))
while True:
batch = _quene_get.get(True)
if batch is None:
_quene_put.put(None)
while True:
pass
text = batch['text_input']
image = batch['image'].to(device)
idx_list = batch['idx_list']
batch_size = len(text)
assert batch_size == image.shape[0]
with torch.no_grad():
# What is the text of the picture? 66.92/81.39
# What is the content of the text?
# What does the text in the picture say? 66.81/82.39 66.87/82.39(32)
# What is written on the picture? 68.80/81.78 68.86/81.78(32)
answer = model.predict_answers(samples={"image": image, "text_input": ['Question: What does the text in the picture say? Short answer:'] * batch_size}, inference_method="generate", max_len=32)
_quene_put.put([idx_list, text, answer])
if __name__ == '__main__':
path = sys.argv[1]
queue_data = Queue(maxsize=32)
queue_result = Queue()
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s][line:%(lineno)d][%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
batch_size = 128
data_process = Process(target=process_data, args=(queue_data, path, batch_size))
data_process.start()
model_process_list = []
for i in range(1):
model_process = Process(target=process_by_model, args=(i, queue_data, queue_result))
model_process.start()
model_process_list.append(model_process)
# time.sleep(20)
save_all = []
last_time = time.time()
while True:
batch_data = queue_result.get(True)
if batch_data is None:
break
for i in range(len(batch_data[0])):
print('Label: {} Answer: {}'.format(batch_data[1][i], batch_data[2][i]))
save_all.append([batch_data[1][i], batch_data[2][i]])
right_num = 0.0
in_num = 0.0
for label, answer in save_all:
label = label.lower()
answer = answer.lower()
if label == answer or label == answer.split(' ')[0]:
right_num += 1
if label in answer.split(' ') or label in answer or label in answer.replace(' ', '').replace('\'', ''):
in_num += 1
else:
print('[error] Label: {} Answer: {}'.format(label, answer))
print(right_num / len(save_all), right_num, len(save_all))
print('in', in_num / len(save_all), in_num, len(save_all))