init commit of samurai
This commit is contained in:
199
lib/train/data/loader.py
Normal file
199
lib/train/data/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.utils.data.dataloader
|
||||
import importlib
|
||||
import collections
|
||||
# from torch._six import string_classes
|
||||
from lib.utils import TensorDict, TensorList
|
||||
|
||||
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
|
||||
int_classes = int
|
||||
else:
|
||||
from torch._six import int_classes
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
string_classes = str
|
||||
|
||||
def _check_use_shared_memory():
|
||||
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
|
||||
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
|
||||
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
|
||||
if hasattr(collate_lib, '_use_shared_memory'):
|
||||
return getattr(collate_lib, '_use_shared_memory')
|
||||
return torch.utils.data.get_worker_info() is not None
|
||||
|
||||
|
||||
def ltr_collate(batch):
|
||||
"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 0, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
def ltr_collate_stack1(batch):
|
||||
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 1, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 1)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate_stack1(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
class LTRLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
|
||||
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
|
||||
select along which dimension the data should be stacked to form a batch.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: 1).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: False).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, ``shuffle`` must be False.
|
||||
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
||||
indices at a time. Mutually exclusive with batch_size, shuffle,
|
||||
sampler, and drop_last.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means that the data will be loaded in the main process.
|
||||
(default: 0)
|
||||
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
||||
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
||||
into CUDA pinned memory before returning them.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: False)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: 0)
|
||||
worker_init_fn (callable, optional): If not None, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: None)
|
||||
|
||||
.. note:: By default, each worker will have its PyTorch seed set to
|
||||
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
||||
by main process using its RNG. However, seeds for other libraries
|
||||
may be duplicated upon initializing workers (w.g., NumPy), causing
|
||||
each worker to return identical random numbers. (See
|
||||
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
||||
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
||||
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
||||
before data loading.
|
||||
|
||||
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
||||
unpicklable object, e.g., a lambda function.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
print("pin_memory is", pin_memory)
|
||||
if collate_fn is None:
|
||||
if stack_dim == 0:
|
||||
collate_fn = ltr_collate
|
||||
elif stack_dim == 1:
|
||||
collate_fn = ltr_collate_stack1
|
||||
else:
|
||||
raise ValueError('Stack dim no supported. Must be 0 or 1.')
|
||||
|
||||
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
Reference in New Issue
Block a user