Files
MultimodalOCR/OCRBench_v2/eval_scripts/vqa_metric.py
2025-07-05 10:32:27 +08:00

283 lines
9.3 KiB
Python

import re
import os
import json
import ipdb
import math
import numpy as np
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def vqa_evaluation(predict, answers):
score = 0
if type(answers)==list:
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].lower().strip().replace("\n"," ")
except:
ipdb.set_trace()
if isinstance(predict, (int, float)):
predict = str(predict)
predict = predict.lower().strip().replace("\n"," ")
if len(answer.split()) < 5:
if answer in predict:
score = 1
else:
dist = levenshtein_distance(predict, answer)
length = max(len(predict), len(answer))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if len(answers.split()) < 5:
if answers in predict:
score = 1
else:
dist = levenshtein_distance(predict, answers)
length = max(len(predict), len(answers))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
return score
def cn_vqa_evaluation(predict, answers):
score = 0
if type(answers)==list:
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].lower().strip().replace("\n"," ").replace(" ", "")
except:
ipdb.set_trace()
if isinstance(predict, (int, float)):
predict = str(predict)
predict = predict.lower().strip().replace("\n"," ").replace(" ", "")
if len(answer.split(",")) < 4:
if answer in predict:
score = 1
else:
dist = levenshtein_distance(predict, answer)
length = max(len(predict), len(answer))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
else:
answers = answers.lower().strip().replace("\n"," ").replace(" ", "")
predict = predict.lower().strip().replace("\n"," ").replace(" ", "")
if len(answer.split(",")) < 4:
if answers in predict:
score = 1
else:
dist = levenshtein_distance(predict, answers)
length = max(len(predict), len(answers))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
return score
def vqa_evaluation_case_sensitive(predict, answers):
score = 0
if type(answers)==list:
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].strip().replace("\n"," ")
except:
ipdb.set_trace()
predict = predict.strip().replace("\n"," ")
if len(answer.split()) < 5:
if answer in predict:
score = 1
else:
dist = levenshtein_distance(predict, answer)
length = max(len(predict), len(answer))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
else:
answers = answers.strip().replace("\n"," ")
predict = predict.strip().replace("\n"," ")
if len(answers.split()) < 5:
if answers in predict:
score = 1
else:
dist = levenshtein_distance(predict, answers)
length = max(len(predict), len(answers))
ANLS_value = 0.0 if length == 0 else float(dist) / float(length)
ANLS_value = 1 - ANLS_value
if ANLS_value >= 0.5 and ANLS_value > score:
score = ANLS_value
return score
def extract_first_number(string):
match = re.search(r'\d+', string)
if match:
return int(match.group())
return None
def counting_evaluation(predict, answers, eval_method):
score = 0
if isinstance(predict, str):
predict_processed = predict.lower().strip().replace("\n", " ")
elif math.isnan(predict):
return 0
else:
predict_processed = int(predict)
if type(answers)==list:
temp_score = 0
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
answer = answers[j].lower().strip().replace("\n"," ")
if eval_method == "exact match":
if answer in predict:
score = 1
else:
score = 0
elif eval_method == "regression":
predict_number = extract_first_number(predict_processed)
if predict_number:
answer = int(answer)
if predict_number <= 0 or predict_number >= 2 * answer:
score = 0
else:
iou = 1 - abs(predict_number - answer) / answer
if iou > 0.5:
score = iou
else:
score = 0
else:
score = 0
if score > temp_score:
temp_score = score
score = temp_score
else:
answers = answers.lower().strip().replace("\n"," ")
predict = predict.lower().strip().replace("\n"," ")
if eval_method == "exact match":
if answer in predict:
score = 1
else:
score = 0
elif eval_method == "regression":
predict = extract_first_number(predict)
if predict:
answer = int(answer)
if predict <= 0 or predict >= 2 * answer:
score = 0
else:
iou = 1 - abs(predict - answer) / answer
if iou > 0.5:
score = iou
else:
score = 0
else:
score = 0
return score
def math_expression_evaluation(predict, answers):
score = 0
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
score = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
score = 1
return score
def remove_text_tags(latex_str):
"""
Removes LaTeX \text{...} tags while keeping their content.
:param latex_str: A string containing LaTeX expressions
:return: The processed string with \text{...} tags removed
"""
pattern = r'\\text\{([^{}]*)\}'
processed_str = re.sub(pattern, r'\1', latex_str)
return processed_str
def cn_math_expression_evaluation(predict, answers):
score = 0
assert len(answers) == 1
answers = [remove_text_tags(answers[0])]
predict = remove_text_tags(predict)
if type(answers)==list:
for j in range(len(answers)):
answer = answers[j].strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answer in predict:
score = 1
else:
answers = answers.strip().replace("\n"," ").replace(" ","")
predict = predict.strip().replace("\n"," ").replace(" ","")
if answers in predict:
score = 1
return score
if __name__ == "__main__":
test_predict = "apple pie and banana"
test_answers = ["apple", "banana pie", "apple pie and orange"]
vqa_score = vqa_evaluation(test_predict, test_answers)
print(f"VQA evaluation score for predict '{test_predict}' and answers {test_answers}: {vqa_score}")