This commit is contained in:
echo840
2023-05-27 17:21:39 +08:00
parent b388fba03e
commit 6e02bedd46
450 changed files with 1148092 additions and 38254 deletions

27
eval.py
View File

@@ -11,17 +11,19 @@ import json
import re
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
from datasets.ocr_dataset import ocrDataset
from datasets.kie_dataset import SROIEDataset,FUNSDDataset
from datasets.kie_dataset import SROIEDataset,FUNSDDataset,POIEDataset
from datasets.formula_dataset import HMEDataset
from models.lavis.lavis import lavis
from models.LLaVA.LLaVA import LLaVA
from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
from models.MiniGPT4.MiniGPT4 import MiniGPT4
from models.OpenFlamingo.OpenFlamingo import OpenFlamingo
from models.BLIP2.BLIP2 import BLIP2
import torch
import numpy as np
def get_model(args):
if args.model_name=='BLIP2':
#model = BLIP2(args.BLIP2_model_path, args.device)
#model = BLIP2(args.BLIP2_model_name, args.device)
model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
elif args.model_name=='LLaVA':
model = LLaVA(args.LLaVA_model_path, args.device)
@@ -29,6 +31,8 @@ def get_model(args):
model = MiniGPT4(args, args.device)
elif args.model_name=='mPLUG':
model = mPLUG(args.mPLUG_model_name, args.device)
elif args.model_name=='OpenFlamingo':
model = OpenFlamingo(args.llama_path, args.check_point, args.device)
return model
def has_word(sentence, word):
pattern = r"\b" + re.escape(word) + r"\b"
@@ -410,6 +414,9 @@ def parse_args():
#FUNSD
parser.add_argument("--FUNSD_dir_path", type=str, default="./data/FUNSD/testing_data/annotations")
#POIE
parser.add_argument("--POIE_dir_path", type=str, default="./data/POIE/test.txt")
#result_path
parser.add_argument("--answer_path", type=str, default="./answers")
@@ -461,6 +468,12 @@ def parse_args():
default=False,
help="Whether to evaluate on FUNSD."
)
parser.add_argument(
"--eval_POIE",
action="store_true",
default=False,
help="Whether to evaluate on POIE."
)
parser.add_argument(
"--eval_HME",
action="store_true",
@@ -480,7 +493,6 @@ def parse_args():
help="Whether to evaluate all datasets"
)
#BLIP2
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
#LLaVA
@@ -490,9 +502,12 @@ def parse_args():
#mPLUG
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
#OpenFlamingo
parser.add_argument("--llama_path", type=str, default="llama-7b_path")
parser.add_argument("--check_point", type=str, default="open_flamingo/checkpoint/checkpoint.pt")
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
parser.add_argument("--device", type=str, default="cuda:1")
parser.add_argument("--device", type=str, default="cuda:3")#2,3,7
args = parser.parse_args()
return args
@@ -548,6 +563,10 @@ def main(args):
dataset = FUNSDDataset(args.FUNSD_dir_path)
acc = evaluate_VQA(model, dataset, args.model_name, 'FUNSD', time)
result['FUNSD'] = acc
if args.eval_POIE or args.eval_all:
dataset = POIEDataset(args.POIE_dir_path)
acc = evaluate_VQA(model, dataset, args.model_name, 'POIE', time)
result['POIE'] = acc
if args.eval_HME or args.eval_all:
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)