This commit is contained in:
echo840
2023-05-27 17:21:39 +08:00
parent b388fba03e
commit 6e02bedd46
450 changed files with 1148092 additions and 38254 deletions

View File

@@ -1,6 +1,7 @@
from torch.utils.data import Dataset
import os
import json
import re
class SROIEDataset(Dataset):
def __init__(
self,
@@ -90,6 +91,14 @@ class FUNSDDataset(Dataset):
"image_path": img_path,
"question": question,
"gt_answers": answers}
entities = {"CE-PS":"Calories/Energy of per serving", "TF-PS":"Total fat of per serving", "CAR-PS":"Total carbohydrate of per serving",
"PRO-PS":"Protein of per serving","SS":"Serving size", "SO-PS":"Sodium of per serving", "TF-D":"Total fat of daily value",
"CAR-D":"Total carbohydrate of daily value","SO-D":"Sodium of daily value", "CE-P1":"Calories/Energy of per 100g/ml",
"PRO-P1":"Protein of per 100g/ml","CAR-P1":"Total carbohydrate of per 100g/ml","TF-P1":"Total Fat of per 100g/ml",
"PRO-D":"Protein of daily value","SO-P1":"Sodium of per 100g/ml", "CE-D":"Calories/Energy of daily value",
"TF-PP":"Total fat of per 100g/ml percentage","CAR-PP":"Total carbohydrate of per 100g/ml percentage",
"SO-PP":"Sodium of per 100g/ml percentage","PRO-PP":"Protein of per 100g/ml percentage",
"CE-PP":"Calories/Energy of per 100g/ml percentage"}
class POIEDataset(Dataset):
def __init__(
self,
@@ -101,12 +110,14 @@ class POIEDataset(Dataset):
with open(dir_path, 'r') as f:
lines = f.readlines()
for line in lines:
import pdb;pdb.set_trace()
dict = json.loads(line)
for key, value in dict['entity_dict'].items():
self.image_list.append(dir_path.replace("test.txt", dict['file_name']))
self.question_list.append(key)
self.answer_list.append(value)
self.question_list.append(f'what is {entities[key]} in the image?')
matches = re.findall(r"\((.*?)\)", value)
answer = [match.strip() for match in matches]
answer.append(re.sub(r'\(.*?\)', '', value).strip())
self.answer_list.append(answer)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):