2024-08-23 10:59:02 +08:00
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool , Queue , Manager
from PIL import Image
from transformers import AutoModel , CLIPImageProcessor
from transformers import AutoTokenizer
import torchvision . transforms as T
from torchvision . transforms . functional import InterpolationMode
#https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey
IMAGENET_MEAN = ( 0.485 , 0.456 , 0.406 )
IMAGENET_STD = ( 0.229 , 0.224 , 0.225 )
def build_transform ( input_size ) :
MEAN , STD = IMAGENET_MEAN , IMAGENET_STD
transform = T . Compose ( [
T . Lambda ( lambda img : img . convert ( ' RGB ' ) if img . mode != ' RGB ' else img ) ,
T . Resize ( ( input_size , input_size ) , interpolation = InterpolationMode . BICUBIC ) ,
T . ToTensor ( ) ,
T . Normalize ( mean = MEAN , std = STD )
] )
return transform
def find_closest_aspect_ratio ( aspect_ratio , target_ratios , width , height , image_size ) :
best_ratio_diff = float ( ' inf ' )
best_ratio = ( 1 , 1 )
area = width * height
for ratio in target_ratios :
target_aspect_ratio = ratio [ 0 ] / ratio [ 1 ]
ratio_diff = abs ( aspect_ratio - target_aspect_ratio )
if ratio_diff < best_ratio_diff :
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff :
if area > 0.5 * image_size * image_size * ratio [ 0 ] * ratio [ 1 ] :
best_ratio = ratio
return best_ratio
def dynamic_preprocess ( image , min_num = 5 , max_num = 6 , image_size = 448 , use_thumbnail = False ) :
orig_width , orig_height = image . size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set (
( i , j ) for n in range ( min_num , max_num + 1 ) for i in range ( 1 , n + 1 ) for j in range ( 1 , n + 1 ) if
i * j < = max_num and i * j > = min_num )
target_ratios = sorted ( target_ratios , key = lambda x : x [ 0 ] * x [ 1 ] )
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio (
aspect_ratio , target_ratios , orig_width , orig_height , image_size )
# calculate the target width and height
target_width = image_size * target_aspect_ratio [ 0 ]
target_height = image_size * target_aspect_ratio [ 1 ]
blocks = target_aspect_ratio [ 0 ] * target_aspect_ratio [ 1 ]
# resize the image
resized_img = image . resize ( ( target_width , target_height ) )
processed_images = [ ]
for i in range ( blocks ) :
box = (
( i % ( target_width / / image_size ) ) * image_size ,
( i / / ( target_width / / image_size ) ) * image_size ,
( ( i % ( target_width / / image_size ) ) + 1 ) * image_size ,
( ( i / / ( target_width / / image_size ) ) + 1 ) * image_size
)
# split the image
split_img = resized_img . crop ( box )
processed_images . append ( split_img )
assert len ( processed_images ) == blocks
if use_thumbnail and len ( processed_images ) != 1 :
thumbnail_img = image . resize ( ( image_size , image_size ) )
processed_images . append ( thumbnail_img )
return processed_images , target_aspect_ratio
def dynamic_preprocess2 ( image , min_num = 1 , max_num = 6 , image_size = 448 , use_thumbnail = False , prior_aspect_ratio = None ) :
orig_width , orig_height = image . size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set (
( i , j ) for n in range ( min_num , max_num + 1 ) for i in range ( 1 , n + 1 ) for j in range ( 1 , n + 1 ) if
i * j < = max_num and i * j > = min_num )
target_ratios = sorted ( target_ratios , key = lambda x : x [ 0 ] * x [ 1 ] )
new_target_ratios = [ ]
if prior_aspect_ratio is not None :
for i in target_ratios :
if prior_aspect_ratio [ 0 ] % i [ 0 ] != 0 or prior_aspect_ratio [ 1 ] % i [ 1 ] != 0 :
new_target_ratios . append ( i )
else :
continue
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio (
aspect_ratio , new_target_ratios , orig_width , orig_height , image_size )
# calculate the target width and height
target_width = image_size * target_aspect_ratio [ 0 ]
target_height = image_size * target_aspect_ratio [ 1 ]
blocks = target_aspect_ratio [ 0 ] * target_aspect_ratio [ 1 ]
# resize the image
resized_img = image . resize ( ( target_width , target_height ) )
processed_images = [ ]
for i in range ( blocks ) :
box = (
( i % ( target_width / / image_size ) ) * image_size ,
( i / / ( target_width / / image_size ) ) * image_size ,
( ( i % ( target_width / / image_size ) ) + 1 ) * image_size ,
( ( i / / ( target_width / / image_size ) ) + 1 ) * image_size
)
# split the image
split_img = resized_img . crop ( box )
processed_images . append ( split_img )
assert len ( processed_images ) == blocks
if use_thumbnail and len ( processed_images ) != 1 :
thumbnail_img = image . resize ( ( image_size , image_size ) )
processed_images . append ( thumbnail_img )
return processed_images
def load_image ( image_file , input_size = 448 , min_num = 1 , max_num = 6 ) :
image = Image . open ( image_file ) . convert ( ' RGB ' )
transform = build_transform ( input_size = input_size )
images , target_aspect_ratio = dynamic_preprocess ( image , image_size = input_size , use_thumbnail = True , min_num = min_num , max_num = max_num )
pixel_values = [ transform ( image ) for image in images ]
pixel_values = torch . stack ( pixel_values )
return pixel_values , target_aspect_ratio
def load_image2 ( image_file , input_size = 448 , target_aspect_ratio = ( 1 , 1 ) , min_num = 1 , max_num = 6 ) :
image = Image . open ( image_file ) . convert ( ' RGB ' )
transform = build_transform ( input_size = input_size )
images = dynamic_preprocess2 ( image , image_size = input_size , prior_aspect_ratio = target_aspect_ratio , use_thumbnail = True , min_num = min_num , max_num = max_num )
pixel_values = [ transform ( image ) for image in images ]
pixel_values = torch . stack ( pixel_values )
return pixel_values
def split_list ( lst , n ) :
length = len ( lst )
avg = length / / n # 每份的大小
result = [ ] # 存储分割后的子列表
for i in range ( n - 1 ) :
result . append ( lst [ i * avg : ( i + 1 ) * avg ] )
result . append ( lst [ ( n - 1 ) * avg : ] )
return result
def save_json ( json_list , save_path ) :
with open ( save_path , ' w ' ) as file :
json . dump ( json_list , file , indent = 4 )
def _get_args ( ) :
parser = ArgumentParser ( )
parser . add_argument ( " --image_folder " , type = str , default = " ./OCRBench_Images " )
parser . add_argument ( " --output_folder " , type = str , default = " ./resutls " )
parser . add_argument ( " --OCRBench_file " , type = str , default = " ./OCRBench/OCRBench.json " )
parser . add_argument ( " --model_path " , type = str , default = ' mx262/MiniMokney ' ) #TODO Set the address of your model's weights
parser . add_argument ( " --save_name " , type = str , default = " MiniMokney " ) #TODO Set the name of the JSON file you save in the output_folder.
parser . add_argument ( " --num_workers " , type = int , default = 1 )
args = parser . parse_args ( )
return args
OCRBench_score = { " Regular Text Recognition " : 0 , " Irregular Text Recognition " : 0 , " Artistic Text Recognition " : 0 , " Handwriting Recognition " : 0 ,
" Digit String Recognition " : 0 , " Non-Semantic Text Recognition " : 0 , " Scene Text-centric VQA " : 0 , " Doc-oriented VQA " : 0 , " Doc-oriented VQA " : 0 ,
" Key Information Extraction " : 0 , " Handwritten Mathematical Expression Recognition " : 0 }
AllDataset_score = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
num_all = { " IIIT5K " : 0 , " svt " : 0 , " IC13_857 " : 0 , " IC15_1811 " : 0 , " svtp " : 0 , " ct80 " : 0 , " cocotext " : 0 , " ctw " : 0 , " totaltext " : 0 , " HOST " : 0 , " WOST " : 0 , " WordArt " : 0 , " IAM " : 0 , " ReCTS " : 0 , " ORAND " : 0 , " NonSemanticText " : 0 , " SemanticText " : 0 ,
" STVQA " : 0 , " textVQA " : 0 , " ocrVQA " : 0 , " ESTVQA " : 0 , " ESTVQA_cn " : 0 , " docVQA " : 0 , " infographicVQA " : 0 , " ChartQA " : 0 , " ChartQA_Human " : 0 , " FUNSD " : 0 , " SROIE " : 0 , " POIE " : 0 , " HME100k " : 0 }
def eval_worker ( args , data , eval_id , output_queue ) :
print ( f " Process { eval_id } start. " )
checkpoint = args . model_path
model = AutoModel . from_pretrained (
checkpoint ,
torch_dtype = torch . bfloat16 ,
low_cpu_mem_usage = True ,
trust_remote_code = True ) . eval ( ) . to ( f ' cuda: { eval_id } ' )
tokenizer = AutoTokenizer . from_pretrained ( checkpoint , trust_remote_code = True )
for i in tqdm ( range ( len ( data ) ) ) :
dataset_name = data [ i ] [ " dataset_name " ]
image_path = os . path . join ( args . image_folder , data [ i ] [ ' image_path ' ] )
qs = data [ i ] [ ' question ' ]
2024-08-24 10:39:31 +08:00
pixel_values , target_aspect_ratio = load_image ( image_path , min_num = 12 , max_num = 24 )
2024-08-23 10:59:02 +08:00
pixel_values = pixel_values . to ( f ' cuda: { eval_id } ' ) . to ( torch . bfloat16 )
2024-08-24 10:39:31 +08:00
pixel_values2 = load_image2 ( image_path , target_aspect_ratio = target_aspect_ratio , min_num = 3 , max_num = 11 )
2024-08-23 10:59:02 +08:00
pixel_values2 = pixel_values2 . to ( f ' cuda: { eval_id } ' ) . to ( torch . bfloat16 )
pixel_values = torch . cat ( ( pixel_values [ : - 1 ] , pixel_values2 [ : - 1 ] , pixel_values [ - 1 : ] ) , 0 )
generation_config = dict (
num_beams = 1 ,
max_new_tokens = 512 ,
do_sample = False ,
)
question = ' <image> \n ' + qs + ' \n Answer the question using a single word or phrase. '
response = model . chat ( tokenizer , pixel_values , target_aspect_ratio , question , generation_config )
data [ i ] [ ' predict ' ] = response
output_queue . put ( { eval_id : data } )
print ( f " Process { eval_id } has completed. " )
if __name__ == " __main__ " :
multiprocessing . set_start_method ( ' spawn ' )
args = _get_args ( )
if os . path . exists ( os . path . join ( args . output_folder , f " { args . save_name } .json " ) ) :
data_path = os . path . join ( args . output_folder , f " { args . save_name } .json " )
print ( f " output_path: { data_path } exist! Only generate the results that were not generated in { data_path } . " )
else :
data_path = args . OCRBench_file
with open ( data_path , " r " ) as f :
data = json . load ( f )
data_list = split_list ( data , args . num_workers )
output_queue = Manager ( ) . Queue ( )
pool = Pool ( processes = args . num_workers )
for i in range ( len ( data_list ) ) :
pool . apply_async ( eval_worker , args = ( args , data_list [ i ] , i , output_queue ) )
pool . close ( )
pool . join ( )
results = { }
while not output_queue . empty ( ) :
result = output_queue . get ( )
results . update ( result )
data = [ ]
for i in range ( len ( data_list ) ) :
data . extend ( results [ i ] )
for i in range ( len ( data ) ) :
data_type = data [ i ] [ " type " ]
dataset_name = data [ i ] [ " dataset_name " ]
answers = data [ i ] [ " answers " ]
if data [ i ] . get ( ' predict ' , 0 ) == 0 :
continue
predict = data [ i ] [ ' predict ' ]
data [ i ] [ ' result ' ] = 0
if dataset_name == " HME100k " :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
predict = predict . strip ( ) . replace ( " \n " , " " ) . replace ( " " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
else :
if type ( answers ) == list :
for j in range ( len ( answers ) ) :
answer = answers [ j ] . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answer in predict :
data [ i ] [ ' result ' ] = 1
else :
answers = answers . lower ( ) . strip ( ) . replace ( " \n " , " " )
predict = predict . lower ( ) . strip ( ) . replace ( " \n " , " " )
if answers in predict :
data [ i ] [ ' result ' ] = 1
save_json ( data , os . path . join ( args . output_folder , f " { args . save_name } .json " ) )
if len ( data ) == 1000 :
for i in range ( len ( data ) ) :
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
OCRBench_score [ data [ i ] [ ' type ' ] ] + = data [ i ] [ ' result ' ]
recognition_score = OCRBench_score [ ' Regular Text Recognition ' ] + OCRBench_score [ ' Irregular Text Recognition ' ] + OCRBench_score [ ' Artistic Text Recognition ' ] + OCRBench_score [ ' Handwriting Recognition ' ] + OCRBench_score [ ' Digit String Recognition ' ] + OCRBench_score [ ' Non-Semantic Text Recognition ' ]
Final_score = recognition_score + OCRBench_score [ ' Scene Text-centric VQA ' ] + OCRBench_score [ ' Doc-oriented VQA ' ] + OCRBench_score [ ' Key Information Extraction ' ] + OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ]
print ( " ###########################OCRBench############################## " )
print ( f " Text Recognition(Total 300): { recognition_score } " )
print ( " ------------------Details of Recognition Score------------------- " )
print ( f " Regular Text Recognition(Total 50): { OCRBench_score [ ' Regular Text Recognition ' ] } " )
print ( f " Irregular Text Recognition(Total 50): { OCRBench_score [ ' Irregular Text Recognition ' ] } " )
print ( f " Artistic Text Recognition(Total 50): { OCRBench_score [ ' Artistic Text Recognition ' ] } " )
print ( f " Handwriting Recognition(Total 50): { OCRBench_score [ ' Handwriting Recognition ' ] } " )
print ( f " Digit String Recognition(Total 50): { OCRBench_score [ ' Digit String Recognition ' ] } " )
print ( f " Non-Semantic Text Recognition(Total 50): { OCRBench_score [ ' Non-Semantic Text Recognition ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Scene Text-centric VQA(Total 200): { OCRBench_score [ ' Scene Text-centric VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Doc-oriented VQA(Total 200): { OCRBench_score [ ' Doc-oriented VQA ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Key Information Extraction(Total 200): { OCRBench_score [ ' Key Information Extraction ' ] } " )
print ( " ---------------------------------------------------------------- " )
print ( f " Handwritten Mathematical Expression Recognition(Total 100): { OCRBench_score [ ' Handwritten Mathematical Expression Recognition ' ] } " )
print ( " ----------------------Final Score------------------------------- " )
print ( f " Final Score(Total 1000): { Final_score } " )
else :
for i in range ( len ( data ) ) :
num_all [ data [ i ] [ ' dataset_name ' ] ] + = 1
if data [ i ] . get ( " result " , 100 ) == 100 :
continue
AllDataset_score [ data [ i ] [ ' dataset_name ' ] ] + = data [ i ] [ ' result ' ]
for key in AllDataset_score . keys ( ) :
print ( f " { key } : { AllDataset_score [ key ] / float ( num_all [ key ] ) } " )