Add files via upload
This commit is contained in:
47
eval_textvqa.py
Normal file
47
eval_textvqa.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
|
||||
#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp']
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--ocr_path", type=str, default="./data/")
|
||||
parser.add_argument("--ocr_dataset", type=str, default="textVQA")
|
||||
parser.add_argument("--answers-file", type=str, default="./answers_textvqa")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
model_name = "Salesforce/blip2-opt-6.7b"
|
||||
device = "cuda"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch.float16
|
||||
)
|
||||
model.to(device)
|
||||
ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w", encoding="utf-8")
|
||||
with open(args.ocr_path+args.ocr_dataset+'/TextVQA_0.5.1_val.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
for i in range(len(data['data'])):
|
||||
prompt = data['data'][i]['question']
|
||||
image_file = args.ocr_path+args.ocr_dataset+'/train_images/'+data['data'][i]['image_id']+'.jpg'
|
||||
question_id = data['data'][i]['question_id']
|
||||
gt_answers = data['data'][i]['answers']
|
||||
|
||||
image = Image.open(image_file)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
generated_ids = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
ans_file.write(json.dumps({
|
||||
"image_path": image_file,
|
||||
"question_id": question_id,
|
||||
"prompt": prompt,
|
||||
"answer": generated_text,
|
||||
"gt_answers":gt_answers,
|
||||
"model_name":model_name}) + "\n")
|
||||
ans_file.flush()
|
||||
ans_file.close()
|
233
eval_vqa.py
Normal file
233
eval_vqa.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import re
|
||||
import json
|
||||
def has_word(sentence, word):
|
||||
pattern = r"\b" + re.escape(word) + r"\b"
|
||||
match = re.search(pattern, sentence)
|
||||
if match:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
class VQAEval:
|
||||
def __init__(self):
|
||||
self.contractions = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
}
|
||||
self.manualMap = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
self.articles = ["a", "an", "the"]
|
||||
|
||||
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
||||
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
||||
self.punct = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
def evaluate(self, answer, gt_answers):
|
||||
answer = answer.replace("\n", " ")
|
||||
answer = answer.replace("\t", " ")
|
||||
answer = answer.strip()
|
||||
answer = self.processPunctuation(answer)
|
||||
answer = self.processDigitArticle(answer)
|
||||
for i in range(len(gt_answers)):
|
||||
gt_answers[i] = gt_answers[i].replace("\n", " ")
|
||||
gt_answers[i] = gt_answers[i].replace("\t", " ")
|
||||
gt_answers[i] = gt_answers[i].strip()
|
||||
gt_answers[i] = self.processPunctuation(gt_answers[i])
|
||||
gt_answers[i] = self.processDigitArticle(gt_answers[i])
|
||||
if has_word(answer, gt_answers[i]):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def processPunctuation(self, inText):
|
||||
outText = inText
|
||||
for p in self.punct:
|
||||
if (p + " " in inText or " " + p in inText) or (
|
||||
re.search(self.commaStrip, inText) != None
|
||||
):
|
||||
outText = outText.replace(p, "")
|
||||
else:
|
||||
outText = outText.replace(p, " ")
|
||||
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
||||
return outText
|
||||
|
||||
def processDigitArticle(self, inText):
|
||||
outText = []
|
||||
tempText = inText.lower().split()
|
||||
for word in tempText:
|
||||
word = self.manualMap.setdefault(word, word)
|
||||
if word not in self.articles:
|
||||
outText.append(word)
|
||||
else:
|
||||
pass
|
||||
for wordId, word in enumerate(outText):
|
||||
if word in self.contractions:
|
||||
outText[wordId] = self.contractions[word]
|
||||
outText = " ".join(outText)
|
||||
return outText
|
||||
if __name__ == "__main__":
|
||||
path = '/path/to/GPT4/mPLUG-Owl/answers_textvqa/textVQA.jsonl'
|
||||
cor = 0
|
||||
num = 0
|
||||
eval = VQAEval()
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
gt_answers = obj['gt_answers']
|
||||
answer = obj['answer']
|
||||
if eval.evaluate(answer,gt_answers)==1:
|
||||
cor+=1
|
||||
num+=1
|
||||
print(float(cor)/num)
|
||||
|
||||
|
49
kie_eval.py
Normal file
49
kie_eval.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
|
||||
#dataset_name=['ct80','IC13_857','IC15_1811','IIIT5K','svt','svtp']
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--ocr_path", type=str, default="/path/to/GPT4/KIE_data/")
|
||||
parser.add_argument("--ocr_dataset", type=str, default="FUNSD")
|
||||
parser.add_argument("--answers-file", type=str, default="/path/to/GPT4/KIE_data/cmr/blip2answer")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
model_name = "Salesforce/blip2-opt-6.7b"
|
||||
device = "cuda"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch.float16
|
||||
)
|
||||
model.to(device)
|
||||
# prompt = "Question: what is written in the image? Answer:"
|
||||
ans_file = open(args.answers_file + '/' + args.ocr_dataset + '.jsonl', "w+", encoding="utf-8")
|
||||
with open(args.ocr_path+ args.ocr_dataset +'.txt', 'r') as file:
|
||||
for line in file:
|
||||
image_file = line.split()[0]
|
||||
|
||||
question_start = line.find('question:') + len('question:')
|
||||
label_start = line.find('label:') + len('label:')
|
||||
question = line[question_start:line.find('label:')].strip()
|
||||
label = line[label_start:].strip()
|
||||
prompt = f"What is the '{question}' information written in this image ?"
|
||||
|
||||
img_path = os.path.join(args.ocr_path+args.ocr_dataset, image_file)
|
||||
image = Image.open(img_path)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
generated_ids = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
ans_file.write(json.dumps({"prompt": prompt,
|
||||
"image_path": image_file,
|
||||
"label": label,
|
||||
"text": generated_text,
|
||||
"model_name":model_name}) + "\n")
|
||||
ans_file.flush()
|
||||
ans_file.close()
|
160
val_text_recognition.py
Normal file
160
val_text_recognition.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import PIL
|
||||
import time
|
||||
import lmdb
|
||||
import six
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import torch.distributed as dist
|
||||
from multiprocessing import Queue, Process
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
|
||||
def pad_image(image, target_size):
|
||||
|
||||
"""
|
||||
:param image: input image
|
||||
:param target_size: a tuple (num,num)
|
||||
:return: new image
|
||||
"""
|
||||
|
||||
iw, ih = image.size
|
||||
w, h = target_size
|
||||
|
||||
scale = min(w / iw, h / ih)
|
||||
|
||||
nw = int(iw * scale+0.5)
|
||||
nh = int(ih * scale+0.5)
|
||||
|
||||
w += 128
|
||||
h += 128
|
||||
|
||||
|
||||
image = image.resize((nw, nh), PIL.Image.BICUBIC)
|
||||
new_image = PIL.Image.new('RGB', (w, h), (0, 0, 0))
|
||||
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
|
||||
|
||||
return new_image
|
||||
|
||||
|
||||
def process_data(_quene, path, batch_size):
|
||||
from lavis.models import load_model_and_preprocess
|
||||
_, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=torch.device("cpu"))
|
||||
env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
|
||||
txn = env.begin(write=False)
|
||||
length = int(txn.get('num-samples'.encode()))
|
||||
print("the length of dataset:", length)
|
||||
batch_image = []
|
||||
batch_text = []
|
||||
idx_list = []
|
||||
for idx in range(length):
|
||||
image_key, label_key = f'image-{idx + 1:09d}', f'label-{idx + 1:09d}'
|
||||
imgbuf = txn.get(image_key.encode()) # image
|
||||
buf = six.BytesIO()
|
||||
buf.write(imgbuf)
|
||||
buf.seek(0)
|
||||
image = PIL.Image.open(buf).convert("RGB")
|
||||
image = pad_image(image, (224, 224))
|
||||
label = str(txn.get(label_key.encode()), 'utf-8').strip()
|
||||
batch_image.append(vis_processors["eval"](image).unsqueeze(0))
|
||||
batch_text.append(label)
|
||||
idx_list.append(idx)
|
||||
if len(batch_image) >= batch_size:
|
||||
assert len(batch_image) == len(batch_text)
|
||||
batch_image_tensor = torch.cat(batch_image, dim=0)
|
||||
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
|
||||
_quene.put(batch)
|
||||
batch_text = []
|
||||
batch_image = []
|
||||
idx_list = []
|
||||
if len(batch_image) > 0:
|
||||
assert len(batch_image) == len(batch_text)
|
||||
batch_image_tensor = torch.cat(batch_image, dim=0)
|
||||
batch = {'text_input': batch_text, 'image': batch_image_tensor, 'idx_list': idx_list}
|
||||
_quene.put(batch)
|
||||
_quene.put(None)
|
||||
while True:
|
||||
pass
|
||||
|
||||
|
||||
def process_by_model(cuda_idx, _quene_get, _quene_put):
|
||||
from lavis.models import load_model_and_preprocess
|
||||
logging.info('init cuda:{}'.format(cuda_idx))
|
||||
device = torch.device("cuda:{}".format(cuda_idx))
|
||||
model, _, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device)
|
||||
logging.info('cuda {} ready'.format(cuda_idx))
|
||||
while True:
|
||||
batch = _quene_get.get(True)
|
||||
if batch is None:
|
||||
_quene_put.put(None)
|
||||
while True:
|
||||
pass
|
||||
text = batch['text_input']
|
||||
image = batch['image'].to(device)
|
||||
idx_list = batch['idx_list']
|
||||
batch_size = len(text)
|
||||
assert batch_size == image.shape[0]
|
||||
|
||||
with torch.no_grad():
|
||||
# What is the text of the picture? 66.92/81.39
|
||||
# What is the content of the text?
|
||||
# What does the text in the picture say? 66.81/82.39 66.87/82.39(32)
|
||||
# What is written on the picture? 68.80/81.78 68.86/81.78(32)
|
||||
answer = model.predict_answers(samples={"image": image, "text_input": ['Question: What does the text in the picture say? Short answer:'] * batch_size}, inference_method="generate", max_len=32)
|
||||
|
||||
_quene_put.put([idx_list, text, answer])
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
path = sys.argv[1]
|
||||
|
||||
queue_data = Queue(maxsize=32)
|
||||
queue_result = Queue()
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='[%(asctime)s][line:%(lineno)d][%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
batch_size = 128
|
||||
|
||||
data_process = Process(target=process_data, args=(queue_data, path, batch_size))
|
||||
data_process.start()
|
||||
|
||||
model_process_list = []
|
||||
for i in range(1):
|
||||
model_process = Process(target=process_by_model, args=(i, queue_data, queue_result))
|
||||
model_process.start()
|
||||
model_process_list.append(model_process)
|
||||
# time.sleep(20)
|
||||
|
||||
save_all = []
|
||||
last_time = time.time()
|
||||
while True:
|
||||
batch_data = queue_result.get(True)
|
||||
if batch_data is None:
|
||||
break
|
||||
for i in range(len(batch_data[0])):
|
||||
print('Label: {} Answer: {}'.format(batch_data[1][i], batch_data[2][i]))
|
||||
save_all.append([batch_data[1][i], batch_data[2][i]])
|
||||
|
||||
right_num = 0.0
|
||||
in_num = 0.0
|
||||
for label, answer in save_all:
|
||||
label = label.lower()
|
||||
answer = answer.lower()
|
||||
if label == answer or label == answer.split(' ')[0]:
|
||||
right_num += 1
|
||||
if label in answer.split(' ') or label in answer or label in answer.replace(' ', '').replace('\'', ''):
|
||||
in_num += 1
|
||||
else:
|
||||
print('[error] Label: {} Answer: {}'.format(label, answer))
|
||||
print(right_num / len(save_all), right_num, len(save_all))
|
||||
print('in', in_num / len(save_all), in_num, len(save_all))
|
Reference in New Issue
Block a user