import json from argparse import ArgumentParser import torch import os import json from tqdm import tqdm from PIL import Image import math import multiprocessing from multiprocessing import Pool, Queue, Manager from PIL import Image from transformers import AutoModel, CLIPImageProcessor from transformers import AutoTokenizer import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode #https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=5, max_num=6, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images, target_aspect_ratio def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) new_target_ratios = [] if prior_aspect_ratio is not None: for i in target_ratios: if prior_aspect_ratio[0]%i[0] !=0 or prior_aspect_ratio[1]%i[1] !=0: new_target_ratios.append(i) else: continue # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, new_target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(image_file, input_size=448, min_num=1, max_num=6): image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values, target_aspect_ratio def load_image2(image_file, input_size=448, target_aspect_ratio=(1,1), min_num=1, max_num=6): image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) images = dynamic_preprocess2(image, image_size=input_size, prior_aspect_ratio=target_aspect_ratio, use_thumbnail=True, min_num=min_num, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values def split_list(lst, n): length = len(lst) avg = length // n # 每份的大小 result = [] # 存储分割后的子列表 for i in range(n - 1): result.append(lst[i*avg:(i+1)*avg]) result.append(lst[(n-1)*avg:]) return result def save_json(json_list,save_path): with open(save_path, 'w') as file: json.dump(json_list, file,indent=4) def _get_args(): parser = ArgumentParser() parser.add_argument("--image_folder", type=str, default="./OCRBench_Images") parser.add_argument("--output_folder", type=str, default="./results") parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json") parser.add_argument("--model_path", type=str, default='mx262/MiniMokney')#TODO Set the address of your model's weights parser.add_argument("--save_name", type=str, default="MiniMokney") #TODO Set the name of the JSON file you save in the output_folder. parser.add_argument("--num_workers", type=int, default=1) args = parser.parse_args() return args OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0, "Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0, "Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0} AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0, "STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0} def eval_worker(args, data, eval_id, output_queue): print(f"Process {eval_id} start.") checkpoint = args.model_path model = AutoModel.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).eval().to(f'cuda:{eval_id}') tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) for i in tqdm(range(len(data))): dataset_name = data[i]["dataset_name"] image_path = os.path.join(args.image_folder, data[i]['image_path']) qs = data[i]['question'] pixel_values, target_aspect_ratio = load_image(image_path, min_num=12, max_num=24) pixel_values = pixel_values.to(f'cuda:{eval_id}').to(torch.bfloat16) pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=3, max_num=11) pixel_values2 = pixel_values2.to(f'cuda:{eval_id}').to(torch.bfloat16) pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0) generation_config = dict( num_beams=1, max_new_tokens=512, do_sample=False, ) question = '\n'+qs+ '\nAnswer the question using a single word or phrase.' response = model.chat(tokenizer, pixel_values, target_aspect_ratio, question, generation_config) data[i]['predict'] = response output_queue.put({eval_id: data}) print(f"Process {eval_id} has completed.") if __name__=="__main__": multiprocessing.set_start_method('spawn') args = _get_args() if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")): data_path = os.path.join(args.output_folder,f"{args.save_name}.json") print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.") else: data_path = args.OCRBench_file with open(data_path, "r") as f: data = json.load(f) data_list = split_list(data, args.num_workers) output_queue = Manager().Queue() pool = Pool(processes=args.num_workers) for i in range(len(data_list)): pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue)) pool.close() pool.join() results = {} while not output_queue.empty(): result = output_queue.get() results.update(result) data = [] for i in range(len(data_list)): data.extend(results[i]) for i in range(len(data)): data_type = data[i]["type"] dataset_name = data[i]["dataset_name"] answers = data[i]["answers"] if data[i].get('predict',0)==0: continue predict = data[i]['predict'] data[i]['result'] = 0 if dataset_name == "HME100k": if type(answers)==list: for j in range(len(answers)): answer = answers[j].strip().replace("\n"," ").replace(" ","") predict = predict.strip().replace("\n"," ").replace(" ","") if answer in predict: data[i]['result'] = 1 else: answers = answers.strip().replace("\n"," ").replace(" ","") predict = predict.strip().replace("\n"," ").replace(" ","") if answers in predict: data[i]['result'] = 1 else: if type(answers)==list: for j in range(len(answers)): answer = answers[j].lower().strip().replace("\n"," ") predict = predict.lower().strip().replace("\n"," ") if answer in predict: data[i]['result'] = 1 else: answers = answers.lower().strip().replace("\n"," ") predict = predict.lower().strip().replace("\n"," ") if answers in predict: data[i]['result'] = 1 save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json")) if len(data)==1000: for i in range(len(data)): if data[i].get("result",100)==100: continue OCRBench_score[data[i]['type']] += data[i]['result'] recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition'] Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition'] print("###########################OCRBench##############################") print(f"Text Recognition(Total 300):{recognition_score}") print("------------------Details of Recognition Score-------------------") print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}") print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}") print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}") print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}") print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}") print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}") print("----------------------------------------------------------------") print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}") print("----------------------------------------------------------------") print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}") print("----------------------------------------------------------------") print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}") print("----------------------------------------------------------------") print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}") print("----------------------Final Score-------------------------------") print(f"Final Score(Total 1000): {Final_score}") else: for i in range(len(data)): num_all[data[i]['dataset_name']] += 1 if data[i].get("result",100)==100: continue AllDataset_score[data[i]['dataset_name']] += data[i]['result'] for key in AllDataset_score.keys(): print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")