IAM ReCTS
This commit is contained in:
88
eval.py
Normal file → Executable file
88
eval.py
Normal file → Executable file
@@ -10,7 +10,7 @@ import os
|
||||
import json
|
||||
import re
|
||||
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
|
||||
from datasets.ocr_dataset import ocrDataset
|
||||
from datasets.ocr_dataset import ocrDataset, IAMDataset, ReCTSDataset
|
||||
from datasets.kie_dataset import SROIEDataset,FUNSDDataset,POIEDataset
|
||||
from datasets.formula_dataset import HMEDataset
|
||||
from models.lavis.lavis import lavis
|
||||
@@ -19,12 +19,13 @@ from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
|
||||
from models.MiniGPT4.MiniGPT4 import MiniGPT4
|
||||
from models.OpenFlamingo.OpenFlamingo import OpenFlamingo
|
||||
from models.BLIP2.BLIP2 import BLIP2
|
||||
from models.InstructBLIP.InstructBLIP import InstructBLIP
|
||||
import torch
|
||||
import numpy as np
|
||||
def get_model(args):
|
||||
if args.model_name=='BLIP2':
|
||||
#model = BLIP2(args.BLIP2_model_name, args.device)
|
||||
model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
|
||||
model = BLIP2("/home/zhangli/.cache/huggingface/hub/models--Salesforce--blip2-opt-6.7b/snapshots/f998da12f28eb37d7e7f080cfe3291d6d9d7e1fb", args.device)
|
||||
#model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
|
||||
elif args.model_name=='LLaVA':
|
||||
model = LLaVA(args.LLaVA_model_path, args.device)
|
||||
elif args.model_name=='MiniGPT4':
|
||||
@@ -33,6 +34,8 @@ def get_model(args):
|
||||
model = mPLUG(args.mPLUG_model_name, args.device)
|
||||
elif args.model_name=='OpenFlamingo':
|
||||
model = OpenFlamingo(args.llama_path, args.check_point, args.device)
|
||||
elif args.model_name == 'instructblip':
|
||||
model = InstructBLIP('blip2_vicuna_instruct',args.device)
|
||||
return model
|
||||
def has_word(sentence, word):
|
||||
pattern = r"\b" + re.escape(word) + r"\b"
|
||||
@@ -45,6 +48,7 @@ def remove_special_chars(s):
|
||||
pattern = r"[^a-zA-Z0-9\s]"
|
||||
s = re.sub(pattern, "", s)
|
||||
return s
|
||||
|
||||
class VQAEval:
|
||||
def __init__(self):
|
||||
self.contractions = {
|
||||
@@ -341,6 +345,47 @@ def evaluate_OCR(
|
||||
print(f'{dataset_name}:{float(correct)/num}')
|
||||
return float(correct)/num
|
||||
|
||||
def evaluate_ReCTS(
|
||||
model,
|
||||
dataset,
|
||||
model_name,
|
||||
dataset_name,
|
||||
time,
|
||||
#question='图像中的中文是什么?',
|
||||
question = 'What are the Chinese characters in the image?',
|
||||
batch_size=1,
|
||||
answer_path='./answers'
|
||||
):
|
||||
predictions=[]
|
||||
for batch in more_itertools.chunked(
|
||||
tqdm(dataset, desc="Running inference"), batch_size
|
||||
):
|
||||
batch = batch[0]
|
||||
output = model.generate(image=batch['image_path'], question=question)
|
||||
answer_dict={'question':question, 'answer':output,
|
||||
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
|
||||
'model_name':model_name}
|
||||
predictions.append(answer_dict)
|
||||
answer_dir = os.path.join(answer_path, time)
|
||||
os.makedirs(answer_dir, exist_ok=True)
|
||||
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
||||
with open(answer_path, "w") as f:
|
||||
f.write(json.dumps(predictions, indent=4))
|
||||
correct = 0
|
||||
num = 0
|
||||
with open(answer_path, 'r') as f:
|
||||
dict = json.load(f)
|
||||
for i in range(len(dict)):
|
||||
gt_answers = dict[i]['gt_answers']
|
||||
answer = dict[i]['answer']
|
||||
gt_answers = re.sub(r'[^\u4e00-\u9fa5\s]+', '', gt_answers)
|
||||
answer = re.sub(r'[^\u4e00-\u9fa5\s]+', '', answer)
|
||||
if gt_answers in answer:
|
||||
correct+=1
|
||||
num+=1
|
||||
print(f'{dataset_name}:{float(correct)/num}')
|
||||
return float(correct)/num
|
||||
|
||||
def evaluate_Formula(
|
||||
model,
|
||||
dataset,
|
||||
@@ -385,6 +430,11 @@ def parse_args():
|
||||
#OCR datasets
|
||||
parser.add_argument("--ocr_dir_path", type=str, default="./data")
|
||||
parser.add_argument("--ocr_dataset_name", type=str, default="IIIT5K svt IC13_857 IC15_1811 svtp ct80 cocotext ctw totaltext HOST WOST WordArt")
|
||||
#IAM
|
||||
parser.add_argument("--IAM_dir_path", type=str, default="./data/IAM")
|
||||
#ReCTS
|
||||
parser.add_argument("--ReCTS_dir_path", type=str, default="./data/ReCTS")
|
||||
|
||||
#HME100k
|
||||
parser.add_argument("--HME_image_dir_path", type=str, default="./data/HME100K/test_images")
|
||||
parser.add_argument("--HME_ann_path", type=str, default="./data/HME100K/test_labels.txt")
|
||||
@@ -480,6 +530,18 @@ def parse_args():
|
||||
default=False,
|
||||
help="Whether to evaluate on HME100k."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_IAM",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on IAM (handwritten)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_ReCTS",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on ReCTS (Chinese)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_ocr",
|
||||
action="store_true",
|
||||
@@ -493,6 +555,7 @@ def parse_args():
|
||||
help="Whether to evaluate all datasets"
|
||||
)
|
||||
#BLIP2
|
||||
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
|
||||
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
|
||||
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
|
||||
#LLaVA
|
||||
@@ -503,11 +566,11 @@ def parse_args():
|
||||
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
|
||||
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
|
||||
#OpenFlamingo
|
||||
parser.add_argument("--llama_path", type=str, default="llama-7b_path")
|
||||
parser.add_argument("--check_point", type=str, default="open_flamingo/checkpoint/checkpoint.pt")
|
||||
parser.add_argument("--llama_path", type=str, default="/home/zhangli/llama_models/llama/llama-7b")
|
||||
parser.add_argument("--check_point", type=str, default="/home/zhangli/code/open_flamingo/checkpoint/checkpoint.pt")
|
||||
|
||||
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
|
||||
parser.add_argument("--device", type=str, default="cuda:3")#2,3,7
|
||||
parser.add_argument("--device", type=str, default="cuda:0")#2,3,7
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -572,8 +635,17 @@ def main(args):
|
||||
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_Formula(model, dataset, args.model_name, 'HME', time)
|
||||
result['HME'] = acc
|
||||
|
||||
result['HME'] = acc
|
||||
if args.eval_IAM or args.eval_all:
|
||||
dataset = IAMDataset(args.IAM_dir_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(3000))
|
||||
acc = evaluate_OCR(model, dataset, args.model_name, 'IAM', time)
|
||||
result['IAM'] = acc
|
||||
if args.eval_ReCTS or args.eval_all:
|
||||
dataset = ReCTSDataset(args.ReCTS_dir_path)
|
||||
dataset = torch.utils.data.Subset(dataset, range(3000))
|
||||
acc = evaluate_ReCTS(model, dataset, args.model_name, 'ReCTS', time)
|
||||
result['ReCTS'] = acc
|
||||
if args.eval_ocr or args.eval_all:
|
||||
for i in range(len(ocr_dataset_name)):
|
||||
dataset = ocrDataset(args.ocr_dir_path, ocr_dataset_name[i])
|
||||
|
Reference in New Issue
Block a user