2024-01-16 17:26:24 +08:00
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool , Queue , Manager
# TODO model packages import
# from transformers import AutoModelForCausalLM, AutoTokenizer
def split_list ( lst , n ) :
length = len ( lst )
avg = length / / n # 每份的大小
result = [ ] # 存储分割后的子列表
for i in range ( n - 1 ) :
result . append ( lst [ i * avg : ( i + 1 ) * avg ] )
result . append ( lst [ ( n - 1 ) * avg : ] )
return result
def save_json ( json_list , save_path ) :
with open ( save_path , ' w ' ) as file :
json . dump ( json_list , file , indent = 4 )
def _get_args ( ) :
parser = ArgumentParser ( )
2024-02-14 11:39:30 +08:00
parser . add_argument ( " --image_folder " , type = str , default = " ./OCRBench_Images " )
2024-01-16 17:26:24 +08:00
parser . add_argument ( " --output_folder " , type = str , default = " ./results " )
parser . add_argument ( " --OCRBench_file " , type = str , default = " ./OCRBench/OCRBench.json " )
parser . add_argument ( " --model_path " , type = str , default = " " ) #TODO Set the address of your model's weights
parser . add_argument ( " --save_name " , type = str , default = " " ) #TODO Set the name of the JSON file you save in the output_folder.
parser . add_argument ( " --num_workers " , type = int , default = 8 )
args = parser . parse_args ( )
return args
OCRBench_score = { " Regular Text Recognition " : 0 , " Irregular Text Recognition " : 0 , " Artistic Text Recognition " : 0 , " Handwriting Recognition " : 0 ,
" Digit String Recognition " : 0 , " Non-Semantic Text Recognition " : 0 , " Scene Text-centric VQA " : 0 , " Doc-oriented VQA " : 0 , " Doc-oriented VQA " : 0 ,
" Key Information Extraction " : 0 , " Handwritten Mathematical Expression Recognition " : 0 }
AllDataset_score = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
num_all = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
def eval_worker ( args , data , eval_id , output_queue ) :
print ( f " Process { eval_id } start. " )
checkpoint = args . model_path
# TODO model init
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
# tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
# tokenizer.padding_side = 'left'
# tokenizer.pad_token_id = tokenizer.eod_id
for i in tqdm ( range ( len ( data ) ) ) :
img_path = os . path . join ( args . image_folder , data [ i ] [ ' image_path ' ] )
qs = data [ i ] [ ' question ' ]
# TODO Generation process
# query = f'<img>{img_path}</img> {qs} Answer: '
# input_ids = tokenizer(query, return_tensors='pt', padding='longest')
# attention_mask = input_ids.attention_mask
# input_ids = input_ids.input_ids
# pred = model.generate(
# input_ids=input_ids.to(f'cuda:{eval_id}'),
# attention_mask=attention_mask.to(f'cuda:{eval_id}'),
# do_sample=False,
# num_beams=1,
# max_new_tokens=100,
# min_new_tokens=1,
# length_penalty=1,
# num_return_sequences=1,
# output_hidden_states=True,
# use_cache=True,
# pad_token_id=tokenizer.eod_id,
# eos_token_id=tokenizer.eod_id,
# )
# response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
data [ i ] [ ' predict ' ] = response
output_queue . put ( { eval_id : data } )
print ( f " Process { eval_id } has completed. " )
if __name__ == " __main__ " :
multiprocessing . set_start_method ( ' spawn ' )
args = _get_args ( )
if os . path . exists ( os . path . join ( args . output_folder , f " { args . save_name } .json " ) ) :
data_path = os . path . join ( args . output_folder , f " { args . save_name } .json " )
print ( f " output_path: { data_path } exist! Only generate the results that were not generated in { data_path } . " )
else :
data_path = args . OCRBench_file
with open ( data_path , " r " ) as f :
data = json . load ( f )
data_list = split_list ( data , args . num_workers )
output_queue = Manager ( ) . Queue ( )
pool = Pool ( processes = args . num_workers )
for i in range ( len ( data_list ) ) :
pool . apply_async ( eval_worker , args = ( args , data_list [ i ] , i , output_queue ) )
pool . close ( )
pool . join ( )
results = { }
while not output_queue . empty ( ) :
result = output_queue . get ( )
results . update ( result )
data = [ ]
for i in range ( len ( data_list ) ) :
data . extend ( results [ i ] )
for i in range ( len ( data ) ) :
data_type = data [ i ] [ " type " ]
dataset_name = data [ i ] [ " dataset_name " ]
answers = data [ i ] [ " answers " ]
if data [ i ] . get ( ' predict ' , 0 ) == 0 :
continue
predict = data [ i ] [ ' predict ' ]
data [ i ] [ ' result ' ] = 0
if dataset_name == " HME100k " :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
else :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
save_json ( data , os . path . join ( args . output_folder , f " { args . save_name } .json " ) )
if len ( data ) == 1000 :
for i in range ( len ( data ) ) :
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
OCRBench_score [ data [ i ] [ ' type ' ] ] + = data [ i ] [ ' result ' ]
recognition_score = OCRBench_score [ ' Regular Text Recognition ' ] + OCRBench_score [ ' Irregular Text Recognition ' ] + OCRBench_score [ ' Artistic Text Recognition ' ] + OCRBench_score [ ' Handwriting Recognition ' ] + OCRBench_score [ ' Digit String Recognition ' ] + OCRBench_score [ ' Non-Semantic Text Recognition ' ]
Final_score = recognition_score + OCRBench_score [ ' Scene Text-centric VQA ' ] + OCRBench_score [ ' Doc-oriented VQA ' ] + OCRBench_score [ ' Key Information Extraction ' ] + OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ]
print ( " ###########################OCRBench############################## " )
print ( f " Text Recognition(Total 300): { recognition_score } " )
print ( " ------------------Details of Recognition Score------------------- " )
print ( f " Regular Text Recognition(Total 50): { OCRBench_score [ ' Regular Text Recognition ' ] } " )
print ( f " Irregular Text Recognition(Total 50): { OCRBench_score [ ' Irregular Text Recognition ' ] } " )
print ( f " Artistic Text Recognition(Total 50): { OCRBench_score [ ' Artistic Text Recognition ' ] } " )
print ( f " Handwriting Recognition(Total 50): { OCRBench_score [ ' Handwriting Recognition ' ] } " )
print ( f " Digit String Recognition(Total 50): { OCRBench_score [ ' Digit String Recognition ' ] } " )
print ( f " Non-Semantic Text Recognition(Total 50): { OCRBench_score [ ' Non-Semantic Text Recognition ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Scene Text-centric VQA(Total 200): { OCRBench_score [ ' Scene Text-centric VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Doc-oriented VQA(Total 200): { OCRBench_score [ ' Doc-oriented VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Key Information Extraction(Total 200): { OCRBench_score [ ' Key Information Extraction ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Handwritten Mathematical Expression Recognition(Total 100): { OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ] } " )
print ( " ----------------------Final Score------------------------------- " )
print ( f " Final Score(Total 1000): { Final_score } " )
else :
for i in range ( len ( data ) ) :
num_all [ data [ i ] [ ' dataset_name ' ] ] + = 1
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
AllDataset_score [ data [ i ] [ ' dataset_name ' ] ] + = data [ i ] [ ' result ' ]
for key in AllDataset_score . keys ( ) :
print ( f " { key } : { AllDataset_score [ key ] / float ( num_all [ key ] ) } " )