update
This commit is contained in:
27
eval.py
27
eval.py
@@ -11,17 +11,19 @@ import json
|
||||
import re
|
||||
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
|
||||
from datasets.ocr_dataset import ocrDataset
|
||||
from datasets.kie_dataset import SROIEDataset,FUNSDDataset
|
||||
from datasets.kie_dataset import SROIEDataset,FUNSDDataset,POIEDataset
|
||||
from datasets.formula_dataset import HMEDataset
|
||||
from models.lavis.lavis import lavis
|
||||
from models.LLaVA.LLaVA import LLaVA
|
||||
from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
|
||||
from models.MiniGPT4.MiniGPT4 import MiniGPT4
|
||||
from models.OpenFlamingo.OpenFlamingo import OpenFlamingo
|
||||
from models.BLIP2.BLIP2 import BLIP2
|
||||
import torch
|
||||
import numpy as np
|
||||
def get_model(args):
|
||||
if args.model_name=='BLIP2':
|
||||
#model = BLIP2(args.BLIP2_model_path, args.device)
|
||||
#model = BLIP2(args.BLIP2_model_name, args.device)
|
||||
model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
|
||||
elif args.model_name=='LLaVA':
|
||||
model = LLaVA(args.LLaVA_model_path, args.device)
|
||||
@@ -29,6 +31,8 @@ def get_model(args):
|
||||
model = MiniGPT4(args, args.device)
|
||||
elif args.model_name=='mPLUG':
|
||||
model = mPLUG(args.mPLUG_model_name, args.device)
|
||||
elif args.model_name=='OpenFlamingo':
|
||||
model = OpenFlamingo(args.llama_path, args.check_point, args.device)
|
||||
return model
|
||||
def has_word(sentence, word):
|
||||
pattern = r"\b" + re.escape(word) + r"\b"
|
||||
@@ -410,6 +414,9 @@ def parse_args():
|
||||
#FUNSD
|
||||
parser.add_argument("--FUNSD_dir_path", type=str, default="./data/FUNSD/testing_data/annotations")
|
||||
|
||||
#POIE
|
||||
parser.add_argument("--POIE_dir_path", type=str, default="./data/POIE/test.txt")
|
||||
|
||||
#result_path
|
||||
parser.add_argument("--answer_path", type=str, default="./answers")
|
||||
|
||||
@@ -461,6 +468,12 @@ def parse_args():
|
||||
default=False,
|
||||
help="Whether to evaluate on FUNSD."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_POIE",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on POIE."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_HME",
|
||||
action="store_true",
|
||||
@@ -480,7 +493,6 @@ def parse_args():
|
||||
help="Whether to evaluate all datasets"
|
||||
)
|
||||
#BLIP2
|
||||
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
|
||||
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
|
||||
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
|
||||
#LLaVA
|
||||
@@ -490,9 +502,12 @@ def parse_args():
|
||||
#mPLUG
|
||||
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
|
||||
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
|
||||
#OpenFlamingo
|
||||
parser.add_argument("--llama_path", type=str, default="llama-7b_path")
|
||||
parser.add_argument("--check_point", type=str, default="open_flamingo/checkpoint/checkpoint.pt")
|
||||
|
||||
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
|
||||
parser.add_argument("--device", type=str, default="cuda:1")
|
||||
parser.add_argument("--device", type=str, default="cuda:3")#2,3,7
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -548,6 +563,10 @@ def main(args):
|
||||
dataset = FUNSDDataset(args.FUNSD_dir_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'FUNSD', time)
|
||||
result['FUNSD'] = acc
|
||||
if args.eval_POIE or args.eval_all:
|
||||
dataset = POIEDataset(args.POIE_dir_path)
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'POIE', time)
|
||||
result['POIE'] = acc
|
||||
|
||||
if args.eval_HME or args.eval_all:
|
||||
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)
|
||||
|
Reference in New Issue
Block a user