2023-05-17 03:38:36 +08:00
|
|
|
|
from torch.utils.data import Dataset
|
2023-06-09 10:29:18 +08:00
|
|
|
|
import xml.etree.ElementTree as ET
|
2023-05-17 03:38:36 +08:00
|
|
|
|
import os
|
2023-06-09 10:29:18 +08:00
|
|
|
|
import re
|
|
|
|
|
def remove_special_chars(s):
|
|
|
|
|
pattern = r"[^a-zA-Z0-9\s]"
|
|
|
|
|
s = re.sub(pattern, "", s)
|
|
|
|
|
return s
|
2023-05-17 03:38:36 +08:00
|
|
|
|
class ocrDataset(Dataset):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
image_dir_path= "./data/ocr",
|
|
|
|
|
dataset_name = "ct80"
|
|
|
|
|
):
|
|
|
|
|
self.image_dir_path = image_dir_path
|
|
|
|
|
self.dataset_name = dataset_name
|
|
|
|
|
file_path = os.path.join(image_dir_path, f'{dataset_name}/test_label.txt')
|
|
|
|
|
file = open(file_path, "r")
|
|
|
|
|
self.lines = file.readlines()
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.lines)
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
image_id = self.lines[idx].split()[0]
|
|
|
|
|
img_path = os.path.join(self.image_dir_path,f'{self.dataset_name}/{image_id}')
|
|
|
|
|
answers = self.lines[idx].split()[1]
|
|
|
|
|
return {
|
|
|
|
|
"image_path": img_path,
|
2023-06-09 10:29:18 +08:00
|
|
|
|
"gt_answers": answers}
|
|
|
|
|
class IAMDataset(Dataset):
|
|
|
|
|
def __init__(self, image_dir_path = './data/IAM') -> None:
|
|
|
|
|
ann_path = image_dir_path + '/xml'
|
|
|
|
|
self.images = []
|
|
|
|
|
self.answers = []
|
|
|
|
|
for filename in os.listdir(ann_path):
|
|
|
|
|
if filename.endswith('.xml'):
|
|
|
|
|
# 读取xml文件
|
|
|
|
|
xml_file = os.path.join(ann_path, filename)
|
|
|
|
|
tree = ET.parse(xml_file)
|
|
|
|
|
root = tree.getroot()
|
|
|
|
|
# 对读取的xml文件进行操作
|
|
|
|
|
# 例如,输出xml文件中的所有元素
|
|
|
|
|
for word in root.iter('word'):
|
|
|
|
|
text = word.get('text')
|
|
|
|
|
img_id = word.get('id')
|
|
|
|
|
img_path = image_dir_path+'/'+filename.split('-')[0]+'/'+filename.split('.')[0]+'/'+img_id+'.png'
|
|
|
|
|
text = remove_special_chars(text)
|
|
|
|
|
if len(text)>0:
|
|
|
|
|
self.images.append(img_path)
|
|
|
|
|
self.answers.append(text)
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.images)
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
img_path = self.images[idx]
|
|
|
|
|
answers = self.answers[idx]
|
|
|
|
|
return {
|
|
|
|
|
"image_path": img_path,
|
|
|
|
|
"gt_answers": answers}
|
|
|
|
|
class ReCTSDataset(Dataset):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dir_path= "./data/ReCTS",
|
|
|
|
|
):
|
|
|
|
|
self.image_dir_path = os.path.join(dir_path, 'crops')
|
|
|
|
|
file_path = os.path.join(dir_path, 'test_label.txt')
|
|
|
|
|
file = open(file_path, "r")
|
|
|
|
|
self.lines = file.readlines()
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.lines)
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
image_id = self.lines[idx].split()[0]
|
|
|
|
|
img_path = os.path.join(self.image_dir_path, image_id)
|
|
|
|
|
answers = self.lines[idx].split()[1]
|
|
|
|
|
return {
|
|
|
|
|
"image_path": img_path,
|
|
|
|
|
"gt_answers": answers}
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
'''data = IAMDataset('/home/zhangli/GPT4/MutimodelOCR/data/IAM')
|
|
|
|
|
print(len(data))
|
|
|
|
|
data = iter(data)
|
|
|
|
|
batch = next(data)
|
|
|
|
|
import pdb;pdb.set_trace()'''
|
|
|
|
|
data = ReCTSDataset('/home/zhangli/GPT4/MutimodelOCR/data/ReCTS')
|
|
|
|
|
print(len(data))
|
|
|
|
|
data = iter(data)
|
|
|
|
|
batch = next(data)
|
|
|
|
|
print(batch)
|
|
|
|
|
|
|
|
|
|
|