Files
MultimodalOCR/datasets/vqa_dataset.py
2023-05-17 03:38:36 +08:00

103 lines
3.4 KiB
Python

from torch.utils.data import Dataset
import os
import json
class textVQADataset(Dataset):
def __init__(
self,
image_dir_path= "./data/textVQA/train_images",
ann_path= "./data/textVQA/TextVQA_0.5.1_val.json"
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir_path = image_dir_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.data[idx]['question']
answers = self.data[idx]['answers']
img_path = os.path.join(self.image_dir_path, f"{self.data[idx]['image_id']}.jpg")
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
class docVQADataset(Dataset):
def __init__(
self,
image_dir_path= "./data/docVQA/val",
ann_path= "./data/docVQA/val/val_v1.0.json",
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir_path = image_dir_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.data[idx]['question']
answers = self.data[idx]['answers']
img_path = os.path.join(self.image_dir_path, self.data[idx]['image'])
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
class ocrVQADataset(Dataset):
def __init__(
self,
image_dir_path= "./data/ocrVQA/images",
ann_path= "./data/ocrVQA/dataset.json",
):
self.image_list = []
self.question_list = []
self.answer_list = []
dataset = json.load(open(ann_path, "r"))
import pdb;pdb.set_trace()
for idx, data in enumerate(dataset):
questions = dataset[data]['questions']
for index, question in enumerate(questions):
image_file = os.path.join(image_dir_path, f'{data}.jpg')
gt_answers = dataset[data]['answers'][index]
self.image_list.append(image_file)
self.answer_list.append(gt_answers)
self.question_list.append(question)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.question_list[idx]
answers = self.answer_list[idx]
img_path = self.image_list[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}
class STVQADataset(Dataset):
def __init__(
self,
image_dir_path= "./data/STVQA",
ann_path= "./data/STVQA/train_task_3.json",
):
self.image_list = []
self.question_list = []
self.answer_list = []
data = json.load(open(ann_path, "r"))
for i in range(len(data)):
image_path = image_dir_path+'/'+data['data'][i]['dataset']+'/'+data['data'][i]['file_name']
self.image_list.append(image_path)
self.answer_list.append(data['data'][i]['answers'])
self.question_list.append(data['data'][i]['question'])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.question_list[idx]
answers = self.answer_list[idx]
img_path = self.image_list[idx]
return {
"image_path": img_path,
"question": question,
"gt_answers": answers}