2024-01-16 17:26:24 +08:00
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool , Queue , Manager
import sys
sys . path . append ( " ./scripts/MiniGPT-4/ " )
from minigpt4 . common . eval_utils import prepare_texts , init_model , eval_parser
from minigpt4 . conversation . conversation import CONV_VISION_minigptv2
from minigpt4 . common . config import Config
import random
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/eval_scripts/eval_vqa.py
def split_list ( lst , n ) :
length = len ( lst )
avg = length / / n # 每份的大小
result = [ ] # 存储分割后的子列表
for i in range ( n - 1 ) :
result . append ( lst [ i * avg : ( i + 1 ) * avg ] )
result . append ( lst [ ( n - 1 ) * avg : ] )
return result
def save_json ( json_list , save_path ) :
with open ( save_path , ' w ' ) as file :
json . dump ( json_list , file , indent = 4 )
def _get_args ( ) :
parser = ArgumentParser ( )
2024-02-06 14:25:25 +08:00
parser . add_argument ( " --image_folder " , type = str , default = " ./OCRBench_Images " )
2024-01-16 17:26:24 +08:00
parser . add_argument ( " --output_folder " , type = str , default = " ./results " )
parser . add_argument ( " --OCRBench_file " , type = str , default = " ./OCRBench/OCRBench.json " )
parser . add_argument ( " --cfg-path " , default = ' ./scripts/MiniGPT-4/eval_configs/minigptv2_eval.yaml ' )
parser . add_argument ( " --save_name " , type = str , default = " minigptv2 " )
parser . add_argument ( " --num_workers " , type = int , default = 1 )
parser . add_argument ( " --temperature " , type = float , default = 0.0 )
parser . add_argument (
" --options " ,
nargs = " + " ,
help = " override some settings in the used config, the key-value pair "
" in xxx=yyy format will be merged into config file (deprecate), "
" change to --cfg-options instead. " ,
)
args = parser . parse_args ( )
return args
OCRBench_score = { " Regular Text Recognition " : 0 , " Irregular Text Recognition " : 0 , " Artistic Text Recognition " : 0 , " Handwriting Recognition " : 0 ,
" Digit String Recognition " : 0 , " Non-Semantic Text Recognition " : 0 , " Scene Text-centric VQA " : 0 , " Doc-oriented VQA " : 0 , " Doc-oriented VQA " : 0 ,
" Key Information Extraction " : 0 , " Handwritten Mathematical Expression Recognition " : 0 }
AllDataset_score = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
num_all = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
def eval_worker ( args , data , eval_id , output_queue ) :
print ( f " Process { eval_id } start. " )
device = f ' cuda: { eval_id } '
cfg = Config ( args )
model , vis_processor = init_model ( args , device )
conv_temp = CONV_VISION_minigptv2 . copy ( )
conv_temp . system = " "
model . eval ( )
instruction_pool = [
" [vqa] {} "
]
for i in tqdm ( range ( len ( data ) ) ) :
img_path = os . path . join ( args . image_folder , data [ i ] [ ' image_path ' ] )
qs = data [ i ] [ ' question ' ]
if data [ i ] . get ( " predict " , 0 ) != 0 :
print ( f " { img_path } predict exist, continue. " )
continue
image = Image . open ( img_path ) . convert ( " RGB " )
image = vis_processor ( image )
image = image . unsqueeze ( 0 ) . to ( device )
# question = self.text_processor(qs)
instruction = random . choice ( instruction_pool ) . format ( qs )
instruction = " <Img><ImageHere></Img> {} " . format ( instruction )
texts = prepare_texts ( instruction , conv_temp ) # warp the texts with conversation template
answers = model . generate ( image , texts , max_new_tokens = 100 , do_sample = False )
data [ i ] [ ' predict ' ] = answers [ 0 ]
output_queue . put ( { eval_id : data } )
print ( f " Process { eval_id } has completed. " )
if __name__ == " __main__ " :
multiprocessing . set_start_method ( ' spawn ' )
args = _get_args ( )
if os . path . exists ( os . path . join ( args . output_folder , f " { args . save_name } .json " ) ) :
data_path = os . path . join ( args . output_folder , f " { args . save_name } .json " )
print ( f " output_path: { data_path } exist! Only generate the results that were not generated in { data_path } . " )
else :
data_path = args . OCRBench_file
with open ( data_path , " r " ) as f :
data = json . load ( f )
data_list = split_list ( data , args . num_workers )
output_queue = Manager ( ) . Queue ( )
pool = Pool ( processes = args . num_workers )
for i in range ( len ( data_list ) ) :
pool . apply_async ( eval_worker , args = ( args , data_list [ i ] , i , output_queue ) )
pool . close ( )
pool . join ( )
results = { }
while not output_queue . empty ( ) :
result = output_queue . get ( )
results . update ( result )
data = [ ]
for i in range ( len ( data_list ) ) :
data . extend ( results [ i ] )
for i in range ( len ( data ) ) :
data_type = data [ i ] [ " type " ]
dataset_name = data [ i ] [ " dataset_name " ]
answers = data [ i ] [ " answers " ]
if data [ i ] . get ( ' predict ' , 0 ) == 0 :
continue
predict = data [ i ] [ ' predict ' ]
data [ i ] [ ' result ' ] = 0
if dataset_name == " HME100k " :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
else :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
save_json ( data , os . path . join ( args . output_folder , f " { args . save_name } .json " ) )
if len ( data ) == 1000 :
for i in range ( len ( data ) ) :
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
OCRBench_score [ data [ i ] [ ' type ' ] ] + = data [ i ] [ ' result ' ]
recognition_score = OCRBench_score [ ' Regular Text Recognition ' ] + OCRBench_score [ ' Irregular Text Recognition ' ] + OCRBench_score [ ' Artistic Text Recognition ' ] + OCRBench_score [ ' Handwriting Recognition ' ] + OCRBench_score [ ' Digit String Recognition ' ] + OCRBench_score [ ' Non-Semantic Text Recognition ' ]
Final_score = recognition_score + OCRBench_score [ ' Scene Text-centric VQA ' ] + OCRBench_score [ ' Doc-oriented VQA ' ] + OCRBench_score [ ' Key Information Extraction ' ] + OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ]
print ( " ###########################OCRBench############################## " )
print ( f " Text Recognition(Total 300): { recognition_score } " )
print ( " ------------------Details of Recognition Score------------------- " )
print ( f " Regular Text Recognition(Total 50): { OCRBench_score [ ' Regular Text Recognition ' ] } " )
print ( f " Irregular Text Recognition(Total 50): { OCRBench_score [ ' Irregular Text Recognition ' ] } " )
print ( f " Artistic Text Recognition(Total 50): { OCRBench_score [ ' Artistic Text Recognition ' ] } " )
print ( f " Handwriting Recognition(Total 50): { OCRBench_score [ ' Handwriting Recognition ' ] } " )
print ( f " Digit String Recognition(Total 50): { OCRBench_score [ ' Digit String Recognition ' ] } " )
print ( f " Non-Semantic Text Recognition(Total 50): { OCRBench_score [ ' Non-Semantic Text Recognition ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Scene Text-centric VQA(Total 200): { OCRBench_score [ ' Scene Text-centric VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Doc-oriented VQA(Total 200): { OCRBench_score [ ' Doc-oriented VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Key Information Extraction(Total 200): { OCRBench_score [ ' Key Information Extraction ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Handwritten Mathematical Expression Recognition(Total 100): { OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ] } " )
print ( " ----------------------Final Score------------------------------- " )
print ( f " Final Score(Total 1000): { Final_score } " )
else :
for i in range ( len ( data ) ) :
num_all [ data [ i ] [ ' dataset_name ' ] ] + = 1
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
AllDataset_score [ data [ i ] [ ' dataset_name ' ] ] + = data [ i ] [ ' result ' ]
for key in AllDataset_score . keys ( ) :
print ( f " { key } : { AllDataset_score [ key ] / float ( num_all [ key ] ) } " )