add OCRBench v2

This commit is contained in:
99Franklin
2024-12-30 19:30:31 +08:00
parent f65b1fb959
commit f5cbec6681
59 changed files with 152469 additions and 34 deletions

58
OCRBench/README.md Normal file
View File

@@ -0,0 +1,58 @@
# OCRBench: On the Hidden Mystery of OCR in Large Multimodal Models
<img src="./images/all_data.png" width="96%" height="96%">
> Large models have recently played a dominant role in natural language processing and multimodal vision-language learning. However, their effectiveness in text-related visual tasks remains relatively unexplored. In this paper, we conducted a comprehensive evaluation of Large Multimodal Models, such as GPT4V and Gemini, in various text-related visual tasks including Text Recognition, Scene Text-Centric Visual Question Answering (VQA), Document-Oriented VQA, Key Information Extraction (KIE), and Handwritten Mathematical Expression Recognition (HMER). To facilitate the assessment of Optical Character Recognition (OCR) capabilities in Large Multimodal Models, we propose OCRBench, a comprehensive evaluation benchmark. Our study encompasses 29 datasets, making it the most comprehensive OCR evaluation benchmark available. Furthermore, our study reveals both the strengths and weaknesses of these models, particularly in handling multilingual text, handwritten text, non-semantic text, and mathematical expression recognition. Most importantly, the baseline results showcased in this study could provide a foundational framework for the conception and assessment of innovative strategies targeted at enhancing zero-shot multimodal techniques.
**[Project Page [This Page]](https://github.com/Yuliang-Liu/MultimodalOCR)** | **[Paper](https://arxiv.org/abs/2305.07895)** |**[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**|**[Opencompass Leaderboard](https://rank.opencompass.org.cn/leaderboard-multimodal)**|
# Data
| Data | Link | Description |
| --- | --- | --- |
| Full Test Json | [Full Test](./json_files/FullTest.json) | This file contains the test data used in Table 1 and Table 2 from [Paper](https://arxiv.org/abs/2305.07895). |
| OCRBench Json | [OCRBench](./json_files/OCRBench.json) | This file contains the test data in OCRBench used in Table3 from [Paper](https://arxiv.org/abs/2305.07895). |
| All Test Images |[All Images](https://drive.google.com/file/d/1U5AtLoJ7FrJe9yfcbssfeLmlKb7dTosc/view?usp=drive_link) | This file contains all the testing images used in [Paper](https://arxiv.org/abs/2305.07895), including OCRBench Images.|
| OCRBench Images | [OCRBench Images](https://drive.google.com/file/d/1a3VRJx3V3SdOmPr7499Ky0Ug8AwqGUHO/view?usp=drive_link) | This file only contains the images used in OCRBench. |
| Test Results | [Test Results](https://drive.google.com/drive/folders/15XlHCuNTavI1Ihqm4G7u3J34BHpkaqyE?usp=drive_link) | This file file contains the result files for the test models. |
# OCRBench
OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models. It comprises five components: Text Recognition, SceneText-Centric VQA, Document-Oriented VQA, Key Information Extraction, and Handwritten Mathematical Expression Recognition. The benchmark includes 1000 question-answer pairs, and all the answers undergo manual verification and correction to ensure a more precise evaluation.
You can find the results of Large Multimodal Models in **[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**, if you would like to include your model in the OCRBench leaderboard, please follow the evaluation instructions provided below and feel free to contact us via email at zhangli123@hust.edu.cn. We will update the leaderboard in time.
<img src="./images/GPT4V_Gemini.png" width="96%" height="96%">
# Evaluation
The test code for evaluating models in the paper can be found in [scripts](./scripts). Before conducting the evaluation, you need to configure the model weights and environment based on the official code link provided in the scripts. If you want to evaluate other models, please edit the "TODO" things in [example](./example.py).
You can also use [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for evaluation.
Example evaluation scripts:
```python
python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/OCRBench.json --save_name Monkey_OCRBench --num_workers GPU_Nums # Test on OCRBench
python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/FullTest.json --save_name Monkey_FullTest --num_workers GPU_Nums # Full Test
```
# Citation
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{Liu_2024,
title={OCRBench: on the hidden mystery of OCR in large multimodal models},
volume={67},
ISSN={1869-1919},
url={http://dx.doi.org/10.1007/s11432-024-4235-6},
DOI={10.1007/s11432-024-4235-6},
number={12},
journal={Science China Information Sciences},
publisher={Springer Science and Business Media LLC},
author={Liu, Yuliang and Li, Zhang and Huang, Mingxin and Yang, Biao and Yu, Wenwen and Li, Chunyuan and Yin, Xu-Cheng and Liu, Cheng-Lin and Jin, Lianwen and Bai, Xiang},
year={2024},
month=dec }
```

186
OCRBench/example.py Normal file
View File

@@ -0,0 +1,186 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
# TODO model packages import
# from transformers import AutoModelForCausalLM, AutoTokenizer
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="")#TODO Set the address of your model's weights
parser.add_argument("--save_name", type=str, default="") #TODO Set the name of the JSON file you save in the output_folder.
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
# TODO model init
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
# tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
# tokenizer.padding_side = 'left'
# tokenizer.pad_token_id = tokenizer.eod_id
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
# TODO Generation process
# query = f'<img>{img_path}</img> {qs} Answer: '
# input_ids = tokenizer(query, return_tensors='pt', padding='longest')
# attention_mask = input_ids.attention_mask
# input_ids = input_ids.input_ids
# pred = model.generate(
# input_ids=input_ids.to(f'cuda:{eval_id}'),
# attention_mask=attention_mask.to(f'cuda:{eval_id}'),
# do_sample=False,
# num_beams=1,
# max_new_tokens=100,
# min_new_tokens=1,
# length_penalty=1,
# num_return_sequences=1,
# output_hidden_states=True,
# use_cache=True,
# pad_token_id=tokenizer.eod_id,
# eos_token_id=tokenizer.eod_id,
# )
# response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 408 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

150
OCRBench/scripts/GPT4V.py Normal file
View File

@@ -0,0 +1,150 @@
import base64
import requests
from tqdm import tqdm
import json
from PIL import Image
import random
import time
import pathlib
import textwrap
from argparse import ArgumentParser
import google.generativeai as genai
import json
from PIL import Image
from IPython.display import display
from IPython.display import Markdown
from tqdm import tqdm
import os
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_path", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--OPENAI_API_KEY", type=str, default="")
parser.add_argument("--API_BASE", type=str, default="")
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
if __name__ == "__main__":
args = _get_args()
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
data_path = os.path.join(args.output_path,f"{args.model}.json")
else:
data_path = args.OCRBench_file
with open(data_path,"r") as f:
data = json.load(f)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
question = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
base64_image = encode_image(img_path)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {args.OPENAI_API_KEY}"
}
payload = {
"model": args.model,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"{question}"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 500
}
try:
response = requests.post(args.API_BASE, headers=headers, json=payload)
print(response.json())
answer = response.json()['choices'][0]['message']['content']
data[i]['predict'] = answer
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
except:
time.sleep(100)
print(f"{img_path} error")
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")

115
OCRBench/scripts/Genimi.py Normal file
View File

@@ -0,0 +1,115 @@
import pathlib
import textwrap
from argparse import ArgumentParser
import google.generativeai as genai
import json
from PIL import Image
from IPython.display import display
from IPython.display import Markdown
from tqdm import tqdm
import os
import sys
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_path", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--GOOGLE_API_KEY", type=str, default="")
parser.add_argument("--model", type=str, default="gemini-pro-vision")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _get_args()
genai.configure(api_key=args.GOOGLE_API_KEY)
model = genai.GenerativeModel(args.model)
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
data_path = os.path.join(args.output_path,f"{args.model}.json")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
question = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
try:
img = Image.open(img_path).convert("RGB")
response = model.generate_content([question, img])
data[i]['predict'] = response.text
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
except:
print(f"{img_path}: API call failed.")
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")

View File

@@ -0,0 +1,210 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
# https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa_loader.py
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="liuhaotian/llava-v1.5-7b")
parser.add_argument("--model_base", type=str, default=None)
parser.add_argument("--save_name", type=str, default="llava1_5_7b")
parser.add_argument("--conv_mode", type=str, default="vicuna_v1")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
device = f"cuda:{eval_id}"
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model( model_path = model_path, model_base = args.model_base, model_name = model_name,device = device)
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
args.conv_mode = args.conv_mode + '_mmtag'
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
qs = qs+"\nAnswer the question using a single word or phrase."
if model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
image = Image.open(img_path).convert('RGB')
image_tensor = process_images([image], image_processor, model.config)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2
input_ids = input_ids.to(device=device, non_blocking=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.to(dtype=torch.float16, device=device, non_blocking=True),
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=128,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
data[i]['predict'] = outputs
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,313 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
from transformers import AutoTokenizer
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
#https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=5, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
new_target_ratios = []
if prior_aspect_ratio is not None:
for i in target_ratios:
if prior_aspect_ratio[0]%i[0] !=0 or prior_aspect_ratio[1]%i[1] !=0:
new_target_ratios.append(i)
else:
continue
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, min_num=1, max_num=6):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values, target_aspect_ratio
def load_image2(image_file, input_size=448, target_aspect_ratio=(1,1), min_num=1, max_num=6):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess2(image, image_size=input_size, prior_aspect_ratio=target_aspect_ratio, use_thumbnail=True, min_num=min_num, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default='mx262/MiniMonkey')#TODO Set the address of your model's weights
parser.add_argument("--save_name", type=str, default="MiniMokney") #TODO Set the name of the JSON file you save in the output_folder.
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
model = AutoModel.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).eval().to(f'cuda:{eval_id}')
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
for i in tqdm(range(len(data))):
dataset_name = data[i]["dataset_name"]
image_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
pixel_values, target_aspect_ratio = load_image(image_path, min_num=12, max_num=24)
pixel_values = pixel_values.to(f'cuda:{eval_id}').to(torch.bfloat16)
pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=3, max_num=11)
pixel_values2 = pixel_values2.to(f'cuda:{eval_id}').to(torch.bfloat16)
pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0)
generation_config = dict(
num_beams=1,
max_new_tokens=512,
do_sample=False,
)
question = '<image>\n'+qs+ '\nAnswer the question using a single word or phrase.'
response = model.chat(tokenizer, pixel_values, target_aspect_ratio, question, generation_config)
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

163
OCRBench/scripts/blip2.py Normal file
View File

@@ -0,0 +1,163 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
# https://huggingface.co/Salesforce/blip2-opt-6.7b
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="./model_weights/blip2-opt-6.7b")
parser.add_argument("--save_name", type=str, default="blip2_opt_6_7b")
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
processor = Blip2Processor.from_pretrained(args.model_path)
model = Blip2ForConditionalGeneration.from_pretrained(args.model_path, load_in_8bit=False, device_map={"": eval_id}, torch_dtype=torch.float16)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
image = Image.open(img_path).convert("RGB")
prompt = f"Question: {qs} Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=f"cuda:{eval_id}", dtype=torch.float16)
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
data[i]['predict'] = generated_text
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,175 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
# https://huggingface.co/Salesforce/instructblip-vicuna-7b
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="./model_weights/instructblip-vicuna-7b")
parser.add_argument("--save_name", type=str, default="instructblip_vicuna_7b")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
device = f"cuda:{eval_id}"
model = InstructBlipForConditionalGeneration.from_pretrained(args.model_path)
processor = InstructBlipProcessor.from_pretrained(args.model_path)
model.to(device)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
image = Image.open(img_path).convert('RGB')
inputs = processor(images=image, text=qs, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
do_sample=False,
num_beams=5,
max_length=100,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=0,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
data[i]['predict'] = generated_text
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

179
OCRBench/scripts/bliva.py Normal file
View File

@@ -0,0 +1,179 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from bliva.models import load_model_and_preprocess
import numpy as np
# https://github.com/mlpc-ucsd/BLIVA/blob/main/evaluate.py
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="bliva_vicuna")
parser.add_argument("--save_name", type=str, default="bliva")
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
device = f"cuda:{eval_id}"
np.random.seed(0)
disable_torch_init()
if "vicuna" in args.model_path.lower():
print("load bliva-vicuna")
model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="vicuna7b", is_eval=True, device=device)
if "flant5xxl" in args.model_path.lower():
print("load bliva-flant5xxl")
model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="flant5xxl", is_eval=True, device=device)
vis_processor = vis_processors["eval"]
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
image = Image.open(img_path).convert('RGB')
question = [qs]
image = vis_processor(image).unsqueeze(0).to(device)
outputs = model.generate({"image": image, "prompt": qs}, max_length=150)
data[i]['predict'] = outputs[0].split('### Assistant:')[0]
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

161
OCRBench/scripts/interlm.py Normal file
View File

@@ -0,0 +1,161 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoModel, AutoTokenizer
# https://github.com/InternLM/InternLM-XComposer/tree/main/InternLM-XComposer-1.0
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer-7b')#TODO Set the address of your model's weights
parser.add_argument("--save_name", type=str, default="internlm-xcomposer-7b") #TODO Set the name of the JSON file you save in the output_folder.
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
torch.set_grad_enabled(False)
# init model and tokenizer
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model.tokenizer = tokenizer
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
response = model.generate(qs, img_path)
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,162 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoModel, AutoTokenizer
#https://github.com/InternLM/InternLM-XComposer/tree/main
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer2-vl-7b')#TODO Set the address of your model's weights
parser.add_argument("--save_name", type=str, default="internlm-xcomposer2-vl-7b") #TODO Set the name of the JSON file you save in the output_folder.
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
torch.set_grad_enabled(False)
# init model and tokenizer
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
text = f'<ImageHere>{qs}'
with torch.cuda.amp.autocast():
response, _ = model.chat(tokenizer, query=text, image=img_path, history=[], do_sample=False)
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,247 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoModel
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# https://huggingface.co/OpenGVLab/InternVL2-1B
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default='OpenGVLab/InternVL2-4B')
parser.add_argument("--save_name", type=str, default="internvl2-4B")
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
model = AutoModel.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).eval().to(f'cuda:{eval_id}')
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, use_fast=False)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
pixel_values = load_image(img_path, max_num=12).to(torch.bfloat16).to(f'cuda:{eval_id}')
generation_config = dict(max_new_tokens=1024, do_sample=False)
question = f'<image>\n{qs}'
response = model.chat(tokenizer, pixel_values, question, generation_config)
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
# eval_worker(args, data_list[0], 0, output_queue)
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

178
OCRBench/scripts/intervl.py Normal file
View File

@@ -0,0 +1,178 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
from transformers import AutoTokenizer
#https://github.com/OpenGVLab/InternVL
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default='OpenGVLab/InternVL-Chat-Chinese-V1-1')#TODO Set the address of your model's weights
parser.add_argument("--save_name", type=str, default="InternVL-Chat-Chinese-V1-1") #TODO Set the name of the JSON file you save in the output_folder.
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
model = AutoModel.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map='cuda').eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
image = Image.open(img_path).convert('RGB')
image = image.resize((448, 448))
image_processor = CLIPImageProcessor.from_pretrained(checkpoint)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
generation_config = dict(
num_beams=1,
max_new_tokens=512,
do_sample=False,
)
response = model.chat(tokenizer, pixel_values, qs, generation_config)
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

330
OCRBench/scripts/llavar.py Normal file
View File

@@ -0,0 +1,330 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from llava import LlavaLlamaForCausalLM
from llava.conversation import conv_templates
from llava import conversation as conversation_lib
from llava.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from PIL import Image,ImageOps
# https://github.com/SALT-NLP/LLaVAR/blob/main/LLaVA/llava/eval/model_vqa.py
def resize_image(image, target_size):
width, height = image.size
aspect_ratio = width / height
if aspect_ratio > 1:
new_width = target_size[0]
new_height = int(new_width / aspect_ratio)
else:
new_height = target_size[1]
new_width = int(new_height * aspect_ratio)
image = image.resize((new_width, new_height))
width_diff = target_size[0] - image.size[0]
height_diff = target_size[1] - image.size[1]
left_padding = 0
top_padding = 0
right_padding = width_diff - left_padding
bottom_padding = height_diff - top_padding
padded_image = ImageOps.expand(image, border=(left_padding, top_padding, right_padding, bottom_padding), fill=0)
return padded_image
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def patch_config(config):
patch_dict = {
"use_mm_proj": True,
"mm_vision_tower": "openai/clip-vit-large-patch14",
"mm_hidden_size": 1024
}
cfg = AutoConfig.from_pretrained(config)
if not hasattr(cfg, "mm_vision_tower"):
print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
for k, v in patch_dict.items():
setattr(cfg, k, v)
cfg.save_pretrained(config)
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="./model_weights/LLaVar")
parser.add_argument("--save_name", type=str, default="llavar")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--mm-projector", type=str, default=None)
parser.add_argument("--vision-tower", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
device = f"cuda:{eval_id}"
disable_torch_init()
model_name = os.path.expanduser(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if args.mm_projector is None:
patch_config(model_name)
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_tower = model.model.vision_tower[0]
vision_tower.to(device=device, dtype=torch.float16)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
else:
# in case of using a pretrained model with only a MLP projector weights
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).to(device)
image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size)
mm_projector_weights = torch.load(args.mm_projector, map_location='cpu')
mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
model.model.mm_projector = mm_projector.to(device).half()
model.model.vision_tower = [vision_tower]
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
# qs = qs+"\nAnswer the question using a single word or phrase."
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
if mm_use_im_start_end:
qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
else:
qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
if args.conv_mode == 'simple_legacy':
qs += '\n\n### Response:'
# conv = default_conversation.copy()
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
# modified
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = Image.open(img_path)
# if "REval" in args.image_folder:
image = resize_image(image, (336, 336))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
input_ids = torch.as_tensor(inputs.input_ids).to(device)
# new stopping implementation
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
# keywords = ['###']
# modified
keywords = ['</s>']
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().to(device),
do_sample=False,
temperature=0,
max_new_tokens=200,
stopping_criteria=[stopping_criteria])
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
# modified
if args.conv_mode == 'simple_legacy' or args.conv_mode == 'simple':
while True:
cur_len = len(outputs)
outputs = outputs.strip()
for pattern in ['###', 'Assistant:', 'Response:']:
if outputs.startswith(pattern):
outputs = outputs[len(pattern):].strip()
if len(outputs) == cur_len:
break
if conv.sep_style == conversation_lib.SeparatorStyle.TWO:
sep = conv.sep2
else:
sep = conv.sep
try:
index = outputs.index(sep)
except ValueError:
outputs += sep
index = outputs.index(sep)
outputs = outputs[:index].strip()
data[i]['predict'] = outputs
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,329 @@
import json
import multiprocessing
import os
from argparse import ArgumentParser
from multiprocessing import Manager, Pool, Queue
import torch
from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from mplug_docowl.conversation import conv_templates
from mplug_docowl.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from mplug_docowl.model.builder import load_pretrained_model
from mplug_docowl.processor import DocProcessor
from tqdm import tqdm
from transformers import TextStreamer
# https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i * avg : (i + 1) * avg])
result.append(lst[(n - 1) * avg :])
return result
def save_json(json_list, save_path):
with open(save_path, "w", encoding="utf-8") as file:
json.dump(json_list, file, indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="mPLUG/DocOwl1.5")
parser.add_argument("--save_name", type=str, default="mplug-DocOwl1.5")
parser.add_argument("--conv_mode", type=str, default="mplug_owl2")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args()
return args
OCRBench_score = {
"Regular Text Recognition": 0,
"Irregular Text Recognition": 0,
"Artistic Text Recognition": 0,
"Handwriting Recognition": 0,
"Digit String Recognition": 0,
"Non-Semantic Text Recognition": 0,
"Scene Text-centric VQA": 0,
"Doc-oriented VQA": 0,
"Key Information Extraction": 0,
"Handwritten Mathematical Expression Recognition": 0,
}
AllDataset_score = {
"IIIT5K": 0,
"svt": 0,
"IC13_857": 0,
"IC15_1811": 0,
"svtp": 0,
"ct80": 0,
"cocotext": 0,
"ctw": 0,
"totaltext": 0,
"HOST": 0,
"WOST": 0,
"WordArt": 0,
"IAM": 0,
"ReCTS": 0,
"ORAND": 0,
"NonSemanticText": 0,
"SemanticText": 0,
"STVQA": 0,
"textVQA": 0,
"ocrVQA": 0,
"ESTVQA": 0,
"ESTVQA_cn": 0,
"docVQA": 0,
"infographicVQA": 0,
"ChartQA": 0,
"ChartQA_Human": 0,
"FUNSD": 0,
"SROIE": 0,
"POIE": 0,
"HME100k": 0,
}
num_all = {
"IIIT5K": 0,
"svt": 0,
"IC13_857": 0,
"IC15_1811": 0,
"svtp": 0,
"ct80": 0,
"cocotext": 0,
"ctw": 0,
"totaltext": 0,
"HOST": 0,
"WOST": 0,
"WordArt": 0,
"IAM": 0,
"ReCTS": 0,
"ORAND": 0,
"NonSemanticText": 0,
"SemanticText": 0,
"STVQA": 0,
"textVQA": 0,
"ocrVQA": 0,
"ESTVQA": 0,
"ESTVQA_cn": 0,
"docVQA": 0,
"infographicVQA": 0,
"ChartQA": 0,
"ChartQA_Human": 0,
"FUNSD": 0,
"SROIE": 0,
"POIE": 0,
"HME100k": 0,
}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, _, _ = load_pretrained_model(
args.model_path,
None,
model_name,
load_8bit=False,
load_4bit=False,
device=f"cuda:{eval_id}",
)
doc_image_processor = DocProcessor(
image_size=448,
anchors="grid_9",
add_global_img=True,
add_textual_crop_indicator=True,
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]["image_path"])
qs = data[i]["question"]
if data[i].get("predict", 0) != 0:
print(f"{img_path} predict exist, continue.")
continue
image_tensor, patch_positions, text = doc_image_processor(
images=img_path, query="<|image|>" + qs
)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
patch_positions = patch_positions.to(model.device)
conv = conv_templates["mplug_owl2"].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(model.device)
)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
patch_positions=patch_positions,
do_sample=False,
temperature=1.0,
max_new_tokens=512,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
data[i]["predict"] = outputs
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
args = _get_args()
if os.path.exists(os.path.join(args.output_folder, f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder, f"{args.save_name}.json")
print(
f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}."
)
else:
data_path = args.OCRBench_file
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
# pool.apply(eval_worker, args=(args, data_list[i], i, output_queue))
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get("predict", 0) == 0:
continue
predict = data[i]["predict"]
data[i]["result"] = 0
if dataset_name == "HME100k":
if type(answers) == list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n", " ").replace(" ", "")
predict = predict.strip().replace("\n", " ").replace(" ", "")
if answer in predict:
data[i]["result"] = 1
else:
answers = answers.strip().replace("\n", " ").replace(" ", "")
predict = predict.strip().replace("\n", " ").replace(" ", "")
if answers in predict:
data[i]["result"] = 1
else:
if type(answers) == list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n", " ")
predict = predict.lower().strip().replace("\n", " ")
if answer in predict:
data[i]["result"] = 1
else:
answers = answers.lower().strip().replace("\n", " ")
predict = predict.lower().strip().replace("\n", " ")
if answers in predict:
data[i]["result"] = 1
save_json(data, os.path.join(args.output_folder, f"{args.save_name}.json"))
if len(data) == 1000:
for i in range(len(data)):
if data[i].get("result", 100) == 100:
continue
OCRBench_score[data[i]["type"]] += data[i]["result"]
recognition_score = (
OCRBench_score["Regular Text Recognition"]
+ OCRBench_score["Irregular Text Recognition"]
+ OCRBench_score["Artistic Text Recognition"]
+ OCRBench_score["Handwriting Recognition"]
+ OCRBench_score["Digit String Recognition"]
+ OCRBench_score["Non-Semantic Text Recognition"]
)
Final_score = (
recognition_score
+ OCRBench_score["Scene Text-centric VQA"]
+ OCRBench_score["Doc-oriented VQA"]
+ OCRBench_score["Key Information Extraction"]
+ OCRBench_score["Handwritten Mathematical Expression Recognition"]
)
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(
f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}"
)
print(
f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}"
)
print(
f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}"
)
print(
f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}"
)
print(
f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}"
)
print(
f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}"
)
print("----------------------------------------------------------------")
print(
f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}"
)
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(
f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}"
)
print("----------------------------------------------------------------")
print(
f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}"
)
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]["dataset_name"]] += 1
if data[i].get("result", 100) == 100:
continue
AllDataset_score[data[i]["dataset_name"]] += data[i]["result"]
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,185 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
import sys
sys.path.append("./scripts/mPLUG-Owl/mPLUG-Owl/")
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
# https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl")
parser.add_argument("--save_name", type=str, default="mplug-owl")
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
pretrained_ckpt = args.model_path
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
)
model.to(f"cuda:{eval_id}")
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
prompts = [
f'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <image>
Human: {qs}
AI: ''']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
generate_kwargs = {
'do_sample': False,
'top_k': 1,
'max_length': 100
}
images = [Image.open(img_path)]
inputs = processor(text=prompts, images=images, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
res = model.generate(**inputs, **generate_kwargs)
sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
data[i]['predict'] = sentence
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,191 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import TextStreamer
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates, SeparatorStyle
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
# https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl2")
parser.add_argument("--save_name", type=str, default="mplug-owl2")
parser.add_argument("--conv_mode", type=str, default="mplug_owl2")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name, load_8bit=False, load_4bit=False, device=f"cuda:{eval_id}")
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
conv = conv_templates[args.conv_mode].copy()
roles = conv.roles
image = Image.open(img_path).convert('RGB')
max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))
image_tensor = process_images([image], image_processor)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
inp = DEFAULT_IMAGE_TOKEN + qs
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
temperature=args.temperature,
max_new_tokens=100,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
data[i]['predict'] = outputs
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,187 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
import sys
sys.path.append("./scripts/MiniGPT-4/")
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.common.config import Config
import random
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/eval_scripts/eval_vqa.py
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--cfg-path", default='./scripts/MiniGPT-4/eval_configs/minigptv2_eval.yaml')
parser.add_argument("--save_name", type=str, default="minigptv2")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
device = f'cuda:{eval_id}'
cfg = Config(args)
model, vis_processor = init_model(args, device)
conv_temp = CONV_VISION_minigptv2.copy()
conv_temp.system = ""
model.eval()
instruction_pool = [
"[vqa] {}"
]
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
image = Image.open(img_path).convert("RGB")
image = vis_processor(image)
image = image.unsqueeze(0).to(device)
# question = self.text_processor(qs)
instruction = random.choice(instruction_pool).format(qs)
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
texts = prepare_texts(instruction, conv_temp) # warp the texts with conversation template
answers = model.generate(image, texts, max_new_tokens=100, do_sample=False)
data[i]['predict'] = answers[0]
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

182
OCRBench/scripts/monkey.py Normal file
View File

@@ -0,0 +1,182 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://github.com/Yuliang-Liu/Monkey
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="echo840/Monkey")
parser.add_argument("--save_name", type=str, default="monkey")
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
query = f'<img>{img_path}</img> {qs} Answer: '
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.to(f'cuda:{eval_id}'),
attention_mask=attention_mask.to(f'cuda:{eval_id}'),
do_sample=False,
num_beams=1,
max_new_tokens=100,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

181
OCRBench/scripts/qwenvl.py Normal file
View File

@@ -0,0 +1,181 @@
import json
from argparse import ArgumentParser
import torch
import os
import json
from tqdm import tqdm
from PIL import Image
import math
import multiprocessing
from multiprocessing import Pool, Queue, Manager
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/evaluate_vqa.py
def split_list(lst, n):
length = len(lst)
avg = length // n # 每份的大小
result = [] # 存储分割后的子列表
for i in range(n - 1):
result.append(lst[i*avg:(i+1)*avg])
result.append(lst[(n-1)*avg:])
return result
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
parser.add_argument("--output_folder", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model_path", type=str, default="Qwen/Qwen-VL")
parser.add_argument("--save_name", type=str, default="qwenvl")
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args()
return args
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def eval_worker(args, data, eval_id, output_queue):
print(f"Process {eval_id} start.")
checkpoint = args.model_path
model = AutoModelForCausalLM.from_pretrained(
checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
qs = data[i]['question']
# query = f'<img>{img_path}</img> {qs} Answer: '
query = f'<img>{img_path}</img>{qs} Answer:'
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.to(f'cuda:{eval_id}'),
attention_mask=attention_mask.to(f'cuda:{eval_id}'),
do_sample=False,
num_beams=1,
max_new_tokens=100,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
data[i]['predict'] = response
output_queue.put({eval_id: data})
print(f"Process {eval_id} has completed.")
if __name__=="__main__":
multiprocessing.set_start_method('spawn')
args = _get_args()
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
data_list = split_list(data, args.num_workers)
output_queue = Manager().Queue()
pool = Pool(processes=args.num_workers)
for i in range(len(data_list)):
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
pool.close()
pool.join()
results = {}
while not output_queue.empty():
result = output_queue.get()
results.update(result)
data = []
for i in range(len(data_list)):
data.extend(results[i])
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
if len(data)==1000:
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")
else:
for i in range(len(data)):
num_all[data[i]['dataset_name']] += 1
if data[i].get("result",100)==100:
continue
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
for key in AllDataset_score.keys():
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")

View File

@@ -0,0 +1,138 @@
import pathlib
from argparse import ArgumentParser
import json
from tqdm import tqdm
import os
import sys
from http import HTTPStatus
from dashscope import MultiModalConversation
import time
# You should follow the instructions here befor strat: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
def save_json(json_list,save_path):
with open(save_path, 'w') as file:
json.dump(json_list, file,indent=4)
def call_with_local_file(img_path, question, model_name):
"""Sample of use local file.
linux&mac file schema: file:///home/images/test.png
windows file schema: file://D:/images/abc.png
"""
local_file_path1 = f'file://{img_path}'
messages = [{
'role': 'system',
'content': [{
'text': 'You are a helpful assistant.'
}]
}, {
'role':
'user',
'content': [
{
'image': local_file_path1
},
{
'text': question
},
]
}]
response = MultiModalConversation.call(model=model_name, messages=messages)
# time.sleep(2) #For qwenvl-max you may need to add this line to avoid the limits.
print(response)
return response['output']['choices'][0]["message"]['content'][0]['text']
def _get_args():
parser = ArgumentParser()
parser.add_argument("--image_folder", type=str, default="./data")
parser.add_argument("--output_path", type=str, default="./results")
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
parser.add_argument("--model", type=str, default="qwen-vl-max")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _get_args()
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
data_path = os.path.join(args.output_path,f"{args.model}.json")
else:
data_path = args.OCRBench_file
with open(data_path, "r") as f:
data = json.load(f)
for i in tqdm(range(len(data))):
img_path = os.path.join(args.image_folder, data[i]['image_path'])
question = data[i]['question']
if data[i].get("predict", 0)!=0:
print(f"{img_path} predict exist, continue.")
continue
try:
response = call_with_local_file(img_path, question, args.model)
data[i]['predict'] = response
except:
print("QwenVL api failed")
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
for i in range(len(data)):
data_type = data[i]["type"]
dataset_name = data[i]["dataset_name"]
answers = data[i]["answers"]
if data[i].get('predict',0)==0:
continue
predict = data[i]['predict']
data[i]['result'] = 0
if dataset_name == "HME100k":
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
data[i]['result'] = 1
else:
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answer in predict:
data[i]['result'] = 1
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if answers in predict:
data[i]['result'] = 1
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
for i in range(len(data)):
if data[i].get("result",100)==100:
continue
OCRBench_score[data[i]['type']] += data[i]['result']
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
print("###########################OCRBench##############################")
print(f"Text Recognition(Total 300):{recognition_score}")
print("------------------Details of Recognition Score-------------------")
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
print("----------------------------------------------------------------")
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
print("----------------------------------------------------------------")
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
print("----------------------------------------------------------------")
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
print("----------------------Final Score-------------------------------")
print(f"Final Score(Total 1000): {Final_score}")