659 lines
24 KiB
Python
Executable File
659 lines
24 KiB
Python
Executable File
import sys
|
|
sys.path.append('./models/MiniGPT4')
|
|
sys.path.append('./models/mPLUG_Owl')
|
|
import argparse
|
|
#from models.BLIP2.BLIP2 import BLIP2
|
|
import more_itertools
|
|
from tqdm import tqdm
|
|
import datetime
|
|
import os
|
|
import json
|
|
import re
|
|
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
|
|
from datasets.ocr_dataset import ocrDataset, IAMDataset, ReCTSDataset
|
|
from datasets.kie_dataset import SROIEDataset,FUNSDDataset,POIEDataset
|
|
from datasets.formula_dataset import HMEDataset
|
|
from models.lavis.lavis import lavis
|
|
from models.LLaVA.LLaVA import LLaVA
|
|
from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
|
|
from models.MiniGPT4.MiniGPT4 import MiniGPT4
|
|
from models.OpenFlamingo.OpenFlamingo import OpenFlamingo
|
|
from models.BLIP2.BLIP2 import BLIP2
|
|
from models.InstructBLIP.InstructBLIP import InstructBLIP
|
|
import torch
|
|
import numpy as np
|
|
def get_model(args):
|
|
if args.model_name=='BLIP2':
|
|
model = BLIP2("/home/zhangli/.cache/huggingface/hub/models--Salesforce--blip2-opt-6.7b/snapshots/f998da12f28eb37d7e7f080cfe3291d6d9d7e1fb", args.device)
|
|
#model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
|
|
elif args.model_name=='LLaVA':
|
|
model = LLaVA(args.LLaVA_model_path, args.device)
|
|
elif args.model_name=='MiniGPT4':
|
|
model = MiniGPT4(args, args.device)
|
|
elif args.model_name=='mPLUG':
|
|
model = mPLUG(args.mPLUG_model_name, args.device)
|
|
elif args.model_name=='OpenFlamingo':
|
|
model = OpenFlamingo(args.llama_path, args.check_point, args.device)
|
|
elif args.model_name == 'instructblip':
|
|
model = InstructBLIP('blip2_vicuna_instruct',args.device)
|
|
return model
|
|
def has_word(sentence, word):
|
|
pattern = r"\b" + re.escape(word) + r"\b"
|
|
match = re.search(pattern, sentence)
|
|
if match:
|
|
return True
|
|
else:
|
|
return False
|
|
def remove_special_chars(s):
|
|
pattern = r"[^a-zA-Z0-9\s]"
|
|
s = re.sub(pattern, "", s)
|
|
return s
|
|
|
|
class VQAEval:
|
|
def __init__(self):
|
|
self.contractions = {
|
|
"aint": "ain't",
|
|
"arent": "aren't",
|
|
"cant": "can't",
|
|
"couldve": "could've",
|
|
"couldnt": "couldn't",
|
|
"couldn'tve": "couldn't've",
|
|
"couldnt've": "couldn't've",
|
|
"didnt": "didn't",
|
|
"doesnt": "doesn't",
|
|
"dont": "don't",
|
|
"hadnt": "hadn't",
|
|
"hadnt've": "hadn't've",
|
|
"hadn'tve": "hadn't've",
|
|
"hasnt": "hasn't",
|
|
"havent": "haven't",
|
|
"hed": "he'd",
|
|
"hed've": "he'd've",
|
|
"he'dve": "he'd've",
|
|
"hes": "he's",
|
|
"howd": "how'd",
|
|
"howll": "how'll",
|
|
"hows": "how's",
|
|
"Id've": "I'd've",
|
|
"I'dve": "I'd've",
|
|
"Im": "I'm",
|
|
"Ive": "I've",
|
|
"isnt": "isn't",
|
|
"itd": "it'd",
|
|
"itd've": "it'd've",
|
|
"it'dve": "it'd've",
|
|
"itll": "it'll",
|
|
"let's": "let's",
|
|
"maam": "ma'am",
|
|
"mightnt": "mightn't",
|
|
"mightnt've": "mightn't've",
|
|
"mightn'tve": "mightn't've",
|
|
"mightve": "might've",
|
|
"mustnt": "mustn't",
|
|
"mustve": "must've",
|
|
"neednt": "needn't",
|
|
"notve": "not've",
|
|
"oclock": "o'clock",
|
|
"oughtnt": "oughtn't",
|
|
"ow's'at": "'ow's'at",
|
|
"'ows'at": "'ow's'at",
|
|
"'ow'sat": "'ow's'at",
|
|
"shant": "shan't",
|
|
"shed've": "she'd've",
|
|
"she'dve": "she'd've",
|
|
"she's": "she's",
|
|
"shouldve": "should've",
|
|
"shouldnt": "shouldn't",
|
|
"shouldnt've": "shouldn't've",
|
|
"shouldn'tve": "shouldn't've",
|
|
"somebody'd": "somebodyd",
|
|
"somebodyd've": "somebody'd've",
|
|
"somebody'dve": "somebody'd've",
|
|
"somebodyll": "somebody'll",
|
|
"somebodys": "somebody's",
|
|
"someoned": "someone'd",
|
|
"someoned've": "someone'd've",
|
|
"someone'dve": "someone'd've",
|
|
"someonell": "someone'll",
|
|
"someones": "someone's",
|
|
"somethingd": "something'd",
|
|
"somethingd've": "something'd've",
|
|
"something'dve": "something'd've",
|
|
"somethingll": "something'll",
|
|
"thats": "that's",
|
|
"thered": "there'd",
|
|
"thered've": "there'd've",
|
|
"there'dve": "there'd've",
|
|
"therere": "there're",
|
|
"theres": "there's",
|
|
"theyd": "they'd",
|
|
"theyd've": "they'd've",
|
|
"they'dve": "they'd've",
|
|
"theyll": "they'll",
|
|
"theyre": "they're",
|
|
"theyve": "they've",
|
|
"twas": "'twas",
|
|
"wasnt": "wasn't",
|
|
"wed've": "we'd've",
|
|
"we'dve": "we'd've",
|
|
"weve": "we've",
|
|
"werent": "weren't",
|
|
"whatll": "what'll",
|
|
"whatre": "what're",
|
|
"whats": "what's",
|
|
"whatve": "what've",
|
|
"whens": "when's",
|
|
"whered": "where'd",
|
|
"wheres": "where's",
|
|
"whereve": "where've",
|
|
"whod": "who'd",
|
|
"whod've": "who'd've",
|
|
"who'dve": "who'd've",
|
|
"wholl": "who'll",
|
|
"whos": "who's",
|
|
"whove": "who've",
|
|
"whyll": "why'll",
|
|
"whyre": "why're",
|
|
"whys": "why's",
|
|
"wont": "won't",
|
|
"wouldve": "would've",
|
|
"wouldnt": "wouldn't",
|
|
"wouldnt've": "wouldn't've",
|
|
"wouldn'tve": "wouldn't've",
|
|
"yall": "y'all",
|
|
"yall'll": "y'all'll",
|
|
"y'allll": "y'all'll",
|
|
"yall'd've": "y'all'd've",
|
|
"y'alld've": "y'all'd've",
|
|
"y'all'dve": "y'all'd've",
|
|
"youd": "you'd",
|
|
"youd've": "you'd've",
|
|
"you'dve": "you'd've",
|
|
"youll": "you'll",
|
|
"youre": "you're",
|
|
"youve": "you've",
|
|
}
|
|
self.manualMap = {
|
|
"none": "0",
|
|
"zero": "0",
|
|
"one": "1",
|
|
"two": "2",
|
|
"three": "3",
|
|
"four": "4",
|
|
"five": "5",
|
|
"six": "6",
|
|
"seven": "7",
|
|
"eight": "8",
|
|
"nine": "9",
|
|
"ten": "10",
|
|
}
|
|
self.articles = ["a", "an", "the"]
|
|
|
|
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
|
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
|
self.punct = [
|
|
";",
|
|
r"/",
|
|
"[",
|
|
"]",
|
|
'"',
|
|
"{",
|
|
"}",
|
|
"(",
|
|
")",
|
|
"=",
|
|
"+",
|
|
"\\",
|
|
"_",
|
|
"-",
|
|
">",
|
|
"<",
|
|
"@",
|
|
"`",
|
|
",",
|
|
"?",
|
|
"!",
|
|
]
|
|
|
|
def evaluate(self, answer, gt_answers):
|
|
answer = answer.replace("\n", " ")
|
|
answer = answer.replace("\t", " ")
|
|
answer = answer.strip()
|
|
answer = self.processPunctuation(answer)
|
|
answer = self.processDigitArticle(answer)
|
|
if type(gt_answers)==list:
|
|
for i in range(len(gt_answers)):
|
|
gt_answers[i] = gt_answers[i].replace("\n", " ")
|
|
gt_answers[i] = gt_answers[i].replace("\t", " ")
|
|
gt_answers[i] = gt_answers[i].strip()
|
|
gt_answers[i] = self.processPunctuation(gt_answers[i])
|
|
gt_answers[i] = self.processDigitArticle(gt_answers[i])
|
|
if has_word(answer, gt_answers[i]):
|
|
return 1
|
|
return 0
|
|
else:
|
|
gt_answers = gt_answers.replace("\n", " ")
|
|
gt_answers= gt_answers.replace("\t", " ")
|
|
gt_answers = gt_answers.strip()
|
|
gt_answers = self.processPunctuation(gt_answers)
|
|
gt_answers = self.processDigitArticle(gt_answers)
|
|
if has_word(answer, gt_answers):
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
def processPunctuation(self, inText):
|
|
outText = inText
|
|
for p in self.punct:
|
|
if (p + " " in inText or " " + p in inText) or (
|
|
re.search(self.commaStrip, inText) != None
|
|
):
|
|
outText = outText.replace(p, "")
|
|
else:
|
|
outText = outText.replace(p, " ")
|
|
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
|
return outText
|
|
|
|
def processDigitArticle(self, inText):
|
|
outText = []
|
|
tempText = inText.lower().split()
|
|
for word in tempText:
|
|
word = self.manualMap.setdefault(word, word)
|
|
if word not in self.articles:
|
|
outText.append(word)
|
|
else:
|
|
pass
|
|
for wordId, word in enumerate(outText):
|
|
if word in self.contractions:
|
|
outText[wordId] = self.contractions[word]
|
|
outText = " ".join(outText)
|
|
return outText
|
|
def evaluate_VQA(
|
|
model,
|
|
dataset,
|
|
model_name,
|
|
dataset_name,
|
|
time,
|
|
batch_size=1,
|
|
answer_path='./answers'
|
|
):
|
|
predictions=[]
|
|
for batch in more_itertools.chunked(
|
|
tqdm(dataset, desc="Running inference"), batch_size
|
|
):
|
|
batch = batch[0]
|
|
output = model.generate(image=batch['image_path'], question=batch['question'])
|
|
answer_dict={'question':batch['question'], 'answer':output,
|
|
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
|
'model_name':model_name}
|
|
predictions.append(answer_dict)
|
|
answer_dir = os.path.join(answer_path, time)
|
|
os.makedirs(answer_dir, exist_ok=True)
|
|
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
|
with open(answer_path, "w") as f:
|
|
f.write(json.dumps(predictions, indent=4))
|
|
eval = VQAEval()
|
|
correct = 0
|
|
num = 0
|
|
with open(answer_path, 'r') as f:
|
|
dict = json.load(f)
|
|
for i in range(len(dict)):
|
|
gt_answers = dict[i]['gt_answers']
|
|
answer = dict[i]['answer']
|
|
if eval.evaluate(answer,gt_answers)==1:
|
|
correct+=1
|
|
num+=1
|
|
print(f'{dataset_name}:{float(correct)/num}')
|
|
return float(correct)/num
|
|
def evaluate_OCR(
|
|
model,
|
|
dataset,
|
|
model_name,
|
|
dataset_name,
|
|
time,
|
|
question='what is written in the image?',
|
|
batch_size=1,
|
|
answer_path='./answers'
|
|
):
|
|
predictions=[]
|
|
for batch in more_itertools.chunked(
|
|
tqdm(dataset, desc="Running inference"), batch_size
|
|
):
|
|
batch = batch[0]
|
|
output = model.generate(image=batch['image_path'], question=question)
|
|
answer_dict={'question':question, 'answer':output,
|
|
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
|
'model_name':model_name}
|
|
predictions.append(answer_dict)
|
|
answer_dir = os.path.join(answer_path, time)
|
|
os.makedirs(answer_dir, exist_ok=True)
|
|
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
|
with open(answer_path, "w") as f:
|
|
f.write(json.dumps(predictions, indent=4))
|
|
correct = 0
|
|
num = 0
|
|
with open(answer_path, 'r') as f:
|
|
dict = json.load(f)
|
|
for i in range(len(dict)):
|
|
gt_answers = dict[i]['gt_answers']
|
|
answer = dict[i]['answer']
|
|
gt_answers = remove_special_chars(gt_answers).lower()
|
|
answer = remove_special_chars(answer).lower()
|
|
if has_word(answer, gt_answers):
|
|
correct+=1
|
|
num+=1
|
|
print(f'{dataset_name}:{float(correct)/num}')
|
|
return float(correct)/num
|
|
|
|
def evaluate_ReCTS(
|
|
model,
|
|
dataset,
|
|
model_name,
|
|
dataset_name,
|
|
time,
|
|
#question='图像中的中文是什么?',
|
|
question = 'What are the Chinese characters in the image?',
|
|
batch_size=1,
|
|
answer_path='./answers'
|
|
):
|
|
predictions=[]
|
|
for batch in more_itertools.chunked(
|
|
tqdm(dataset, desc="Running inference"), batch_size
|
|
):
|
|
batch = batch[0]
|
|
output = model.generate(image=batch['image_path'], question=question)
|
|
answer_dict={'question':question, 'answer':output,
|
|
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
|
'model_name':model_name}
|
|
predictions.append(answer_dict)
|
|
answer_dir = os.path.join(answer_path, time)
|
|
os.makedirs(answer_dir, exist_ok=True)
|
|
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
|
with open(answer_path, "w") as f:
|
|
f.write(json.dumps(predictions, indent=4))
|
|
correct = 0
|
|
num = 0
|
|
with open(answer_path, 'r') as f:
|
|
dict = json.load(f)
|
|
for i in range(len(dict)):
|
|
gt_answers = dict[i]['gt_answers']
|
|
answer = dict[i]['answer']
|
|
gt_answers = re.sub(r'[^\u4e00-\u9fa5\s]+', '', gt_answers)
|
|
answer = re.sub(r'[^\u4e00-\u9fa5\s]+', '', answer)
|
|
if gt_answers in answer:
|
|
correct+=1
|
|
num+=1
|
|
print(f'{dataset_name}:{float(correct)/num}')
|
|
return float(correct)/num
|
|
|
|
def evaluate_Formula(
|
|
model,
|
|
dataset,
|
|
model_name,
|
|
dataset_name,
|
|
time,
|
|
question='Please write out the expression of the formula in the image using LaTeX format.',
|
|
batch_size=1,
|
|
answer_path='./answers'
|
|
):
|
|
#Please write out the expression of the formula in the image using LaTeX format.
|
|
predictions=[]
|
|
for batch in more_itertools.chunked(
|
|
tqdm(dataset, desc="Running inference"), batch_size
|
|
):
|
|
batch = batch[0]
|
|
output = model.generate(image=batch['image_path'], question=question)
|
|
answer_dict={'question':question, 'answer':output,
|
|
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
|
'model_name':model_name}
|
|
predictions.append(answer_dict)
|
|
answer_dir = os.path.join(answer_path, time)
|
|
os.makedirs(answer_dir, exist_ok=True)
|
|
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
|
with open(answer_path, "w") as f:
|
|
f.write(json.dumps(predictions, indent=4))
|
|
correct = 0
|
|
num = 0
|
|
with open(answer_path, 'r') as f:
|
|
dict = json.load(f)
|
|
for i in range(len(dict)):
|
|
gt_answers = re.sub(r'\s+', '', dict[i]['gt_answers'])
|
|
answer = re.sub(r'\s+', '', dict[i]['answer'])
|
|
if gt_answers in answer:
|
|
correct+=1
|
|
num+=1
|
|
print(f'{dataset_name}:{float(correct)/num}')
|
|
return float(correct)/num
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Demo")
|
|
#OCR datasets
|
|
parser.add_argument("--ocr_dir_path", type=str, default="./data")
|
|
parser.add_argument("--ocr_dataset_name", type=str, default="IIIT5K svt IC13_857 IC15_1811 svtp ct80 cocotext ctw totaltext HOST WOST WordArt")
|
|
#IAM
|
|
parser.add_argument("--IAM_dir_path", type=str, default="./data/IAM")
|
|
#ReCTS
|
|
parser.add_argument("--ReCTS_dir_path", type=str, default="./data/ReCTS")
|
|
|
|
#HME100k
|
|
parser.add_argument("--HME_image_dir_path", type=str, default="./data/HME100K/test_images")
|
|
parser.add_argument("--HME_ann_path", type=str, default="./data/HME100K/test_labels.txt")
|
|
#textVQA
|
|
parser.add_argument("--textVQA_image_dir_path", type=str, default="./data/textVQA/train_images")
|
|
parser.add_argument("--textVQA_ann_path", type=str, default="./data/textVQA/TextVQA_0.5.1_val.json")
|
|
|
|
#docVQA
|
|
parser.add_argument("--docVQA_image_dir_path", type=str, default="./data/docVQA/val")
|
|
parser.add_argument("--docVQA_ann_path", type=str, default="./data/docVQA/val/val_v1.0.json")
|
|
|
|
#ocrVQA
|
|
parser.add_argument("--ocrVQA_image_dir_path", type=str, default="./data/ocrVQA/images")
|
|
parser.add_argument("--ocrVQA_ann_path", type=str, default="./data/ocrVQA/dataset.json")
|
|
|
|
#STVQA
|
|
parser.add_argument("--STVQA_image_dir_path", type=str, default="./data/STVQA")
|
|
parser.add_argument("--STVQA_ann_path", type=str, default="./data/STVQA/train_task_3.json")
|
|
#ESTVQA
|
|
parser.add_argument("--ESTVQA_image_dir_path", type=str, default="./data/ESTVQA/images/train")
|
|
parser.add_argument("--ESTVQA_CN_ann_path", type=str, default="./data/ESTVQA/annotations/cn_train.json")
|
|
parser.add_argument("--ESTVQA_EN_ann_path", type=str, default="./data/ESTVQA/annotations/en_train.json")
|
|
|
|
#SROIE
|
|
parser.add_argument("--SROIE_dir_path", type=str, default="./data/SROIE")
|
|
|
|
#FUNSD
|
|
parser.add_argument("--FUNSD_dir_path", type=str, default="./data/FUNSD/testing_data/annotations")
|
|
|
|
#POIE
|
|
parser.add_argument("--POIE_dir_path", type=str, default="./data/POIE/test.txt")
|
|
|
|
#result_path
|
|
parser.add_argument("--answer_path", type=str, default="./answers")
|
|
|
|
parser.add_argument(
|
|
"--eval_textVQA",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on textVQA."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_docVQA",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on docVQA."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_ocrVQA",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on ocrVQA."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_STVQA",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on STVQA."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_ESTVQA_CN",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on ESTVQA_CN."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_ESTVQA_EN",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on ESTVQA_EN."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_SROIE",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on SROIE."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_FUNSD",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on FUNSD."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_POIE",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on POIE."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_HME",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on HME100k."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_IAM",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on IAM (handwritten)."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_ReCTS",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on ReCTS (Chinese)."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_ocr",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate on ocr."
|
|
)
|
|
parser.add_argument(
|
|
"--eval_all",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to evaluate all datasets"
|
|
)
|
|
#BLIP2
|
|
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
|
|
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
|
|
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
|
|
#LLaVA
|
|
parser.add_argument("--LLaVA_model_path", type=str, default="./models/LLaVA/model_weight")
|
|
#miniGPT4
|
|
parser.add_argument("--MiniGPT4_cfg_path", type=str, default="./models/MiniGPT4/eval_configs/minigpt4_eval.yaml")
|
|
#mPLUG
|
|
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
|
|
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
|
|
#OpenFlamingo
|
|
parser.add_argument("--llama_path", type=str, default="/home/zhangli/llama_models/llama/llama-7b")
|
|
parser.add_argument("--check_point", type=str, default="/home/zhangli/code/open_flamingo/checkpoint/checkpoint.pt")
|
|
|
|
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
|
|
parser.add_argument("--device", type=str, default="cuda:0")#2,3,7
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def main(args):
|
|
np.random.seed(0)
|
|
max_sample_num = 5000
|
|
model = get_model(args)
|
|
'''ocr_dataset_name=['IIIT5K','svt','IC13_857','IC15_1811','svtp','ct80',
|
|
'cocotext','ctw','totaltext','HOST','WOST','WordArt']'''
|
|
ocr_dataset_name = args.ocr_dataset_name.split()
|
|
result = {}
|
|
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
if args.eval_textVQA or args.eval_all:
|
|
dataset = textVQADataset(args.textVQA_image_dir_path, args.textVQA_ann_path)
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time)
|
|
result['textVQA'] = acc
|
|
if args.eval_docVQA or args.eval_all:
|
|
dataset = docVQADataset(args.docVQA_image_dir_path, args.docVQA_ann_path)
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time)
|
|
result['docVQA'] = acc
|
|
#Due to network issues, it's difficult to download the entire OCR-VQA dataset.
|
|
# Therefore, we will extract the first 5000 questions for testing.
|
|
if args.eval_ocrVQA or args.eval_all:
|
|
dataset = ocrVQADataset(args.ocrVQA_image_dir_path, args.ocrVQA_ann_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'ocrVQA', time)
|
|
result['ocrVQA'] = acc
|
|
|
|
if args.eval_STVQA or args.eval_all:
|
|
dataset = STVQADataset(args.STVQA_image_dir_path, args.STVQA_ann_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'STVQA', time)
|
|
result['STVQA'] = acc
|
|
|
|
if args.eval_ESTVQA_CN or args.eval_all:
|
|
dataset = ESTVQADataset(args.ESTVQA_image_dir_path, args.ESTVQA_CN_ann_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'ESTVQA_CN', time)
|
|
result['ESTVQA_CN'] = acc
|
|
|
|
if args.eval_ESTVQA_EN or args.eval_all:
|
|
dataset = ESTVQADataset(args.ESTVQA_image_dir_path, args.ESTVQA_EN_ann_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'ESTVQA_EN', time)
|
|
result['ESTVQA_EN'] = acc
|
|
|
|
if args.eval_SROIE or args.eval_all:
|
|
dataset = SROIEDataset(args.SROIE_dir_path)
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'SROIE', time)
|
|
result['SROIE'] = acc
|
|
|
|
if args.eval_FUNSD or args.eval_all:
|
|
dataset = FUNSDDataset(args.FUNSD_dir_path)
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'FUNSD', time)
|
|
result['FUNSD'] = acc
|
|
if args.eval_POIE or args.eval_all:
|
|
dataset = POIEDataset(args.POIE_dir_path)
|
|
acc = evaluate_VQA(model, dataset, args.model_name, 'POIE', time)
|
|
result['POIE'] = acc
|
|
|
|
if args.eval_HME or args.eval_all:
|
|
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
|
acc = evaluate_Formula(model, dataset, args.model_name, 'HME', time)
|
|
result['HME'] = acc
|
|
if args.eval_IAM or args.eval_all:
|
|
dataset = IAMDataset(args.IAM_dir_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(3000))
|
|
acc = evaluate_OCR(model, dataset, args.model_name, 'IAM', time)
|
|
result['IAM'] = acc
|
|
if args.eval_ReCTS or args.eval_all:
|
|
dataset = ReCTSDataset(args.ReCTS_dir_path)
|
|
dataset = torch.utils.data.Subset(dataset, range(3000))
|
|
acc = evaluate_ReCTS(model, dataset, args.model_name, 'ReCTS', time)
|
|
result['ReCTS'] = acc
|
|
if args.eval_ocr or args.eval_all:
|
|
for i in range(len(ocr_dataset_name)):
|
|
dataset = ocrDataset(args.ocr_dir_path, ocr_dataset_name[i])
|
|
acc = evaluate_OCR(model, dataset, args.model_name, ocr_dataset_name[i], time)
|
|
result[ocr_dataset_name[i]] = acc
|
|
result_path = os.path.join(os.path.join(args.answer_path, time), 'result.json')
|
|
with open(result_path, "w") as f:
|
|
f.write(json.dumps(result, indent=4))
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args) |