add OCRBench v2
This commit is contained in:
58
OCRBench/README.md
Normal file
58
OCRBench/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# OCRBench: On the Hidden Mystery of OCR in Large Multimodal Models
|
||||
<img src="./images/all_data.png" width="96%" height="96%">
|
||||
|
||||
> Large models have recently played a dominant role in natural language processing and multimodal vision-language learning. However, their effectiveness in text-related visual tasks remains relatively unexplored. In this paper, we conducted a comprehensive evaluation of Large Multimodal Models, such as GPT4V and Gemini, in various text-related visual tasks including Text Recognition, Scene Text-Centric Visual Question Answering (VQA), Document-Oriented VQA, Key Information Extraction (KIE), and Handwritten Mathematical Expression Recognition (HMER). To facilitate the assessment of Optical Character Recognition (OCR) capabilities in Large Multimodal Models, we propose OCRBench, a comprehensive evaluation benchmark. Our study encompasses 29 datasets, making it the most comprehensive OCR evaluation benchmark available. Furthermore, our study reveals both the strengths and weaknesses of these models, particularly in handling multilingual text, handwritten text, non-semantic text, and mathematical expression recognition. Most importantly, the baseline results showcased in this study could provide a foundational framework for the conception and assessment of innovative strategies targeted at enhancing zero-shot multimodal techniques.
|
||||
|
||||
**[Project Page [This Page]](https://github.com/Yuliang-Liu/MultimodalOCR)** | **[Paper](https://arxiv.org/abs/2305.07895)** |**[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**|**[Opencompass Leaderboard](https://rank.opencompass.org.cn/leaderboard-multimodal)**|
|
||||
|
||||
|
||||
# Data
|
||||
| Data | Link | Description |
|
||||
| --- | --- | --- |
|
||||
| Full Test Json | [Full Test](./json_files/FullTest.json) | This file contains the test data used in Table 1 and Table 2 from [Paper](https://arxiv.org/abs/2305.07895). |
|
||||
| OCRBench Json | [OCRBench](./json_files/OCRBench.json) | This file contains the test data in OCRBench used in Table3 from [Paper](https://arxiv.org/abs/2305.07895). |
|
||||
| All Test Images |[All Images](https://drive.google.com/file/d/1U5AtLoJ7FrJe9yfcbssfeLmlKb7dTosc/view?usp=drive_link) | This file contains all the testing images used in [Paper](https://arxiv.org/abs/2305.07895), including OCRBench Images.|
|
||||
| OCRBench Images | [OCRBench Images](https://drive.google.com/file/d/1a3VRJx3V3SdOmPr7499Ky0Ug8AwqGUHO/view?usp=drive_link) | This file only contains the images used in OCRBench. |
|
||||
| Test Results | [Test Results](https://drive.google.com/drive/folders/15XlHCuNTavI1Ihqm4G7u3J34BHpkaqyE?usp=drive_link) | This file file contains the result files for the test models. |
|
||||
|
||||
|
||||
# OCRBench
|
||||
|
||||
OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models. It comprises five components: Text Recognition, SceneText-Centric VQA, Document-Oriented VQA, Key Information Extraction, and Handwritten Mathematical Expression Recognition. The benchmark includes 1000 question-answer pairs, and all the answers undergo manual verification and correction to ensure a more precise evaluation.
|
||||
|
||||
You can find the results of Large Multimodal Models in **[OCRBench Leaderboard](https://huggingface.co/spaces/echo840/ocrbench-leaderboard)**, if you would like to include your model in the OCRBench leaderboard, please follow the evaluation instructions provided below and feel free to contact us via email at zhangli123@hust.edu.cn. We will update the leaderboard in time.
|
||||
|
||||
<img src="./images/GPT4V_Gemini.png" width="96%" height="96%">
|
||||
|
||||
# Evaluation
|
||||
The test code for evaluating models in the paper can be found in [scripts](./scripts). Before conducting the evaluation, you need to configure the model weights and environment based on the official code link provided in the scripts. If you want to evaluate other models, please edit the "TODO" things in [example](./example.py).
|
||||
|
||||
You can also use [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for evaluation.
|
||||
|
||||
Example evaluation scripts:
|
||||
```python
|
||||
|
||||
python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/OCRBench.json --save_name Monkey_OCRBench --num_workers GPU_Nums # Test on OCRBench
|
||||
python ./scripts/monkey.py --image_folder ./OCRBench_Images --OCRBench_file ./OCRBench/FullTest.json --save_name Monkey_FullTest --num_workers GPU_Nums # Full Test
|
||||
|
||||
```
|
||||
|
||||
# Citation
|
||||
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
|
||||
```BibTeX
|
||||
@article{Liu_2024,
|
||||
title={OCRBench: on the hidden mystery of OCR in large multimodal models},
|
||||
volume={67},
|
||||
ISSN={1869-1919},
|
||||
url={http://dx.doi.org/10.1007/s11432-024-4235-6},
|
||||
DOI={10.1007/s11432-024-4235-6},
|
||||
number={12},
|
||||
journal={Science China Information Sciences},
|
||||
publisher={Springer Science and Business Media LLC},
|
||||
author={Liu, Yuliang and Li, Zhang and Huang, Mingxin and Yang, Biao and Yu, Wenwen and Li, Chunyuan and Yin, Xu-Cheng and Liu, Cheng-Lin and Jin, Lianwen and Bai, Xiang},
|
||||
year={2024},
|
||||
month=dec }
|
||||
```
|
||||
|
||||
|
||||
|
186
OCRBench/example.py
Normal file
186
OCRBench/example.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
# TODO model packages import
|
||||
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="")#TODO Set the address of your model's weights
|
||||
parser.add_argument("--save_name", type=str, default="") #TODO Set the name of the JSON file you save in the output_folder.
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
|
||||
# TODO model init
|
||||
|
||||
# model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
|
||||
# tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
# tokenizer.padding_side = 'left'
|
||||
# tokenizer.pad_token_id = tokenizer.eod_id
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
|
||||
# TODO Generation process
|
||||
# query = f'<img>{img_path}</img> {qs} Answer: '
|
||||
|
||||
# input_ids = tokenizer(query, return_tensors='pt', padding='longest')
|
||||
# attention_mask = input_ids.attention_mask
|
||||
# input_ids = input_ids.input_ids
|
||||
|
||||
# pred = model.generate(
|
||||
# input_ids=input_ids.to(f'cuda:{eval_id}'),
|
||||
# attention_mask=attention_mask.to(f'cuda:{eval_id}'),
|
||||
# do_sample=False,
|
||||
# num_beams=1,
|
||||
# max_new_tokens=100,
|
||||
# min_new_tokens=1,
|
||||
# length_penalty=1,
|
||||
# num_return_sequences=1,
|
||||
# output_hidden_states=True,
|
||||
# use_cache=True,
|
||||
# pad_token_id=tokenizer.eod_id,
|
||||
# eos_token_id=tokenizer.eod_id,
|
||||
# )
|
||||
# response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
BIN
OCRBench/images/GPT4V_Gemini.png
Normal file
BIN
OCRBench/images/GPT4V_Gemini.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 408 KiB |
BIN
OCRBench/images/all_data.png
Normal file
BIN
OCRBench/images/all_data.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.8 MiB |
150
OCRBench/scripts/GPT4V.py
Normal file
150
OCRBench/scripts/GPT4V.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import base64
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from PIL import Image
|
||||
import random
|
||||
import time
|
||||
import pathlib
|
||||
import textwrap
|
||||
from argparse import ArgumentParser
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
from PIL import Image
|
||||
from IPython.display import display
|
||||
from IPython.display import Markdown
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_path", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--OPENAI_API_KEY", type=str, default="")
|
||||
parser.add_argument("--API_BASE", type=str, default="")
|
||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
|
||||
data_path = os.path.join(args.output_path,f"{args.model}.json")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path,"r") as f:
|
||||
data = json.load(f)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
question = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
base64_image = encode_image(img_path)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {args.OPENAI_API_KEY}"
|
||||
}
|
||||
payload = {
|
||||
"model": args.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{question}"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 500
|
||||
}
|
||||
try:
|
||||
response = requests.post(args.API_BASE, headers=headers, json=payload)
|
||||
print(response.json())
|
||||
answer = response.json()['choices'][0]['message']['content']
|
||||
data[i]['predict'] = answer
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
except:
|
||||
time.sleep(100)
|
||||
print(f"{img_path} error")
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
115
OCRBench/scripts/Genimi.py
Normal file
115
OCRBench/scripts/Genimi.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pathlib
|
||||
import textwrap
|
||||
from argparse import ArgumentParser
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
from PIL import Image
|
||||
from IPython.display import display
|
||||
from IPython.display import Markdown
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import sys
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_path", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--GOOGLE_API_KEY", type=str, default="")
|
||||
parser.add_argument("--model", type=str, default="gemini-pro-vision")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
genai.configure(api_key=args.GOOGLE_API_KEY)
|
||||
model = genai.GenerativeModel(args.model)
|
||||
|
||||
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
|
||||
data_path = os.path.join(args.output_path,f"{args.model}.json")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
question = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
response = model.generate_content([question, img])
|
||||
data[i]['predict'] = response.text
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
except:
|
||||
print(f"{img_path}: API call failed.")
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
210
OCRBench/scripts/LLaVA1_5.py
Normal file
210
OCRBench/scripts/LLaVA1_5.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
from llava.conversation import conv_templates, SeparatorStyle
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.utils import disable_torch_init
|
||||
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
|
||||
|
||||
# https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa_loader.py
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="liuhaotian/llava-v1.5-7b")
|
||||
parser.add_argument("--model_base", type=str, default=None)
|
||||
parser.add_argument("--save_name", type=str, default="llava1_5_7b")
|
||||
parser.add_argument("--conv_mode", type=str, default="vicuna_v1")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top_p", type=float, default=None)
|
||||
parser.add_argument("--num_beams", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
device = f"cuda:{eval_id}"
|
||||
disable_torch_init()
|
||||
model_path = os.path.expanduser(args.model_path)
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model( model_path = model_path, model_base = args.model_base, model_name = model_name,device = device)
|
||||
if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
||||
args.conv_mode = args.conv_mode + '_mmtag'
|
||||
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
qs = qs+"\nAnswer the question using a single word or phrase."
|
||||
if model.config.mm_use_im_start_end:
|
||||
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
||||
else:
|
||||
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
||||
conv = conv_templates[args.conv_mode].copy()
|
||||
conv.append_message(conv.roles[0], qs)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
|
||||
stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2
|
||||
input_ids = input_ids.to(device=device, non_blocking=True)
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor.to(dtype=torch.float16, device=device, non_blocking=True),
|
||||
do_sample=True if args.temperature > 0 else False,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
num_beams=args.num_beams,
|
||||
max_new_tokens=128,
|
||||
use_cache=True)
|
||||
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
||||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
||||
outputs = outputs.strip()
|
||||
if outputs.endswith(stop_str):
|
||||
outputs = outputs[:-len(stop_str)]
|
||||
outputs = outputs.strip()
|
||||
|
||||
data[i]['predict'] = outputs
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
313
OCRBench/scripts/MiniMonkey.py
Normal file
313
OCRBench/scripts/MiniMonkey.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, CLIPImageProcessor
|
||||
from transformers import AutoTokenizer
|
||||
import torchvision.transforms as T
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
#https://github.com/Yuliang-Liu/Monkey/tree/main/project/mini_monkey
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def dynamic_preprocess(image, min_num=5, max_num=6, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images, target_aspect_ratio
|
||||
|
||||
def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
new_target_ratios = []
|
||||
if prior_aspect_ratio is not None:
|
||||
for i in target_ratios:
|
||||
if prior_aspect_ratio[0]%i[0] !=0 or prior_aspect_ratio[1]%i[1] !=0:
|
||||
new_target_ratios.append(i)
|
||||
else:
|
||||
continue
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
def load_image(image_file, input_size=448, min_num=1, max_num=6):
|
||||
image = Image.open(image_file).convert('RGB')
|
||||
transform = build_transform(input_size=input_size)
|
||||
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values, target_aspect_ratio
|
||||
|
||||
def load_image2(image_file, input_size=448, target_aspect_ratio=(1,1), min_num=1, max_num=6):
|
||||
image = Image.open(image_file).convert('RGB')
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess2(image, image_size=input_size, prior_aspect_ratio=target_aspect_ratio, use_thumbnail=True, min_num=min_num, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default='mx262/MiniMonkey')#TODO Set the address of your model's weights
|
||||
parser.add_argument("--save_name", type=str, default="MiniMokney") #TODO Set the name of the JSON file you save in the output_folder.
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
model = AutoModel.from_pretrained(
|
||||
checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True).eval().to(f'cuda:{eval_id}')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
|
||||
image_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
|
||||
pixel_values, target_aspect_ratio = load_image(image_path, min_num=12, max_num=24)
|
||||
pixel_values = pixel_values.to(f'cuda:{eval_id}').to(torch.bfloat16)
|
||||
pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=3, max_num=11)
|
||||
pixel_values2 = pixel_values2.to(f'cuda:{eval_id}').to(torch.bfloat16)
|
||||
pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0)
|
||||
|
||||
generation_config = dict(
|
||||
num_beams=1,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
)
|
||||
question = '<image>\n'+qs+ '\nAnswer the question using a single word or phrase.'
|
||||
response = model.chat(tokenizer, pixel_values, target_aspect_ratio, question, generation_config)
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
163
OCRBench/scripts/blip2.py
Normal file
163
OCRBench/scripts/blip2.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
import torch
|
||||
# https://huggingface.co/Salesforce/blip2-opt-6.7b
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="./model_weights/blip2-opt-6.7b")
|
||||
parser.add_argument("--save_name", type=str, default="blip2_opt_6_7b")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
processor = Blip2Processor.from_pretrained(args.model_path)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(args.model_path, load_in_8bit=False, device_map={"": eval_id}, torch_dtype=torch.float16)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
prompt = f"Question: {qs} Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=f"cuda:{eval_id}", dtype=torch.float16)
|
||||
generated_ids = model.generate(**inputs)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
data[i]['predict'] = generated_text
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
175
OCRBench/scripts/blip2_vicuna_instruct.py
Normal file
175
OCRBench/scripts/blip2_vicuna_instruct.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
|
||||
|
||||
# https://huggingface.co/Salesforce/instructblip-vicuna-7b
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="./model_weights/instructblip-vicuna-7b")
|
||||
parser.add_argument("--save_name", type=str, default="instructblip_vicuna_7b")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
device = f"cuda:{eval_id}"
|
||||
model = InstructBlipForConditionalGeneration.from_pretrained(args.model_path)
|
||||
processor = InstructBlipProcessor.from_pretrained(args.model_path)
|
||||
model.to(device)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
inputs = processor(images=image, text=qs, return_tensors="pt").to(device)
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
num_beams=5,
|
||||
max_length=100,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1.0,
|
||||
temperature=0,
|
||||
)
|
||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
data[i]['predict'] = generated_text
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
179
OCRBench/scripts/bliva.py
Normal file
179
OCRBench/scripts/bliva.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from bliva.models import load_model_and_preprocess
|
||||
import numpy as np
|
||||
|
||||
# https://github.com/mlpc-ucsd/BLIVA/blob/main/evaluate.py
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
"""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="bliva_vicuna")
|
||||
parser.add_argument("--save_name", type=str, default="bliva")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
device = f"cuda:{eval_id}"
|
||||
np.random.seed(0)
|
||||
disable_torch_init()
|
||||
if "vicuna" in args.model_path.lower():
|
||||
print("load bliva-vicuna")
|
||||
model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="vicuna7b", is_eval=True, device=device)
|
||||
if "flant5xxl" in args.model_path.lower():
|
||||
print("load bliva-flant5xxl")
|
||||
model, vis_processors, _ = load_model_and_preprocess(name=args.model_path, model_type="flant5xxl", is_eval=True, device=device)
|
||||
vis_processor = vis_processors["eval"]
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
question = [qs]
|
||||
image = vis_processor(image).unsqueeze(0).to(device)
|
||||
outputs = model.generate({"image": image, "prompt": qs}, max_length=150)
|
||||
data[i]['predict'] = outputs[0].split('### Assistant:')[0]
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
161
OCRBench/scripts/interlm.py
Normal file
161
OCRBench/scripts/interlm.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
# https://github.com/InternLM/InternLM-XComposer/tree/main/InternLM-XComposer-1.0
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer-7b')#TODO Set the address of your model's weights
|
||||
parser.add_argument("--save_name", type=str, default="internlm-xcomposer-7b") #TODO Set the name of the JSON file you save in the output_folder.
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# init model and tokenizer
|
||||
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
model.tokenizer = tokenizer
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
response = model.generate(qs, img_path)
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
162
OCRBench/scripts/interlm2.py
Normal file
162
OCRBench/scripts/interlm2.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
#https://github.com/InternLM/InternLM-XComposer/tree/main
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default='internlm/internlm-xcomposer2-vl-7b')#TODO Set the address of your model's weights
|
||||
parser.add_argument("--save_name", type=str, default="internlm-xcomposer2-vl-7b") #TODO Set the name of the JSON file you save in the output_folder.
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# init model and tokenizer
|
||||
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True,device_map=f'cuda:{eval_id}').eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
text = f'<ImageHere>{qs}'
|
||||
with torch.cuda.amp.autocast():
|
||||
response, _ = model.chat(tokenizer, query=text, image=img_path, history=[], do_sample=False)
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
247
OCRBench/scripts/internvl2_s
Normal file
247
OCRBench/scripts/internvl2_s
Normal file
@@ -0,0 +1,247 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
def load_image(image_file, input_size=448, max_num=12):
|
||||
image = Image.open(image_file).convert('RGB')
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default='OpenGVLab/InternVL2-4B')
|
||||
parser.add_argument("--save_name", type=str, default="internvl2-4B")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
model = AutoModel.from_pretrained(
|
||||
checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True).eval().to(f'cuda:{eval_id}')
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, use_fast=False)
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
pixel_values = load_image(img_path, max_num=12).to(torch.bfloat16).to(f'cuda:{eval_id}')
|
||||
generation_config = dict(max_new_tokens=1024, do_sample=False)
|
||||
question = f'<image>\n{qs}'
|
||||
response = model.chat(tokenizer, pixel_values, question, generation_config)
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
# eval_worker(args, data_list[0], 0, output_queue)
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
178
OCRBench/scripts/intervl.py
Normal file
178
OCRBench/scripts/intervl.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, CLIPImageProcessor
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
#https://github.com/OpenGVLab/InternVL
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default='OpenGVLab/InternVL-Chat-Chinese-V1-1')#TODO Set the address of your model's weights
|
||||
parser.add_argument("--save_name", type=str, default="InternVL-Chat-Chinese-V1-1") #TODO Set the name of the JSON file you save in the output_folder.
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
model = AutoModel.from_pretrained(
|
||||
checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
device_map='cuda').eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
image = image.resize((448, 448))
|
||||
image_processor = CLIPImageProcessor.from_pretrained(checkpoint)
|
||||
|
||||
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
|
||||
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
||||
|
||||
generation_config = dict(
|
||||
num_beams=1,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
)
|
||||
response = model.chat(tokenizer, pixel_values, qs, generation_config)
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
330
OCRBench/scripts/llavar.py
Normal file
330
OCRBench/scripts/llavar.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from llava import LlavaLlamaForCausalLM
|
||||
from llava.conversation import conv_templates
|
||||
from llava import conversation as conversation_lib
|
||||
from llava.utils import disable_torch_init
|
||||
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
||||
from PIL import Image,ImageOps
|
||||
# https://github.com/SALT-NLP/LLaVAR/blob/main/LLaVA/llava/eval/model_vqa.py
|
||||
|
||||
def resize_image(image, target_size):
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
if aspect_ratio > 1:
|
||||
new_width = target_size[0]
|
||||
new_height = int(new_width / aspect_ratio)
|
||||
else:
|
||||
new_height = target_size[1]
|
||||
new_width = int(new_height * aspect_ratio)
|
||||
image = image.resize((new_width, new_height))
|
||||
width_diff = target_size[0] - image.size[0]
|
||||
height_diff = target_size[1] - image.size[1]
|
||||
left_padding = 0
|
||||
top_padding = 0
|
||||
right_padding = width_diff - left_padding
|
||||
bottom_padding = height_diff - top_padding
|
||||
padded_image = ImageOps.expand(image, border=(left_padding, top_padding, right_padding, bottom_padding), fill=0)
|
||||
return padded_image
|
||||
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
def patch_config(config):
|
||||
patch_dict = {
|
||||
"use_mm_proj": True,
|
||||
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
||||
"mm_hidden_size": 1024
|
||||
}
|
||||
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if not hasattr(cfg, "mm_vision_tower"):
|
||||
print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
|
||||
for k, v in patch_dict.items():
|
||||
setattr(cfg, k, v)
|
||||
cfg.save_pretrained(config)
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="./model_weights/LLaVar")
|
||||
parser.add_argument("--save_name", type=str, default="llavar")
|
||||
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
||||
parser.add_argument("--mm-projector", type=str, default=None)
|
||||
parser.add_argument("--vision-tower", type=str, default=None)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
device = f"cuda:{eval_id}"
|
||||
disable_torch_init()
|
||||
model_name = os.path.expanduser(args.model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if args.mm_projector is None:
|
||||
patch_config(model_name)
|
||||
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
|
||||
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
|
||||
|
||||
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
if mm_use_im_start_end:
|
||||
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
|
||||
vision_tower = model.model.vision_tower[0]
|
||||
vision_tower.to(device=device, dtype=torch.float16)
|
||||
vision_config = vision_tower.config
|
||||
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
vision_config.use_im_start_end = mm_use_im_start_end
|
||||
if mm_use_im_start_end:
|
||||
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
else:
|
||||
# in case of using a pretrained model with only a MLP projector weights
|
||||
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
|
||||
|
||||
vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).to(device)
|
||||
image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
|
||||
|
||||
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
if mm_use_im_start_end:
|
||||
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
|
||||
vision_config = vision_tower.config
|
||||
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
vision_config.use_im_start_end = mm_use_im_start_end
|
||||
if mm_use_im_start_end:
|
||||
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
|
||||
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
||||
|
||||
mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size)
|
||||
mm_projector_weights = torch.load(args.mm_projector, map_location='cpu')
|
||||
mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
|
||||
|
||||
model.model.mm_projector = mm_projector.to(device).half()
|
||||
model.model.vision_tower = [vision_tower]
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
# qs = qs+"\nAnswer the question using a single word or phrase."
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
if mm_use_im_start_end:
|
||||
qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
|
||||
else:
|
||||
qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
||||
if args.conv_mode == 'simple_legacy':
|
||||
qs += '\n\n### Response:'
|
||||
# conv = default_conversation.copy()
|
||||
conv = conv_templates[args.conv_mode].copy()
|
||||
conv.append_message(conv.roles[0], qs)
|
||||
# modified
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
inputs = tokenizer([prompt])
|
||||
image = Image.open(img_path)
|
||||
# if "REval" in args.image_folder:
|
||||
image = resize_image(image, (336, 336))
|
||||
|
||||
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
|
||||
input_ids = torch.as_tensor(inputs.input_ids).to(device)
|
||||
|
||||
# new stopping implementation
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
|
||||
# keywords = ['###']
|
||||
# modified
|
||||
keywords = ['</s>']
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor.unsqueeze(0).half().to(device),
|
||||
do_sample=False,
|
||||
temperature=0,
|
||||
max_new_tokens=200,
|
||||
stopping_criteria=[stopping_criteria])
|
||||
input_token_len = input_ids.shape[1]
|
||||
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
||||
if n_diff_input_output > 0:
|
||||
print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
|
||||
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
||||
|
||||
# modified
|
||||
if args.conv_mode == 'simple_legacy' or args.conv_mode == 'simple':
|
||||
while True:
|
||||
cur_len = len(outputs)
|
||||
outputs = outputs.strip()
|
||||
for pattern in ['###', 'Assistant:', 'Response:']:
|
||||
if outputs.startswith(pattern):
|
||||
outputs = outputs[len(pattern):].strip()
|
||||
if len(outputs) == cur_len:
|
||||
break
|
||||
|
||||
if conv.sep_style == conversation_lib.SeparatorStyle.TWO:
|
||||
sep = conv.sep2
|
||||
else:
|
||||
sep = conv.sep
|
||||
|
||||
try:
|
||||
index = outputs.index(sep)
|
||||
except ValueError:
|
||||
outputs += sep
|
||||
index = outputs.index(sep)
|
||||
outputs = outputs[:index].strip()
|
||||
data[i]['predict'] = outputs
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
329
OCRBench/scripts/mPLUG-DocOwl15.py
Normal file
329
OCRBench/scripts/mPLUG-DocOwl15.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from multiprocessing import Manager, Pool, Queue
|
||||
|
||||
import torch
|
||||
from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
||||
from mplug_docowl.conversation import conv_templates
|
||||
from mplug_docowl.mm_utils import (
|
||||
KeywordsStoppingCriteria,
|
||||
get_model_name_from_path,
|
||||
process_images,
|
||||
tokenizer_image_token,
|
||||
)
|
||||
from mplug_docowl.model.builder import load_pretrained_model
|
||||
from mplug_docowl.processor import DocProcessor
|
||||
from tqdm import tqdm
|
||||
from transformers import TextStreamer
|
||||
|
||||
|
||||
# https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i * avg : (i + 1) * avg])
|
||||
result.append(lst[(n - 1) * avg :])
|
||||
return result
|
||||
|
||||
|
||||
def save_json(json_list, save_path):
|
||||
with open(save_path, "w", encoding="utf-8") as file:
|
||||
json.dump(json_list, file, indent=4)
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="mPLUG/DocOwl1.5")
|
||||
parser.add_argument("--save_name", type=str, default="mplug-DocOwl1.5")
|
||||
parser.add_argument("--conv_mode", type=str, default="mplug_owl2")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
OCRBench_score = {
|
||||
"Regular Text Recognition": 0,
|
||||
"Irregular Text Recognition": 0,
|
||||
"Artistic Text Recognition": 0,
|
||||
"Handwriting Recognition": 0,
|
||||
"Digit String Recognition": 0,
|
||||
"Non-Semantic Text Recognition": 0,
|
||||
"Scene Text-centric VQA": 0,
|
||||
"Doc-oriented VQA": 0,
|
||||
"Key Information Extraction": 0,
|
||||
"Handwritten Mathematical Expression Recognition": 0,
|
||||
}
|
||||
AllDataset_score = {
|
||||
"IIIT5K": 0,
|
||||
"svt": 0,
|
||||
"IC13_857": 0,
|
||||
"IC15_1811": 0,
|
||||
"svtp": 0,
|
||||
"ct80": 0,
|
||||
"cocotext": 0,
|
||||
"ctw": 0,
|
||||
"totaltext": 0,
|
||||
"HOST": 0,
|
||||
"WOST": 0,
|
||||
"WordArt": 0,
|
||||
"IAM": 0,
|
||||
"ReCTS": 0,
|
||||
"ORAND": 0,
|
||||
"NonSemanticText": 0,
|
||||
"SemanticText": 0,
|
||||
"STVQA": 0,
|
||||
"textVQA": 0,
|
||||
"ocrVQA": 0,
|
||||
"ESTVQA": 0,
|
||||
"ESTVQA_cn": 0,
|
||||
"docVQA": 0,
|
||||
"infographicVQA": 0,
|
||||
"ChartQA": 0,
|
||||
"ChartQA_Human": 0,
|
||||
"FUNSD": 0,
|
||||
"SROIE": 0,
|
||||
"POIE": 0,
|
||||
"HME100k": 0,
|
||||
}
|
||||
num_all = {
|
||||
"IIIT5K": 0,
|
||||
"svt": 0,
|
||||
"IC13_857": 0,
|
||||
"IC15_1811": 0,
|
||||
"svtp": 0,
|
||||
"ct80": 0,
|
||||
"cocotext": 0,
|
||||
"ctw": 0,
|
||||
"totaltext": 0,
|
||||
"HOST": 0,
|
||||
"WOST": 0,
|
||||
"WordArt": 0,
|
||||
"IAM": 0,
|
||||
"ReCTS": 0,
|
||||
"ORAND": 0,
|
||||
"NonSemanticText": 0,
|
||||
"SemanticText": 0,
|
||||
"STVQA": 0,
|
||||
"textVQA": 0,
|
||||
"ocrVQA": 0,
|
||||
"ESTVQA": 0,
|
||||
"ESTVQA_cn": 0,
|
||||
"docVQA": 0,
|
||||
"infographicVQA": 0,
|
||||
"ChartQA": 0,
|
||||
"ChartQA_Human": 0,
|
||||
"FUNSD": 0,
|
||||
"SROIE": 0,
|
||||
"POIE": 0,
|
||||
"HME100k": 0,
|
||||
}
|
||||
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
model_name = get_model_name_from_path(args.model_path)
|
||||
tokenizer, model, _, _ = load_pretrained_model(
|
||||
args.model_path,
|
||||
None,
|
||||
model_name,
|
||||
load_8bit=False,
|
||||
load_4bit=False,
|
||||
device=f"cuda:{eval_id}",
|
||||
)
|
||||
|
||||
doc_image_processor = DocProcessor(
|
||||
image_size=448,
|
||||
anchors="grid_9",
|
||||
add_global_img=True,
|
||||
add_textual_crop_indicator=True,
|
||||
)
|
||||
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]["image_path"])
|
||||
qs = data[i]["question"]
|
||||
if data[i].get("predict", 0) != 0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
|
||||
image_tensor, patch_positions, text = doc_image_processor(
|
||||
images=img_path, query="<|image|>" + qs
|
||||
)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
patch_positions = patch_positions.to(model.device)
|
||||
|
||||
conv = conv_templates["mplug_owl2"].copy()
|
||||
conv.append_message(conv.roles[0], text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
input_ids = (
|
||||
tokenizer_image_token(
|
||||
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.to(model.device)
|
||||
)
|
||||
|
||||
stop_str = conv.sep2
|
||||
keywords = [stop_str]
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
patch_positions=patch_positions,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
max_new_tokens=512,
|
||||
streamer=streamer,
|
||||
use_cache=True,
|
||||
stopping_criteria=[stopping_criteria],
|
||||
)
|
||||
|
||||
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
|
||||
data[i]["predict"] = outputs
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method("spawn")
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder, f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder, f"{args.save_name}.json")
|
||||
print(
|
||||
f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}."
|
||||
)
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
# pool.apply(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get("predict", 0) == 0:
|
||||
continue
|
||||
predict = data[i]["predict"]
|
||||
data[i]["result"] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers) == list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n", " ").replace(" ", "")
|
||||
predict = predict.strip().replace("\n", " ").replace(" ", "")
|
||||
if answer in predict:
|
||||
data[i]["result"] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n", " ").replace(" ", "")
|
||||
predict = predict.strip().replace("\n", " ").replace(" ", "")
|
||||
if answers in predict:
|
||||
data[i]["result"] = 1
|
||||
else:
|
||||
if type(answers) == list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n", " ")
|
||||
predict = predict.lower().strip().replace("\n", " ")
|
||||
if answer in predict:
|
||||
data[i]["result"] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n", " ")
|
||||
predict = predict.lower().strip().replace("\n", " ")
|
||||
if answers in predict:
|
||||
data[i]["result"] = 1
|
||||
save_json(data, os.path.join(args.output_folder, f"{args.save_name}.json"))
|
||||
if len(data) == 1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result", 100) == 100:
|
||||
continue
|
||||
OCRBench_score[data[i]["type"]] += data[i]["result"]
|
||||
recognition_score = (
|
||||
OCRBench_score["Regular Text Recognition"]
|
||||
+ OCRBench_score["Irregular Text Recognition"]
|
||||
+ OCRBench_score["Artistic Text Recognition"]
|
||||
+ OCRBench_score["Handwriting Recognition"]
|
||||
+ OCRBench_score["Digit String Recognition"]
|
||||
+ OCRBench_score["Non-Semantic Text Recognition"]
|
||||
)
|
||||
Final_score = (
|
||||
recognition_score
|
||||
+ OCRBench_score["Scene Text-centric VQA"]
|
||||
+ OCRBench_score["Doc-oriented VQA"]
|
||||
+ OCRBench_score["Key Information Extraction"]
|
||||
+ OCRBench_score["Handwritten Mathematical Expression Recognition"]
|
||||
)
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(
|
||||
f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}"
|
||||
)
|
||||
print(
|
||||
f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}"
|
||||
)
|
||||
print(
|
||||
f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}"
|
||||
)
|
||||
print(
|
||||
f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}"
|
||||
)
|
||||
print(
|
||||
f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}"
|
||||
)
|
||||
print(
|
||||
f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}"
|
||||
)
|
||||
print("----------------------------------------------------------------")
|
||||
print(
|
||||
f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}"
|
||||
)
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(
|
||||
f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}"
|
||||
)
|
||||
print("----------------------------------------------------------------")
|
||||
print(
|
||||
f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}"
|
||||
)
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]["dataset_name"]] += 1
|
||||
if data[i].get("result", 100) == 100:
|
||||
continue
|
||||
AllDataset_score[data[i]["dataset_name"]] += data[i]["result"]
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
185
OCRBench/scripts/mPLUG-owl.py
Normal file
185
OCRBench/scripts/mPLUG-owl.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
import sys
|
||||
sys.path.append("./scripts/mPLUG-Owl/mPLUG-Owl/")
|
||||
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
|
||||
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
|
||||
from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
|
||||
|
||||
# https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl")
|
||||
parser.add_argument("--save_name", type=str, default="mplug-owl")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
pretrained_ckpt = args.model_path
|
||||
model = MplugOwlForConditionalGeneration.from_pretrained(
|
||||
pretrained_ckpt,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model.to(f"cuda:{eval_id}")
|
||||
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
|
||||
tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
|
||||
processor = MplugOwlProcessor(image_processor, tokenizer)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
prompts = [
|
||||
f'''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
||||
Human: <image>
|
||||
Human: {qs}
|
||||
AI: ''']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
generate_kwargs = {
|
||||
'do_sample': False,
|
||||
'top_k': 1,
|
||||
'max_length': 100
|
||||
}
|
||||
images = [Image.open(img_path)]
|
||||
inputs = processor(text=prompts, images=images, return_tensors='pt')
|
||||
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
res = model.generate(**inputs, **generate_kwargs)
|
||||
sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
|
||||
data[i]['predict'] = sentence
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
191
OCRBench/scripts/mPLUG-owl2.py
Normal file
191
OCRBench/scripts/mPLUG-owl2.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import TextStreamer
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from mplug_owl2.conversation import conv_templates, SeparatorStyle
|
||||
from mplug_owl2.model.builder import load_pretrained_model
|
||||
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
||||
|
||||
# https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="./model_weights/mplug-owl2")
|
||||
parser.add_argument("--save_name", type=str, default="mplug-owl2")
|
||||
parser.add_argument("--conv_mode", type=str, default="mplug_owl2")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
model_name = get_model_name_from_path(args.model_path)
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, None, model_name, load_8bit=False, load_4bit=False, device=f"cuda:{eval_id}")
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
conv = conv_templates[args.conv_mode].copy()
|
||||
roles = conv.roles
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
|
||||
image = image.resize((max_edge, max_edge))
|
||||
image_tensor = process_images([image], image_processor)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
inp = DEFAULT_IMAGE_TOKEN + qs
|
||||
conv.append_message(conv.roles[0], inp)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
stop_str = conv.sep2
|
||||
keywords = [stop_str]
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=False,
|
||||
temperature=args.temperature,
|
||||
max_new_tokens=100,
|
||||
streamer=streamer,
|
||||
use_cache=True,
|
||||
stopping_criteria=[stopping_criteria])
|
||||
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||
data[i]['predict'] = outputs
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
187
OCRBench/scripts/minigpt4v2.py
Normal file
187
OCRBench/scripts/minigpt4v2.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
import sys
|
||||
sys.path.append("./scripts/MiniGPT-4/")
|
||||
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
||||
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
|
||||
from minigpt4.common.config import Config
|
||||
import random
|
||||
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/eval_scripts/eval_vqa.py
|
||||
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--cfg-path", default='./scripts/MiniGPT-4/eval_configs/minigptv2_eval.yaml')
|
||||
parser.add_argument("--save_name", type=str, default="minigptv2")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
|
||||
print(f"Process {eval_id} start.")
|
||||
device = f'cuda:{eval_id}'
|
||||
cfg = Config(args)
|
||||
model, vis_processor = init_model(args, device)
|
||||
conv_temp = CONV_VISION_minigptv2.copy()
|
||||
conv_temp.system = ""
|
||||
model.eval()
|
||||
instruction_pool = [
|
||||
"[vqa] {}"
|
||||
]
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = vis_processor(image)
|
||||
image = image.unsqueeze(0).to(device)
|
||||
# question = self.text_processor(qs)
|
||||
instruction = random.choice(instruction_pool).format(qs)
|
||||
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
||||
texts = prepare_texts(instruction, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(image, texts, max_new_tokens=100, do_sample=False)
|
||||
data[i]['predict'] = answers[0]
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
182
OCRBench/scripts/monkey.py
Normal file
182
OCRBench/scripts/monkey.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# https://github.com/Yuliang-Liu/Monkey
|
||||
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="echo840/Monkey")
|
||||
parser.add_argument("--save_name", type=str, default="monkey")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eod_id
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
query = f'<img>{img_path}</img> {qs} Answer: '
|
||||
|
||||
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
|
||||
attention_mask = input_ids.attention_mask
|
||||
input_ids = input_ids.input_ids
|
||||
|
||||
pred = model.generate(
|
||||
input_ids=input_ids.to(f'cuda:{eval_id}'),
|
||||
attention_mask=attention_mask.to(f'cuda:{eval_id}'),
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=100,
|
||||
min_new_tokens=1,
|
||||
length_penalty=1,
|
||||
num_return_sequences=1,
|
||||
output_hidden_states=True,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.eod_id,
|
||||
eos_token_id=tokenizer.eod_id,
|
||||
)
|
||||
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
181
OCRBench/scripts/qwenvl.py
Normal file
181
OCRBench/scripts/qwenvl.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import math
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool, Queue, Manager
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/evaluate_vqa.py
|
||||
def split_list(lst, n):
|
||||
length = len(lst)
|
||||
avg = length // n # 每份的大小
|
||||
result = [] # 存储分割后的子列表
|
||||
for i in range(n - 1):
|
||||
result.append(lst[i*avg:(i+1)*avg])
|
||||
result.append(lst[(n-1)*avg:])
|
||||
return result
|
||||
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./OCRBench_Images")
|
||||
parser.add_argument("--output_folder", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model_path", type=str, default="Qwen/Qwen-VL")
|
||||
parser.add_argument("--save_name", type=str, default="qwenvl")
|
||||
parser.add_argument("--num_workers", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
|
||||
def eval_worker(args, data, eval_id, output_queue):
|
||||
print(f"Process {eval_id} start.")
|
||||
checkpoint = args.model_path
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
checkpoint, device_map=f'cuda:{eval_id}', trust_remote_code=True).eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint,
|
||||
trust_remote_code=True)
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eod_id
|
||||
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
qs = data[i]['question']
|
||||
# query = f'<img>{img_path}</img> {qs} Answer: '
|
||||
query = f'<img>{img_path}</img>{qs} Answer:'
|
||||
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
|
||||
attention_mask = input_ids.attention_mask
|
||||
input_ids = input_ids.input_ids
|
||||
|
||||
pred = model.generate(
|
||||
input_ids=input_ids.to(f'cuda:{eval_id}'),
|
||||
attention_mask=attention_mask.to(f'cuda:{eval_id}'),
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=100,
|
||||
min_new_tokens=1,
|
||||
length_penalty=1,
|
||||
num_return_sequences=1,
|
||||
output_hidden_states=True,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.eod_id,
|
||||
eos_token_id=tokenizer.eod_id,
|
||||
)
|
||||
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
|
||||
data[i]['predict'] = response
|
||||
output_queue.put({eval_id: data})
|
||||
print(f"Process {eval_id} has completed.")
|
||||
|
||||
if __name__=="__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
args = _get_args()
|
||||
if os.path.exists(os.path.join(args.output_folder,f"{args.save_name}.json")):
|
||||
data_path = os.path.join(args.output_folder,f"{args.save_name}.json")
|
||||
print(f"output_path:{data_path} exist! Only generate the results that were not generated in {data_path}.")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
data_list = split_list(data, args.num_workers)
|
||||
|
||||
output_queue = Manager().Queue()
|
||||
|
||||
pool = Pool(processes=args.num_workers)
|
||||
for i in range(len(data_list)):
|
||||
pool.apply_async(eval_worker, args=(args, data_list[i], i, output_queue))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
results = {}
|
||||
while not output_queue.empty():
|
||||
result = output_queue.get()
|
||||
results.update(result)
|
||||
data = []
|
||||
for i in range(len(data_list)):
|
||||
data.extend(results[i])
|
||||
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_folder,f"{args.save_name}.json"))
|
||||
if len(data)==1000:
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
num_all[data[i]['dataset_name']] += 1
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
AllDataset_score[data[i]['dataset_name']] += data[i]['result']
|
||||
for key in AllDataset_score.keys():
|
||||
print(f"{key}: {AllDataset_score[key]/float(num_all[key])}")
|
138
OCRBench/scripts/qwenvl_api.py
Normal file
138
OCRBench/scripts/qwenvl_api.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import pathlib
|
||||
from argparse import ArgumentParser
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import sys
|
||||
from http import HTTPStatus
|
||||
from dashscope import MultiModalConversation
|
||||
import time
|
||||
# You should follow the instructions here befor strat: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start
|
||||
OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
|
||||
"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
|
||||
"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
|
||||
AllDataset_score = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
num_all = {"IIIT5K":0,"svt":0,"IC13_857":0,"IC15_1811":0,"svtp":0,"ct80":0,"cocotext":0,"ctw":0,"totaltext":0,"HOST":0,"WOST":0,"WordArt":0,"IAM":0,"ReCTS":0,"ORAND":0,"NonSemanticText":0,"SemanticText":0,
|
||||
"STVQA":0,"textVQA":0,"ocrVQA":0,"ESTVQA":0,"ESTVQA_cn":0,"docVQA":0,"infographicVQA":0,"ChartQA":0,"ChartQA_Human":0,"FUNSD":0,"SROIE":0,"POIE":0,"HME100k":0}
|
||||
def save_json(json_list,save_path):
|
||||
with open(save_path, 'w') as file:
|
||||
json.dump(json_list, file,indent=4)
|
||||
|
||||
def call_with_local_file(img_path, question, model_name):
|
||||
"""Sample of use local file.
|
||||
linux&mac file schema: file:///home/images/test.png
|
||||
windows file schema: file://D:/images/abc.png
|
||||
"""
|
||||
local_file_path1 = f'file://{img_path}'
|
||||
messages = [{
|
||||
'role': 'system',
|
||||
'content': [{
|
||||
'text': 'You are a helpful assistant.'
|
||||
}]
|
||||
}, {
|
||||
'role':
|
||||
'user',
|
||||
'content': [
|
||||
{
|
||||
'image': local_file_path1
|
||||
},
|
||||
{
|
||||
'text': question
|
||||
},
|
||||
]
|
||||
}]
|
||||
response = MultiModalConversation.call(model=model_name, messages=messages)
|
||||
# time.sleep(2) #For qwenvl-max you may need to add this line to avoid the limits.
|
||||
print(response)
|
||||
return response['output']['choices'][0]["message"]['content'][0]['text']
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--image_folder", type=str, default="./data")
|
||||
parser.add_argument("--output_path", type=str, default="./results")
|
||||
parser.add_argument("--OCRBench_file", type=str, default="./OCRBench/OCRBench.json")
|
||||
parser.add_argument("--model", type=str, default="qwen-vl-max")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
if os.path.exists(os.path.join(args.output_path,f"{args.model}.json")):
|
||||
data_path = os.path.join(args.output_path,f"{args.model}.json")
|
||||
else:
|
||||
data_path = args.OCRBench_file
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in tqdm(range(len(data))):
|
||||
img_path = os.path.join(args.image_folder, data[i]['image_path'])
|
||||
question = data[i]['question']
|
||||
if data[i].get("predict", 0)!=0:
|
||||
print(f"{img_path} predict exist, continue.")
|
||||
continue
|
||||
try:
|
||||
response = call_with_local_file(img_path, question, args.model)
|
||||
data[i]['predict'] = response
|
||||
except:
|
||||
print("QwenVL api failed")
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
for i in range(len(data)):
|
||||
data_type = data[i]["type"]
|
||||
dataset_name = data[i]["dataset_name"]
|
||||
answers = data[i]["answers"]
|
||||
if data[i].get('predict',0)==0:
|
||||
continue
|
||||
predict = data[i]['predict']
|
||||
data[i]['result'] = 0
|
||||
if dataset_name == "HME100k":
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.strip().replace("\n"," ").replace(" ","")
|
||||
predict = predict.strip().replace("\n"," ").replace(" ","")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
if type(answers)==list:
|
||||
for j in range(len(answers)):
|
||||
answer = answers[j].lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answer in predict:
|
||||
data[i]['result'] = 1
|
||||
else:
|
||||
answers = answers.lower().strip().replace("\n"," ")
|
||||
predict = predict.lower().strip().replace("\n"," ")
|
||||
if answers in predict:
|
||||
data[i]['result'] = 1
|
||||
save_json(data, os.path.join(args.output_path,f"{args.model}.json"))
|
||||
for i in range(len(data)):
|
||||
if data[i].get("result",100)==100:
|
||||
continue
|
||||
OCRBench_score[data[i]['type']] += data[i]['result']
|
||||
recognition_score = OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
|
||||
Final_score = recognition_score+OCRBench_score['Scene Text-centric VQA']+OCRBench_score['Doc-oriented VQA']+OCRBench_score['Key Information Extraction']+OCRBench_score['Handwritten Mathematical Expression Recognition']
|
||||
print("###########################OCRBench##############################")
|
||||
print(f"Text Recognition(Total 300):{recognition_score}")
|
||||
print("------------------Details of Recognition Score-------------------")
|
||||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}")
|
||||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}")
|
||||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}")
|
||||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}")
|
||||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}")
|
||||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}")
|
||||
print("----------------------------------------------------------------")
|
||||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}")
|
||||
print("----------------------Final Score-------------------------------")
|
||||
print(f"Final Score(Total 1000): {Final_score}")
|
Reference in New Issue
Block a user