Add files via upload
This commit is contained in:
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
BIN
datasets/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
datasets/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/ocr_dataset.cpython-310.pyc
Normal file
BIN
datasets/__pycache__/ocr_dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/vqa_dataset.cpython-310.pyc
Normal file
BIN
datasets/__pycache__/vqa_dataset.cpython-310.pyc
Normal file
Binary file not shown.
0
datasets/kie_dataset.py
Normal file
0
datasets/kie_dataset.py
Normal file
22
datasets/ocr_dataset.py
Normal file
22
datasets/ocr_dataset.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
class ocrDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir_path= "./data/ocr",
|
||||
dataset_name = "ct80"
|
||||
):
|
||||
self.image_dir_path = image_dir_path
|
||||
self.dataset_name = dataset_name
|
||||
file_path = os.path.join(image_dir_path, f'{dataset_name}/test_label.txt')
|
||||
file = open(file_path, "r")
|
||||
self.lines = file.readlines()
|
||||
def __len__(self):
|
||||
return len(self.lines)
|
||||
def __getitem__(self, idx):
|
||||
image_id = self.lines[idx].split()[0]
|
||||
img_path = os.path.join(self.image_dir_path,f'{self.dataset_name}/{image_id}')
|
||||
answers = self.lines[idx].split()[1]
|
||||
return {
|
||||
"image_path": img_path,
|
||||
"gt_answers": answers}
|
103
datasets/vqa_dataset.py
Normal file
103
datasets/vqa_dataset.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import json
|
||||
class textVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir_path= "./data/textVQA/train_images",
|
||||
ann_path= "./data/textVQA/TextVQA_0.5.1_val.json"
|
||||
):
|
||||
|
||||
self.data = json.load(open(ann_path, "r"))["data"]
|
||||
self.image_dir_path = image_dir_path
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question = self.data[idx]['question']
|
||||
answers = self.data[idx]['answers']
|
||||
img_path = os.path.join(self.image_dir_path, f"{self.data[idx]['image_id']}.jpg")
|
||||
return {
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers}
|
||||
|
||||
class docVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir_path= "./data/docVQA/val",
|
||||
ann_path= "./data/docVQA/val/val_v1.0.json",
|
||||
):
|
||||
|
||||
self.data = json.load(open(ann_path, "r"))["data"]
|
||||
self.image_dir_path = image_dir_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question = self.data[idx]['question']
|
||||
answers = self.data[idx]['answers']
|
||||
img_path = os.path.join(self.image_dir_path, self.data[idx]['image'])
|
||||
return {
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers}
|
||||
|
||||
class ocrVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir_path= "./data/ocrVQA/images",
|
||||
ann_path= "./data/ocrVQA/dataset.json",
|
||||
):
|
||||
self.image_list = []
|
||||
self.question_list = []
|
||||
self.answer_list = []
|
||||
dataset = json.load(open(ann_path, "r"))
|
||||
import pdb;pdb.set_trace()
|
||||
for idx, data in enumerate(dataset):
|
||||
questions = dataset[data]['questions']
|
||||
for index, question in enumerate(questions):
|
||||
image_file = os.path.join(image_dir_path, f'{data}.jpg')
|
||||
gt_answers = dataset[data]['answers'][index]
|
||||
self.image_list.append(image_file)
|
||||
self.answer_list.append(gt_answers)
|
||||
self.question_list.append(question)
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question = self.question_list[idx]
|
||||
answers = self.answer_list[idx]
|
||||
img_path = self.image_list[idx]
|
||||
return {
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers}
|
||||
|
||||
class STVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir_path= "./data/STVQA",
|
||||
ann_path= "./data/STVQA/train_task_3.json",
|
||||
):
|
||||
self.image_list = []
|
||||
self.question_list = []
|
||||
self.answer_list = []
|
||||
data = json.load(open(ann_path, "r"))
|
||||
for i in range(len(data)):
|
||||
image_path = image_dir_path+'/'+data['data'][i]['dataset']+'/'+data['data'][i]['file_name']
|
||||
self.image_list.append(image_path)
|
||||
self.answer_list.append(data['data'][i]['answers'])
|
||||
self.question_list.append(data['data'][i]['question'])
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question = self.question_list[idx]
|
||||
answers = self.answer_list[idx]
|
||||
img_path = self.image_list[idx]
|
||||
return {
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers}
|
Reference in New Issue
Block a user