IAM ReCTS

This commit is contained in:
echo840
2023-06-09 10:29:18 +08:00
parent 3c59897aa6
commit e22b12b169
185 changed files with 294244 additions and 22 deletions

88
eval.py Normal file → Executable file
View File

@@ -10,7 +10,7 @@ import os
import json
import re
from datasets.vqa_dataset import textVQADataset, docVQADataset, ocrVQADataset, STVQADataset, ESTVQADataset
from datasets.ocr_dataset import ocrDataset
from datasets.ocr_dataset import ocrDataset, IAMDataset, ReCTSDataset
from datasets.kie_dataset import SROIEDataset,FUNSDDataset,POIEDataset
from datasets.formula_dataset import HMEDataset
from models.lavis.lavis import lavis
@@ -19,12 +19,13 @@ from models.mPLUG_Owl.pipeline.mPLUG import mPLUG
from models.MiniGPT4.MiniGPT4 import MiniGPT4
from models.OpenFlamingo.OpenFlamingo import OpenFlamingo
from models.BLIP2.BLIP2 import BLIP2
from models.InstructBLIP.InstructBLIP import InstructBLIP
import torch
import numpy as np
def get_model(args):
if args.model_name=='BLIP2':
#model = BLIP2(args.BLIP2_model_name, args.device)
model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
model = BLIP2("/home/zhangli/.cache/huggingface/hub/models--Salesforce--blip2-opt-6.7b/snapshots/f998da12f28eb37d7e7f080cfe3291d6d9d7e1fb", args.device)
#model = lavis(args.BLIP2_model_name, args.BLIP2_model_type, args.device)
elif args.model_name=='LLaVA':
model = LLaVA(args.LLaVA_model_path, args.device)
elif args.model_name=='MiniGPT4':
@@ -33,6 +34,8 @@ def get_model(args):
model = mPLUG(args.mPLUG_model_name, args.device)
elif args.model_name=='OpenFlamingo':
model = OpenFlamingo(args.llama_path, args.check_point, args.device)
elif args.model_name == 'instructblip':
model = InstructBLIP('blip2_vicuna_instruct',args.device)
return model
def has_word(sentence, word):
pattern = r"\b" + re.escape(word) + r"\b"
@@ -45,6 +48,7 @@ def remove_special_chars(s):
pattern = r"[^a-zA-Z0-9\s]"
s = re.sub(pattern, "", s)
return s
class VQAEval:
def __init__(self):
self.contractions = {
@@ -341,6 +345,47 @@ def evaluate_OCR(
print(f'{dataset_name}:{float(correct)/num}')
return float(correct)/num
def evaluate_ReCTS(
model,
dataset,
model_name,
dataset_name,
time,
#question='图像中的中文是什么?',
question = 'What are the Chinese characters in the image?',
batch_size=1,
answer_path='./answers'
):
predictions=[]
for batch in more_itertools.chunked(
tqdm(dataset, desc="Running inference"), batch_size
):
batch = batch[0]
output = model.generate(image=batch['image_path'], question=question)
answer_dict={'question':question, 'answer':output,
'gt_answers':batch['gt_answers'], 'image_path':batch['image_path'],
'model_name':model_name}
predictions.append(answer_dict)
answer_dir = os.path.join(answer_path, time)
os.makedirs(answer_dir, exist_ok=True)
answer_path = os.path.join(answer_dir, f"{dataset_name}.json")
with open(answer_path, "w") as f:
f.write(json.dumps(predictions, indent=4))
correct = 0
num = 0
with open(answer_path, 'r') as f:
dict = json.load(f)
for i in range(len(dict)):
gt_answers = dict[i]['gt_answers']
answer = dict[i]['answer']
gt_answers = re.sub(r'[^\u4e00-\u9fa5\s]+', '', gt_answers)
answer = re.sub(r'[^\u4e00-\u9fa5\s]+', '', answer)
if gt_answers in answer:
correct+=1
num+=1
print(f'{dataset_name}:{float(correct)/num}')
return float(correct)/num
def evaluate_Formula(
model,
dataset,
@@ -385,6 +430,11 @@ def parse_args():
#OCR datasets
parser.add_argument("--ocr_dir_path", type=str, default="./data")
parser.add_argument("--ocr_dataset_name", type=str, default="IIIT5K svt IC13_857 IC15_1811 svtp ct80 cocotext ctw totaltext HOST WOST WordArt")
#IAM
parser.add_argument("--IAM_dir_path", type=str, default="./data/IAM")
#ReCTS
parser.add_argument("--ReCTS_dir_path", type=str, default="./data/ReCTS")
#HME100k
parser.add_argument("--HME_image_dir_path", type=str, default="./data/HME100K/test_images")
parser.add_argument("--HME_ann_path", type=str, default="./data/HME100K/test_labels.txt")
@@ -480,6 +530,18 @@ def parse_args():
default=False,
help="Whether to evaluate on HME100k."
)
parser.add_argument(
"--eval_IAM",
action="store_true",
default=False,
help="Whether to evaluate on IAM (handwritten)."
)
parser.add_argument(
"--eval_ReCTS",
action="store_true",
default=False,
help="Whether to evaluate on ReCTS (Chinese)."
)
parser.add_argument(
"--eval_ocr",
action="store_true",
@@ -493,6 +555,7 @@ def parse_args():
help="Whether to evaluate all datasets"
)
#BLIP2
#parser.add_argument("--BLIP2_model_path", type=str, default="/home/zhangli/GPT4/BLIP2-flant5")
parser.add_argument("--BLIP2_model_name", type=str, default="blip2_opt")#blip2_t5 blip2_opt blip2_vicuna_instruct
parser.add_argument("--BLIP2_model_type", type=str, default="pretrain_opt6.7b")#pretrain_flant5xxl pretrain_opt6.7b vicuna13b
#LLaVA
@@ -503,11 +566,11 @@ def parse_args():
parser.add_argument("--mPLUG_model_name", type=str, default="MAGAer13/mplug-owl-llama-7b")
#parser.add_argument("--mPLUG_tokenizer_path", type=str, default="./models/mPLUG_Owl/model_weight/tokenizer.model")
#OpenFlamingo
parser.add_argument("--llama_path", type=str, default="llama-7b_path")
parser.add_argument("--check_point", type=str, default="open_flamingo/checkpoint/checkpoint.pt")
parser.add_argument("--llama_path", type=str, default="/home/zhangli/llama_models/llama/llama-7b")
parser.add_argument("--check_point", type=str, default="/home/zhangli/code/open_flamingo/checkpoint/checkpoint.pt")
parser.add_argument("--model_name", type=str, default="BLIP2")#mPLUG,miniGPT4,LLaVA
parser.add_argument("--device", type=str, default="cuda:3")#2,3,7
parser.add_argument("--device", type=str, default="cuda:0")#2,3,7
args = parser.parse_args()
return args
@@ -572,8 +635,17 @@ def main(args):
dataset = HMEDataset(args.HME_image_dir_path, args.HME_ann_path)
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_Formula(model, dataset, args.model_name, 'HME', time)
result['HME'] = acc
result['HME'] = acc
if args.eval_IAM or args.eval_all:
dataset = IAMDataset(args.IAM_dir_path)
dataset = torch.utils.data.Subset(dataset, range(3000))
acc = evaluate_OCR(model, dataset, args.model_name, 'IAM', time)
result['IAM'] = acc
if args.eval_ReCTS or args.eval_all:
dataset = ReCTSDataset(args.ReCTS_dir_path)
dataset = torch.utils.data.Subset(dataset, range(3000))
acc = evaluate_ReCTS(model, dataset, args.model_name, 'ReCTS', time)
result['ReCTS'] = acc
if args.eval_ocr or args.eval_all:
for i in range(len(ocr_dataset_name)):
dataset = ocrDataset(args.ocr_dir_path, ocr_dataset_name[i])