Files
MultimodalOCR/OCRBench_v2/eval_scripts/TEDS_metric.py
2024-12-30 19:30:31 +08:00

932 lines
36 KiB
Python

# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.
import re
import ast
import json
import ipdb
import distance
from apted import APTED, Config
from itertools import product
from apted.helpers import Tree
from lxml import etree, html
from collections import deque
from parallel import parallel_process
from tqdm import tqdm
from zss import simple_distance, Node
import string
from typing import Any, Callable, Optional, Sequence
import numpy as np
import Levenshtein
import editdistance
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.
class TEDS(object):
''' Tree Edit Distance basead Similarity
'''
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
''' Tokenizes table cells
'''
self.__tokens__.append('<%s>' % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != 'unk':
self.__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
''' Converts HTML tree to the format required by apted
'''
global __tokens__
if node.tag == 'td':
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(node.tag,
int(node.attrib.get('colspan', '1')),
int(node.attrib.get('rowspan', '1')),
cell, *deque())
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
#print("pred:",pred)
#print("true:",true)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, pred_json, true_json):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
samples = true_json.keys()
if self.n_jobs == 1:
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
else:
#inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]} for filename in samples]
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
scores = dict(zip(samples, scores))
return scores
def convert_table_to_html_str(table_row_list=[]):
"""
Given a list of table rows, build the corresponding html string, which is used to compute the TEDS score.
We use the official code of PubTabNet to compute TEDS score, it does not consider '<th>' label.
We also remove unneccessary spaces within a table cell and extra '\n' as they will influence the TEDS score.
"""
html_table_str = "<html><body><table>" + '\n'
for data_row in table_row_list:
html_table_str += "<tr>"
for cell_str in data_row:
html_table_str += f"<td>{cell_str}</td>"
html_table_str += "</tr>"
html_table_str += '\n'
html_table_str += "</table></body></html>"
html_table_str = html_table_str.replace('\n','')
return html_table_str
def convert_markdown_table_to_html(markdown_table):
"""
Converts a markdown table to the corresponding html string for TEDS computation.
"""
# remove extra code block tokens like '```markdown' and '```
markdown_table = markdown_table.strip('```markdown').strip('```').strip()
row_str_list = markdown_table.split('\n')
# extra the first header row and other data rows
valid_row_str_list = [row_str_list[0]]+row_str_list[2:]
table_rows = []
for row_str in valid_row_str_list:
one_row = []
for cell in row_str.strip().split('|')[1:-1]:
if set(cell) != set(' '):
one_row.append(cell.strip())
else:
one_row.append(' ')
table_rows.append(one_row)
# build html string based on table rows
html_str = convert_table_to_html_str(table_rows)
return html_str
def dict_to_html(data):
html = "<html><body><table>\n"
for key, value in data.items():
if not isinstance(value, str):
value = str(value)
value_str = ' '.join(value)
html += f" <tr><td>{key}</td><td>{value_str}</td></tr>\n"
html += "</table></body></html>"
return html
def convert_str_to_dict(predict_str: str):
"""
Parses the 'predict' string and returns a dictionary.
Missing or unparseable content is handled gracefully.
Parameters:
- predict_str (str): The prediction string containing the output dict.
Returns:
- dict: A dictionary extracted from the predict string.
"""
# Remove code fences like ```python\n...\n```
code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
match = re.search(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
if match:
content = match.group(1)
else:
content = predict_str.strip()
data = {}
success = False
# try parsing with JSON
try:
data = json.loads(content)
success = True
except json.JSONDecodeError:
pass
# try parsing with ast.literal_eval
if not success:
try:
data = ast.literal_eval(content)
if isinstance(data, dict):
success = True
except (ValueError, SyntaxError):
pass
# try parsing with regex
if not success:
key_value_pattern = r'["\']?([\w\s]+)["\']?\s*[:=]\s*["\']?([^\n,"\'{}]+)["\']?'
matches = re.findall(key_value_pattern, content)
try:
for key, value in matches:
data[key.strip()] = value.strip()
except:
return {}
if not data:
return {}
try:
result = {k.strip(): str(v).strip() for k, v in data.items()}
except:
return {}
return result
def convert_str_to_multi_dict(predict_str: str):
"""
Parses the 'predict' string and returns a dictionary.
Handles nested dictionaries and missing or unparseable content gracefully.
Parameters:
- predict_str (str): The prediction string containing the output dict.
Returns:
- dict: A dictionary extracted from the predict string.
"""
# Remove code fences like ```python\n...\n```
code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
matches = re.findall(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
if matches:
content = max(matches, key=len)
else:
content = predict_str.strip()
def strip_variable_assignment(s):
variable_assignment_pattern = r'^\s*\w+\s*=\s*'
return re.sub(variable_assignment_pattern, '', s.strip(), count=1)
content = strip_variable_assignment(content)
def remove_comments(s):
return re.sub(r'#.*', '', s)
content = remove_comments(content)
last_brace_pos = content.rfind('}')
if last_brace_pos != -1:
content = content[:last_brace_pos+1]
data = {}
success = False
# try parsing with ast.literal_eval
try:
data = ast.literal_eval(content)
if isinstance(data, dict):
success = True
except (ValueError, SyntaxError, TypeError):
pass
if not success:
return {}
def process_data(obj):
if isinstance(obj, dict):
return {k: process_data(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [process_data(elem) for elem in obj]
else:
return obj
data = process_data(data)
return data
def generate_combinations(input_dict):
"""
Function to generate all possible combinations of values from a dictionary.
"""
kie_answer = input_dict
if not isinstance(kie_answer, dict):
kie_answer = kie_answer.strip('"')
try:
kie_answer = json.loads(kie_answer)
except json.JSONDecodeError:
try:
kie_answer = ast.literal_eval(kie_answer)
if not isinstance(kie_answer, dict):
kie_answer = ast.literal_eval(kie_answer)
except (ValueError, SyntaxError):
print(f"Unable to parse 'answers' field: {kie_answer}")
return {}
# Ensure the parsed result is a dictionary.
if not isinstance(kie_answer, dict):
print("Parsed 'answers' is still not a dictionary.")
raise ValueError("Input could not be parsed into a dictionary.")
keys = list(kie_answer.keys())
value_lists = []
for single_key in keys:
sinlge_value = kie_answer[single_key]
if not isinstance(sinlge_value, list):
sinlge_value = [sinlge_value]
value_lists.append(sinlge_value)
# Compute the Cartesian product of the value lists.
combinations = list(product(*value_lists))
# Create a dictionary for each combination of values.
result = [dict(zip(keys, values)) for values in combinations]
return result
else:
keys = list(input_dict.keys())
value_lists = [input_dict[key] for key in keys]
# Compute the Cartesian product of the value lists.
combinations = list(product(*value_lists))
# Create a dictionary for each combination of values.
result = [dict(zip(keys, values)) for values in combinations]
return result
def compute_f1_score(preds, gts, ignores=[]):
"""Compute the F1-score for KIE task between predicted and ground truth dictionaries.
Args:
preds (dict): The predicted key-value pairs.
gts (dict): The ground truth key-value pairs.
ignores (list): The list of keys to ignore during evaluation.
Returns:
dict: A dictionary where keys are field names and values are their corresponding F1-scores.
"""
# Optionally remove ignored keys from predictions and ground truths
keys = set(preds.keys()).union(set(gts.keys())) - set(ignores)
f1_scores = {}
for key in keys:
pred_value = preds.get(key, None)
gt_value = gts.get(key, None)
if pred_value:
pred_value = pred_value.lower().strip().replace("\n"," ").replace(" ", "")
if gt_value:
gt_value = gt_value.lower().strip().replace("\n"," ").replace(" ", "")
if pred_value is None and gt_value is None:
continue
elif pred_value is None:
precision = 0.0
recall = 0.0
elif gt_value is None:
# false positive
precision = 0.0
recall = 0.0
else:
if pred_value == gt_value:
# True positive
precision = 1.0
recall = 1.0
else:
precision = 0.0
recall = 0.0
# Compute F1-score
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
f1_scores[key] = f1_score
if len(f1_scores) == 0:
return 0
average_f1 = sum(f1_scores.values()) / len(f1_scores)
return average_f1
def pre_clean(text):
text = re.sub(r'<bos>|<eos>|<pad>|<unk>', '', text)
text = re.sub(r'\s##(\S)', r'\1', text)
text = re.sub(r'\\\s', r'\\', text)
text = re.sub(r'\s\*\s\*\s', r'**', text)
text = re.sub(r'{\s', r'{', text)
text = re.sub(r'\s}', r'}', text)
text = re.sub(r'\s}', r'}', text)
text = re.sub(r'\\begin\s', r'\\begin', text)
text = re.sub(r'\\end\s', r'\\end', text)
text = re.sub(r'\\end{table}', r'\\end{table} \n\n', text)
text = text.replace('\n', ' ')
text = text.replace('*', ' ')
text = text.replace('_', ' ')
return text
def get_tree(input_str):
tree = (Node('ROOT').addkid(Node('TITLE')))
lines = input_str.split("\n")
lines = [pre_clean(line) for line in lines]
last_title = ''
for line in lines:
if line.startswith('#'):
child = tree.get('ROOT')
line = line.replace('#', '')
child.addkid(Node(line))
last_title = line
else:
if last_title == '':
child = tree.get('TITLE')
child.addkid(Node(line))
else:
child = tree.get(last_title)
child.addkid(Node(line))
return tree
def STEDS(pred_tree, ref_tree):
def my_distance(pred, ref):
if len(pred.split()) == 0 or len(ref.split()) == 0:
return 1
else:
return 0
total_distance = simple_distance(pred_tree, ref_tree, label_dist=my_distance)
num_of_nodes = max(len(list(pred_tree.iter())), len(list(ref_tree.iter())))
return 1-total_distance/num_of_nodes
def doc_parsing_evaluation(pred, gt):
score = 0
if not isinstance(pred, str):
return 0
pred_tree = get_tree(pred)
gt_tree = get_tree(gt)
score = STEDS(pred_tree, gt_tree)
return score
def wrap_html_table(html_table):
"""
The TEDS computation from PubTabNet code requires that the input html table should have <html>, <body>, and <table> tags.
Add them if they are missing.
"""
html_table = html_table.replace('\n','')
# add missing <table> tag if missing
if "<table" in html_table and "</table>" not in html_table:
html_table = html_table + "</table>"
elif "<table" not in html_table and "</table>" in html_table:
html_table = "<table>" + html_table
elif "<table" not in html_table and "</table>" not in html_table:
html_table = "<table>" + html_table + "</table>"
else:
pass
# add <body> and <html> tags if missing
if '<body>' not in html_table:
html_table = '<body>' + html_table + '</body>'
if '<html>' not in html_table:
html_table = '<html>' + html_table + '</html>'
return html_table
def get_anls(s1, s2):
try:
s1 = s1.lower()
s2 = s2.lower()
except:
pass
if s1 == s2:
return 1.0
iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
anls = iou
return anls
def ocr_eval(references,predictions):
socre_=0.0
None_num=0
for idx,ref_value in enumerate(references):
pred_value = predictions[idx]
pred_values, ref_values = [], []
if isinstance(pred_value, str):
pred_values.append(pred_value)
else:
pred_values = pred_value
if isinstance(ref_value, str):
ref_values.append(ref_value)
else:
ref_values = ref_value
temp_score = 0.0
temp_num = len(ref_values)
for tmpidx, tmpref in enumerate(ref_values):
tmppred = pred_values[tmpidx] if tmpidx < len(pred_values) else pred_values[0]
if len(pred_values) == 1 and tmppred != "None" and "None" not in ref_values: # pred 1, and not None
temp_score = max(temp_score, get_anls(tmppred, tmpref))
temp_num = len(ref_values)
else:
if tmppred=='None' and tmpref!='None':
temp_score += 0.0
elif tmpref=='None':
temp_num -= 1
else:
temp_score += get_anls(tmppred, tmpref)
if temp_num == 0:
ocr_score = 0.0
None_num += 1
else:
ocr_score = temp_score / (temp_num)
socre_ += ocr_score
if None_num == len(references):
return 9999
else:
return round(socre_ / (len(references)-None_num), 5)
def csv_eval(predictions,references,easy, pred_type='json'):
predictions = predictions
labels = references
def is_int(val):
try:
int(val)
return True
except ValueError:
return False
def is_float(val):
try:
float(val)
return True
except ValueError:
return False
def convert_dict_to_list(data):
"""
Convert a dictionary to a list of tuples, handling both simple and nested dictionaries.
Args:
data (dict): The input dictionary, which might be nested or simple.
Returns:
list: A list of tuples generated from the input dictionary.
"""
# print(data)
converted_list = []
for key, value in data.items():
# Check if the value is a dictionary (indicating a nested structure)
if isinstance(value, dict):
# Handle nested dictionary
for subkey, subvalue in value.items():
# converted_list.append((key, subkey, subvalue))
converted_list.append((key, subkey, re.sub(r'[^\d.-]', '', str(subvalue))))
else:
# Handle simple key-value pair
# converted_list.append((key, "value", value))
converted_list.append((key, "value", re.sub(r'[^\d.-]', '', str(value))))
return converted_list
def csv2triples(csv, separator='\\t', delimiter='\\n'):
lines = csv.strip().split(delimiter)
header = lines[0].split(separator)
triples = []
for line in lines[1:]:
if not line:
continue
values = line.split(separator)
entity = values[0]
for i in range(1, len(values)):
if i >= len(header):
break
#---------------------------------------------------------
temp = [entity.strip(), header[i].strip()]
temp = [x if len(x)==0 or x[-1] != ':' else x[:-1] for x in temp]
value = values[i].strip()
value = re.sub(r'[^\d.-]', '', str(value))
# value = value.replace("%","")
# value = value.replace("$","")
triples.append((temp[0], temp[1], value))
#---------------------------------------------------------
return triples
def csv2triples_noheader(csv, separator='\\t', delimiter='\\n'):
lines = csv.strip().split(delimiter)
maybe_header = [x.strip() for x in lines[0].split(separator)]
not_header = False
if len(maybe_header) > 2:
for c in maybe_header[1:]:
try:
num = float(c)
not_header = True
except:
continue
if not_header:
break
header = None if not_header else maybe_header
data_start = 0 if not_header and separator in lines[0] else 1
triples = []
for line in lines[data_start:]:
if not line:
continue
values = [x.strip() for x in line.split(separator)]
entity = values[0]
for i in range(1, len(values)):
try:
temp = [entity if entity[-1]!=':' else entity[:-1], ""]
except:
temp = [entity, ""]
if header is not None:
try:
this_header = header[i]
temp = [entity, this_header]
temp = [x if x[-1] != ':' else x[:-1] for x in temp]
except:
this_header = entity.strip()
value = values[i].strip()
value = re.sub(r'[^\d.-]', '', str(value))
# value = value.replace("%","")
# value = value.replace("$","")
triples.append((temp[0], temp[1], value))
#---------------------------------------------------------
return triples
def process_triplets(triplets):
new_triplets = []
for triplet in triplets:
new_triplet = []
triplet_temp = []
if len(triplet) > 2:
if is_int(triplet[2]) or is_float(triplet[2]):
triplet_temp = (triplet[0].lower(), triplet[1].lower(), float(triplet[2]))
else:
triplet_temp = (triplet[0].lower(), triplet[1].lower(), triplet[2].lower())
else:
triplet_temp = (triplet[0].lower(), triplet[1].lower(), "no meaning")
new_triplets.append(triplet_temp)
return new_triplets
def intersection_with_tolerance(a, b, tol_word, tol_num):
a = set(a)
b = set(b)
c = set()
for elem1 in a:
for elem2 in b:
if is_float(elem1[-1]) and is_float(elem2[-1]):
if ((Levenshtein.distance(''.join(elem1[:-1]),''.join(elem2[:-1])) <= tol_word) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num))or \
((''.join(elem1[:-1]) in ''.join(elem2[:-1])) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num)) or \
((''.join(elem2[:-1]) in ''.join(elem1[:-1])) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num)):
c.add(elem1)
else:
if (Levenshtein.distance(''.join([str(i) for i in elem1]),''.join([str(j) for j in elem2])) <= tol_word):
c.add(elem1)
return list(c)
def union_with_tolerance(a, b, tol_word, tol_num):
c = set(a) | set(b)
d = set(a) & set(b)
e = intersection_with_tolerance(a, b, tol_word, tol_num)
f = set(e)
g = c-(f-d)
return list(g)
def get_eval_list(pred_csv, label_csv, separator='\\t', delimiter='\\n', tol_word=3, tol_num=0.05, pred_type='json'):
if pred_type == 'json':
pred_triple_list=[]
for it in pred_csv:
pred_triple_temp = convert_dict_to_list(it)
pred_triple_pre = process_triplets(pred_triple_temp)
pred_triple_list.append(pred_triple_pre)
else:
pred_triple_list=[]
for it in pred_csv:
pred_triple_temp = csv2triples(it, separator=separator, delimiter=delimiter)
# pred_triple_temp = csv2triples_noheader(it, separator=separator, delimiter=delimiter)
pred_triple_pre = process_triplets(pred_triple_temp)
pred_triple_list.append(pred_triple_pre)
label_triple_list=[]
for it in label_csv:
label_triple_temp = convert_dict_to_list(it)
label_triple_pre = process_triplets(label_triple_temp)
label_triple_list.append(label_triple_pre)
intersection_list=[]
union_list=[]
sim_list=[]
# for each chart image
for pred,label in zip(pred_triple_list, label_triple_list):
for idx in range(len(pred)):
try:
if label[idx][1] == "value" and "value" not in pred[idx][:2]:
pred[idx] = (pred[idx][0], "value", pred[idx][2])
temp_pred_head = sorted(pred[idx][:2])
temp_gt_head = sorted(label[idx][:2])
pred[idx] = (temp_pred_head[0], temp_pred_head[1], pred[idx][2])
label[idx] = (temp_gt_head[0], temp_gt_head[1], label[idx][2])
except:
continue
intersection = intersection_with_tolerance(pred, label, tol_word = tol_word, tol_num=tol_num)
union = union_with_tolerance(pred, label, tol_word = tol_word, tol_num=tol_num)
sim = len(intersection)/len(union)
intersection_list.append(intersection)
union_list.append(union)
sim_list.append(sim)
return intersection_list, union_list, sim_list
def get_ap(predictions, labels, sim_threhold, tolerance, separator='\\t', delimiter='\\n', easy=1):
if tolerance == 'strict':
tol_word=0
if easy == 1:
tol_num=0
else:
tol_num=0.1
elif tolerance == 'slight':
tol_word=2
if easy == 1:
tol_num=0.05
else:
tol_num=0.3
elif tolerance == 'high':
tol_word= 5
if easy == 1:
tol_num=0.1
else:
tol_num=0.5
intersection_list, union_list, sim_list = get_eval_list(predictions, labels, separator=separator, delimiter=delimiter, tol_word=tol_word, tol_num=tol_num, pred_type=pred_type)
ap = len([num for num in sim_list if num >= sim_threhold])/(len(sim_list)+1e-16)
return ap
map_strict = 0
map_slight = 0
map_high = 0
s="\\t"
d="\\n"
for sim_threhold in np.arange (0.5, 1, 0.05):
map_temp_strict = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='strict', separator=s, delimiter=d, easy=easy)
map_temp_slight = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='slight', separator=s, delimiter=d, easy=easy)
map_temp_high = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='high', separator=s, delimiter=d, easy=easy)
map_strict += map_temp_strict/10
map_slight += map_temp_slight/10
map_high += map_temp_high/10
em = get_ap(predictions, labels, sim_threhold=1, tolerance='strict', separator=s, delimiter=d, easy=easy)
ap_50_strict = get_ap(predictions, labels, sim_threhold=0.5, tolerance='strict', separator=s, delimiter=d, easy=easy)
ap_75_strict = get_ap(predictions, labels, sim_threhold=0.75, tolerance='strict', separator=s, delimiter=d, easy=easy)
ap_90_strict = get_ap(predictions, labels, sim_threhold=0.90, tolerance='strict', separator=s, delimiter=d, easy=easy)
ap_50_slight = get_ap(predictions, labels, sim_threhold=0.5, tolerance='slight', separator=s, delimiter=d, easy=easy)
ap_75_slight = get_ap(predictions, labels, sim_threhold=0.75, tolerance='slight', separator=s, delimiter=d, easy=easy)
ap_90_slight = get_ap(predictions, labels, sim_threhold=0.90, tolerance='slight', separator=s, delimiter=d, easy=easy)
ap_50_high = get_ap(predictions, labels, sim_threhold=0.5, tolerance='high', separator=s, delimiter=d, easy=easy)
ap_75_high = get_ap(predictions, labels, sim_threhold=0.75, tolerance='high', separator=s, delimiter=d, easy=easy)
ap_90_high = get_ap(predictions, labels, sim_threhold=0.90, tolerance='high', separator=s, delimiter=d, easy=easy)
return em, map_strict, map_slight, map_high, ap_50_strict, ap_75_strict, ap_90_strict, ap_50_slight, ap_75_slight, ap_90_slight, ap_50_high, ap_75_high, ap_90_high
def draw_SCRM_table(em, map_strict, map_slight, map_high, ap_50_strict, ap_75_strict, ap_90_strict, ap_50_slight, ap_75_slight, ap_90_slight, ap_50_high, ap_75_high, ap_90_high,title_ocr_socre,source_ocr_socre,x_title_ocr_socre,y_title_ocr_socre,structure_accuracy):
result=f'''
-----------------------------------------------------------\n
| Metrics | Sim_threshold | Tolerance | Value |\n
-----------------------------------------------------------\n
| | | strict | {'%.4f' % map_strict} | \n
| | ----------------------------\n
| mPrecison | 0.5:0.05:0.95 | slight | {'%.4f' % map_slight} |\n
| | ---------------------------\n
| | | high | {'%.4f' % map_high} |\n
-----------------------------------------------------------\n
| | | strict | {'%.4f' % ap_50_strict} |\n
| | ---------------------------\n
| Precison | 0.5 | slight | {'%.4f' % ap_50_slight } |\n
| | ---------------------------\n
| | | high | {'%.4f' % ap_50_high } |\n
-----------------------------------------------------------\n
| | | strict | {'%.4f' % ap_75_strict} |\n
| | ---------------------------\n
| Precison | 0.75 | slight | {'%.4f' % ap_75_slight} |\n
| | ---------------------------\n
| | | high | {'%.4f' % ap_75_high} |\n
-----------------------------------------------------------\n
| | | strict | {'%.4f' % ap_90_strict} |\n
| | ---------------------------\n
| Precison | 0.9 | slight | {'%.4f' % ap_90_slight } |\n
| | ---------------------------\n
| | | high | {'%.4f' % ap_90_high} |\n
-----------------------------------------------------------\n
|Precison(EM) | {'%.4f' % em} |\n
-----------------------------------------------------------\n
|Title(EM) | {'%.4f' % title_ocr_socre} |\n
-----------------------------------------------------------\n
|Source(EM) | {'%.4f' % source_ocr_socre} |\n
-----------------------------------------------------------\n
|X_title(EM) | {'%.4f' % x_title_ocr_socre} |\n
-----------------------------------------------------------\n
|Y_title(EM) | {'%.4f' % y_title_ocr_socre} |\n
-----------------------------------------------------------\n
|structure_acc| {'%.4f' % structure_accuracy} |\n
-----------------------------------------------------------\n
'''
return result
if __name__ == '__main__':
import json
import pprint
# markdown structure for Table Parsing task
pred_markdown = "| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |"
true_markdown = "| week | date | opponent | result | record |\n| --- | --- | --- | --- | --- |\n| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |"
teds = TEDS(n_jobs=4)
pred_table_html = convert_markdown_table_to_html(pred_markdown)
true_table_html = convert_markdown_table_to_html(true_markdown)
scores = teds.evaluate(pred_table_html, true_table_html)
pp = pprint.PrettyPrinter()
pp.pprint(scores)
# dict structure for Key Information Extraction task
pred_dict = {
"company": [
"OLD TOWN "
],
"date": [
"2024"
],
"address": [
"SRI RAMPAI"
],
"total": [
"30"
]
}
true_dict = {
"company": [
"OLD TOWN KOPITAM SND BHD"
],
"date": [
"2024/9/27"
],
"address": [
"SRI RAMPAI"
],
"total": [
"30"
]
}
teds = TEDS(n_jobs=4)
pred_dict_html = dict_to_html(pred_dict)
true_dict_html = dict_to_html(true_dict)
print(pred_dict_html)
print(true_dict_html)
scores = teds.evaluate(pred_dict_html, true_dict_html)
pp = pprint.PrettyPrinter()
pp.pprint(scores)