This commit is contained in:
echo840
2023-05-23 18:24:16 +08:00
parent da758a9ca7
commit b388fba03e
470 changed files with 2523750 additions and 7307 deletions

View File

@@ -54,7 +54,6 @@ class ocrVQADataset(Dataset):
self.question_list = []
self.answer_list = []
dataset = json.load(open(ann_path, "r"))
import pdb;pdb.set_trace()
for idx, data in enumerate(dataset):
questions = dataset[data]['questions']
for index, question in enumerate(questions):
@@ -64,7 +63,7 @@ class ocrVQADataset(Dataset):
self.answer_list.append(gt_answers)
self.question_list.append(question)
def __len__(self):
return len(self.data)
return len(self.image_list)
def __getitem__(self, idx):
question = self.question_list[idx]
@@ -85,13 +84,43 @@ class STVQADataset(Dataset):
self.question_list = []
self.answer_list = []
data = json.load(open(ann_path, "r"))
for i in range(len(data)):
for i in range(len(data['data'])):
image_path = image_dir_path+'/'+data['data'][i]['dataset']+'/'+data['data'][i]['file_name']
self.image_list.append(image_path)
self.answer_list.append(data['data'][i]['answers'])
self.question_list.append(data['data'][i]['question'])
def __len__(self):
return len(self.data)
return len(self.image_list)
def __getitem__(self, idx):
question = self.question_list[idx]
answers = self.answer_list[idx]
img_path = self.image_list[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
class ESTVQADataset(Dataset):
def __init__(
self,
image_dir_path= "./data/ESTVQA/images/train",
ann_path= "./data/ESTVQA/annotations/train.json",
):
self.image_list = []
self.question_list = []
self.answer_list = []
with open(ann_path,'r') as f:
data = json.load(f)
for i in range(len(data)):
image_path = os.path.join(image_dir_path, data[i]['image'])
for j in range(len(data[i]['annotation'])):
question = data[i]['annotation'][j]['question']
answer = data[i]['annotation'][j]['answer']
self.image_list.append(image_path)
self.question_list.append(question)
self.answer_list.append(answer)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
question = self.question_list[idx]