103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
from torch.utils.data import Dataset
|
|
import os
|
|
import json
|
|
class textVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir_path= "./data/textVQA/train_images",
|
|
ann_path= "./data/textVQA/TextVQA_0.5.1_val.json"
|
|
):
|
|
|
|
self.data = json.load(open(ann_path, "r"))["data"]
|
|
self.image_dir_path = image_dir_path
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question = self.data[idx]['question']
|
|
answers = self.data[idx]['answers']
|
|
img_path = os.path.join(self.image_dir_path, f"{self.data[idx]['image_id']}.jpg")
|
|
return {
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers}
|
|
|
|
class docVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir_path= "./data/docVQA/val",
|
|
ann_path= "./data/docVQA/val/val_v1.0.json",
|
|
):
|
|
|
|
self.data = json.load(open(ann_path, "r"))["data"]
|
|
self.image_dir_path = image_dir_path
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question = self.data[idx]['question']
|
|
answers = self.data[idx]['answers']
|
|
img_path = os.path.join(self.image_dir_path, self.data[idx]['image'])
|
|
return {
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers}
|
|
|
|
class ocrVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir_path= "./data/ocrVQA/images",
|
|
ann_path= "./data/ocrVQA/dataset.json",
|
|
):
|
|
self.image_list = []
|
|
self.question_list = []
|
|
self.answer_list = []
|
|
dataset = json.load(open(ann_path, "r"))
|
|
import pdb;pdb.set_trace()
|
|
for idx, data in enumerate(dataset):
|
|
questions = dataset[data]['questions']
|
|
for index, question in enumerate(questions):
|
|
image_file = os.path.join(image_dir_path, f'{data}.jpg')
|
|
gt_answers = dataset[data]['answers'][index]
|
|
self.image_list.append(image_file)
|
|
self.answer_list.append(gt_answers)
|
|
self.question_list.append(question)
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question = self.question_list[idx]
|
|
answers = self.answer_list[idx]
|
|
img_path = self.image_list[idx]
|
|
return {
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers}
|
|
|
|
class STVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir_path= "./data/STVQA",
|
|
ann_path= "./data/STVQA/train_task_3.json",
|
|
):
|
|
self.image_list = []
|
|
self.question_list = []
|
|
self.answer_list = []
|
|
data = json.load(open(ann_path, "r"))
|
|
for i in range(len(data)):
|
|
image_path = image_dir_path+'/'+data['data'][i]['dataset']+'/'+data['data'][i]['file_name']
|
|
self.image_list.append(image_path)
|
|
self.answer_list.append(data['data'][i]['answers'])
|
|
self.question_list.append(data['data'][i]['question'])
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question = self.question_list[idx]
|
|
answers = self.answer_list[idx]
|
|
img_path = self.image_list[idx]
|
|
return {
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers} |