update
This commit is contained in:
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
class SROIEDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -90,6 +91,14 @@ class FUNSDDataset(Dataset):
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers}
|
||||
entities = {"CE-PS":"Calories/Energy of per serving", "TF-PS":"Total fat of per serving", "CAR-PS":"Total carbohydrate of per serving",
|
||||
"PRO-PS":"Protein of per serving","SS":"Serving size", "SO-PS":"Sodium of per serving", "TF-D":"Total fat of daily value",
|
||||
"CAR-D":"Total carbohydrate of daily value","SO-D":"Sodium of daily value", "CE-P1":"Calories/Energy of per 100g/ml",
|
||||
"PRO-P1":"Protein of per 100g/ml","CAR-P1":"Total carbohydrate of per 100g/ml","TF-P1":"Total Fat of per 100g/ml",
|
||||
"PRO-D":"Protein of daily value","SO-P1":"Sodium of per 100g/ml", "CE-D":"Calories/Energy of daily value",
|
||||
"TF-PP":"Total fat of per 100g/ml percentage","CAR-PP":"Total carbohydrate of per 100g/ml percentage",
|
||||
"SO-PP":"Sodium of per 100g/ml percentage","PRO-PP":"Protein of per 100g/ml percentage",
|
||||
"CE-PP":"Calories/Energy of per 100g/ml percentage"}
|
||||
class POIEDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -101,12 +110,14 @@ class POIEDataset(Dataset):
|
||||
with open(dir_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
import pdb;pdb.set_trace()
|
||||
dict = json.loads(line)
|
||||
for key, value in dict['entity_dict'].items():
|
||||
self.image_list.append(dir_path.replace("test.txt", dict['file_name']))
|
||||
self.question_list.append(key)
|
||||
self.answer_list.append(value)
|
||||
self.question_list.append(f'what is {entities[key]} in the image?')
|
||||
matches = re.findall(r"\((.*?)\)", value)
|
||||
answer = [match.strip() for match in matches]
|
||||
answer.append(re.sub(r'\(.*?\)', '', value).strip())
|
||||
self.answer_list.append(answer)
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
def __getitem__(self, idx):
|
||||
|
Reference in New Issue
Block a user