Files
MultimodalOCR/datasets/formula_dataset.py

24 lines
889 B
Python
Raw Normal View History

2023-05-23 18:24:16 +08:00
from torch.utils.data import Dataset
import os
class HMEDataset(Dataset):
def __init__(
self,
image_dir_path= "./data/HME100K/test_images",
ann_path= "./data/HME100K/test_labels.txt",
):
file = open(ann_path, "r")
self.lines = file.readlines()
self.image_dir_path = image_dir_path
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
image_id = self.lines[idx].split("\t")[0]
img_path = os.path.join(self.image_dir_path, image_id)
answers = self.lines[idx].split("\t")[1]
return {
"image_path": img_path,
"gt_answers": answers}
if __name__ == "__main__":
data = HMEDataset("/home/zhangli/GPT4/MutimodelOCR/data/HME100K/test_images","/home/zhangli/GPT4/MutimodelOCR/data/HME100K/test_labels.txt")
data = iter(data)
batch = next(data)