Files
MultimodalOCR/datasets/kie_dataset.py

134 lines
6.5 KiB
Python
Raw Normal View History

2023-05-23 18:24:16 +08:00
from torch.utils.data import Dataset
import os
import json
2023-05-27 17:21:39 +08:00
import re
2023-05-23 18:24:16 +08:00
class SROIEDataset(Dataset):
def __init__(
self,
dir_path= "./data/SROIE",
):
self.image_list = []
self.question_list = []
self.answer_list = []
for file_name in os.listdir(dir_path):
if file_name.endswith(".txt") and '(' not in file_name:
file_path = os.path.join(dir_path, file_name)
img_path = file_path.replace('.txt', '.jpg')
with open(file_path) as f:
content = f.read()
info = json.loads(content)
if 'company' in info.keys():
self.question_list.append("what is the name of the company that issued this invoice?")#llava 0.12
#self.question_list.append("what is the company information in the image?")#llava 0.08
self.answer_list.append(info['company'])
self.image_list.append(img_path)
if 'date' in info.keys():
self.question_list.append("when was this invoice issued?")
#self.question_list.append("what is the date information in the image?")
self.answer_list.append(info['date'])
self.image_list.append(img_path)
if 'address' in info.keys():
self.question_list.append("where was this invoice issued?")
#self.question_list.append("what is the address information in the image?")
self.answer_list.append(info['address'])
self.image_list.append(img_path)
if 'total' in info.keys():
self.question_list.append("what is the total amount of this invoice?")
#self.question_list.append("what is the total information in the image?")
self.answer_list.append(info['total'])
self.image_list.append(img_path)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img_path = self.image_list[idx]
question = self.question_list[idx]
answers = self.answer_list[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
class FUNSDDataset(Dataset):
def __init__(self, ann_dir_path= "./data/FUNSD/testing_data/annotations"):
questions = []
answers = []
images = []
for file_name in os.listdir(ann_dir_path):
file_path = os.path.join(ann_dir_path, file_name)
with open(file_path, 'r') as f:
json_data = json.load(f)['form']
#去除空的linking
json_data = [d for d in json_data if "linking" in d and len(d["linking"])>0]
question_list = [d for d in json_data if d.get('label') == 'question']
answer_list = [d for d in json_data if d.get('label') == 'answer']
for i in range(len(question_list)):
link = question_list[i]['linking']
gt_answer = ""
for j in range(len(link)):
for k in range(len(answer_list)):
if answer_list[k]['id'] == link[j][1]:
if len(gt_answer)>0:
gt_answer = gt_answer + ' ' + answer_list[k]['text']
else:
gt_answer = gt_answer + answer_list[k]['text']
if len(gt_answer)>0:
questions.append(f"what is \"{question_list[i]['text']}\" information in the image?")
answers.append(gt_answer)
images.append(file_path.replace('annotations','images').replace('.json','.png'))
self.questions = questions
self.answers = answers
self.images = images
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
img_path = self.images[idx]
question = self.questions[idx]
answers = self.answers[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
2023-05-27 17:21:39 +08:00
entities = {"CE-PS":"Calories/Energy of per serving", "TF-PS":"Total fat of per serving", "CAR-PS":"Total carbohydrate of per serving",
"PRO-PS":"Protein of per serving","SS":"Serving size", "SO-PS":"Sodium of per serving", "TF-D":"Total fat of daily value",
"CAR-D":"Total carbohydrate of daily value","SO-D":"Sodium of daily value", "CE-P1":"Calories/Energy of per 100g/ml",
"PRO-P1":"Protein of per 100g/ml","CAR-P1":"Total carbohydrate of per 100g/ml","TF-P1":"Total Fat of per 100g/ml",
"PRO-D":"Protein of daily value","SO-P1":"Sodium of per 100g/ml", "CE-D":"Calories/Energy of daily value",
"TF-PP":"Total fat of per 100g/ml percentage","CAR-PP":"Total carbohydrate of per 100g/ml percentage",
"SO-PP":"Sodium of per 100g/ml percentage","PRO-PP":"Protein of per 100g/ml percentage",
"CE-PP":"Calories/Energy of per 100g/ml percentage"}
2023-05-23 18:24:16 +08:00
class POIEDataset(Dataset):
def __init__(
self,
dir_path= "./data/POIE/test.txt",
):
self.image_list = []
self.question_list = []
self.answer_list = []
with open(dir_path, 'r') as f:
lines = f.readlines()
for line in lines:
dict = json.loads(line)
for key, value in dict['entity_dict'].items():
self.image_list.append(dir_path.replace("test.txt", dict['file_name']))
2023-05-27 17:21:39 +08:00
self.question_list.append(f'what is {entities[key]} in the image?')
matches = re.findall(r"\((.*?)\)", value)
answer = [match.strip() for match in matches]
answer.append(re.sub(r'\(.*?\)', '', value).strip())
self.answer_list.append(answer)
2023-05-23 18:24:16 +08:00
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img_path = self.image_list[idx]
question = self.question_list[idx]
answers = self.answer_list[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
if __name__ == "__main__":
data = POIEDataset("/home/zhangli/GPT4/MutimodelOCR/data/POIE/test.txt")
data = iter(data)
batch = next(data)