add
This commit is contained in:
170
eval.py
170
eval.py
@@ -1,3 +1,6 @@
|
||||
import sys
|
||||
sys.path.append('./models/MiniGPT4')
|
||||
sys.path.append('./models/mPLUG_Owl')
|
||||
import argparse
|
||||
#from models.BLIP2.BLIP2 import BLIP2
|
||||
import more_itertools
|
||||
@@ -6,17 +9,26 @@ import datetime
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset
|
||||
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
|
||||
from datasets.ocr_dataset import ocrDataset
|
||||
from datasets.kie_dataset import SROIEDataset,FUNSDDataset
|
||||
from datasets.formula_dataset import HMEDataset
|
||||
from models.lavis.lavis import lavis
|
||||
from models.LLaVA.LLaVA import LLaVA
|
||||
from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
|
||||
from models.MiniGPT4.MiniGPT4 import MiniGPT4
|
||||
import torch
|
||||
import numpy as np
|
||||
def get_model(args):
|
||||
if args.model_name=='BLIP2':
|
||||
#model = BLIP2(args.BLIP2_model_path, args.device)
|
||||
model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
|
||||
#elif args.model_name=='mPLUG-Owl':
|
||||
# model =
|
||||
elif args.model_name=='LLaVA':
|
||||
model = LLaVA(args.LLaVA_model_path, args.device)
|
||||
elif args.model_name=='MiniGPT4':
|
||||
model = MiniGPT4(args, args.device)
|
||||
elif args.model_name=='mPLUG':
|
||||
model = mPLUG(args.mPLUG_model_name, args.device)
|
||||
return model
|
||||
def has_word(sentence, word):
|
||||
pattern = r"\b" + re.escape(word) + r"\b"
|
||||
@@ -217,7 +229,7 @@ class VQAEval:
|
||||
gt_answers = gt_answers.strip()
|
||||
gt_answers = self.processPunctuation(gt_answers)
|
||||
gt_answers = self.processDigitArticle(gt_answers)
|
||||
if has_word(answer, gt_answers[i]):
|
||||
if has_word(answer, gt_answers):
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
@@ -324,13 +336,54 @@ def evaluate_OCR(
|
||||
num+=1
|
||||
print(f'{dataset_name}:{float(correct)/num}')
|
||||
return float(correct)/num
|
||||
|
||||
|
||||
def evaluate_Formula(
|
||||
model,
|
||||
dataset,
|
||||
model_name,
|
||||
dataset_name,
|
||||
time,
|
||||
question='Please write out the expression of the formula in the image using LaTeX format.',
|
||||
batch_size=1,
|
||||
answer_path='./answers'
|
||||
):
|
||||
#Please write out the expression of the formula in the image using LaTeX format.
|
||||
predictions=[]
|
||||
for batch in more_itertools.chunked(
|
||||
tqdm(dataset, desc="Running inference"), batch_size
|
||||
):
|
||||
batch = batch[0]
|
||||
output = model.generate(image=batch['image_path'], question=question)
|
||||
answer_dict={'question':question, 'answer':output,
|
||||
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
||||
'model_name':model_name}
|
||||
predictions.append(answer_dict)
|
||||
answer_dir = os.path.join(answer_path, time)
|
||||
os.makedirs(answer_dir, exist_ok=True)
|
||||
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
||||
with open(answer_path, "w") as f:
|
||||
f.write(json.dumps(predictions, indent=4))
|
||||
correct = 0
|
||||
num = 0
|
||||
with open(answer_path, 'r') as f:
|
||||
dict = json.load(f)
|
||||
for i in range(len(dict)):
|
||||
gt_answers = re.sub(r'\s+', '', dict[i]['gt_answers'])
|
||||
answer = re.sub(r'\s+', '', dict[i]['answer'])
|
||||
if gt_answers in answer:
|
||||
correct+=1
|
||||
num+=1
|
||||
print(f'{dataset_name}:{float(correct)/num}')
|
||||
return float(correct)/num
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
#OCR datasets
|
||||
parser.add_argument("--ocr_dir_path", type=str, default="./data")
|
||||
parser.add_argument("--ocr_dataset_name", type=str, default="IIIT5K svt IC13_857 IC15_1811 svtp ct80 cocotext ctw totaltext HOST WOST WordArt")
|
||||
#HME100k
|
||||
parser.add_argument("--HME_image_dir_path", type=str, default="./data/HME100K/test_images")
|
||||
parser.add_argument("--HME_ann_path", type=str, default="./data/HME100K/test_labels.txt")
|
||||
#textVQA
|
||||
parser.add_argument("--textVQA_image_dir_path", type=str, default="./data/textVQA/train_images")
|
||||
parser.add_argument("--textVQA_ann_path", type=str, default="./data/textVQA/TextVQA_0.5.1_val.json")
|
||||
@@ -346,6 +399,16 @@ def parse_args():
|
||||
#STVQA
|
||||
parser.add_argument("--STVQA_image_dir_path", type=str, default="./data/STVQA")
|
||||
parser.add_argument("--STVQA_ann_path", type=str, default="./data/STVQA/train_task_3.json")
|
||||
#ESTVQA
|
||||
parser.add_argument("--ESTVQA_image_dir_path", type=str, default="./data/ESTVQA/images/train")
|
||||
parser.add_argument("--ESTVQA_CN_ann_path", type=str, default="./data/ESTVQA/annotations/cn_train.json")
|
||||
parser.add_argument("--ESTVQA_EN_ann_path", type=str, default="./data/ESTVQA/annotations/en_train.json")
|
||||
|
||||
#SROIE
|
||||
parser.add_argument("--SROIE_dir_path", type=str, default="./data/SROIE")
|
||||
|
||||
#FUNSD
|
||||
parser.add_argument("--FUNSD_dir_path", type=str, default="./data/FUNSD/testing_data/annotations")
|
||||
|
||||
#result_path
|
||||
parser.add_argument("--answer_path", type=str, default="./answers")
|
||||
@@ -374,20 +437,62 @@ def parse_args():
|
||||
default=False,
|
||||
help="Whether to evaluate on STVQA."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_ESTVQA_CN",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on ESTVQA_CN."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_ESTVQA_EN",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on ESTVQA_EN."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_SROIE",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on SROIE."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_FUNSD",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on FUNSD."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_HME",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on HME100k."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_ocr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on ocr."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate all datasets"
|
||||
)
|
||||
#BLIP2
|
||||
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
|
||||
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
|
||||
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
|
||||
|
||||
#LLaVA
|
||||
parser.add_argument("--LLaVA_model_path", type=str, default="./models/LLaVA/model_weight")
|
||||
#miniGPT4
|
||||
parser.add_argument("--MiniGPT4_cfg_path", type=str, default="./models/MiniGPT4/eval_configs/minigpt4_eval.yaml")
|
||||
#mPLUG
|
||||
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
|
||||
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
|
||||
|
||||
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
|
||||
parser.add_argument("--device", type=str, default="cuda:2")
|
||||
parser.add_argument("--device", type=str, default="cuda:1")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -400,34 +505,57 @@ def main(args):
|
||||
ocr_dataset_name = args.ocr_dataset_name.split()
|
||||
result = {}
|
||||
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
if args.eval_textVQA:
|
||||
if args.eval_textVQA or args.eval_all:
|
||||
dataset = textVQADataset(args.textVQA_image_dir_path, args.textVQA_ann_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time)
|
||||
result['textVQA'] = acc
|
||||
if args.eval_docVQA:
|
||||
if args.eval_docVQA or args.eval_all:
|
||||
dataset = docVQADataset(args.docVQA_image_dir_path, args.docVQA_ann_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time)
|
||||
result['docVQA'] = acc
|
||||
|
||||
if args.eval_ocrVQA:
|
||||
#Due to network issues, it's difficult to download the entire OCR-VQA dataset.
|
||||
# Therefore, we will extract the first 5000 questions for testing.
|
||||
if args.eval_ocrVQA or args.eval_all:
|
||||
dataset = ocrVQADataset(args.ocrVQA_image_dir_path, args.ocrVQA_ann_path)
|
||||
random_indices = np.random.choice(
|
||||
len(dataset), max_sample_num, replace=False
|
||||
)
|
||||
dataset = torch.utils.data.Subset(dataset,random_indices)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'ocrVQA', time)
|
||||
result['ocrVQA'] = acc
|
||||
|
||||
if args.eval_STVQA:
|
||||
if args.eval_STVQA or args.eval_all:
|
||||
dataset = STVQADataset(args.STVQA_image_dir_path, args.STVQA_ann_path)
|
||||
random_indices = np.random.choice(
|
||||
len(dataset), max_sample_num, replace=False
|
||||
)
|
||||
dataset = torch.utils.data.Subset(dataset,random_indices)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'STVQA', time)
|
||||
result['STVQA'] = acc
|
||||
|
||||
if args.eval_ocr:
|
||||
if args.eval_ESTVQA_CN or args.eval_all:
|
||||
dataset = ESTVQADataset(args.ESTVQA_image_dir_path, args.ESTVQA_CN_ann_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'ESTVQA_CN', time)
|
||||
result['ESTVQA_CN'] = acc
|
||||
|
||||
if args.eval_ESTVQA_EN or args.eval_all:
|
||||
dataset = ESTVQADataset(args.ESTVQA_image_dir_path, args.ESTVQA_EN_ann_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'ESTVQA_EN', time)
|
||||
result['ESTVQA_EN'] = acc
|
||||
|
||||
if args.eval_SROIE or args.eval_all:
|
||||
dataset = SROIEDataset(args.SROIE_dir_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'SROIE', time)
|
||||
result['SROIE'] = acc
|
||||
|
||||
if args.eval_FUNSD or args.eval_all:
|
||||
dataset = FUNSDDataset(args.FUNSD_dir_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'FUNSD', time)
|
||||
result['FUNSD'] = acc
|
||||
|
||||
if args.eval_HME or args.eval_all:
|
||||
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_Formula(model, dataset, args.model_name, 'HME', time)
|
||||
result['HME'] = acc
|
||||
|
||||
if args.eval_ocr or args.eval_all:
|
||||
for i in range(len(ocr_dataset_name)):
|
||||
dataset = ocrDataset(args.ocr_dir_path, ocr_dataset_name[i])
|
||||
acc = evaluate_OCR(model, dataset, args.model_name, ocr_dataset_name[i], time)
|
||||
|
Reference in New Issue
Block a user