Files
Grounded-SAM-2/lib/train/data/loader.py
2024-11-19 22:12:54 -08:00

200 lines
9.4 KiB
Python

import torch
import torch.utils.data.dataloader
import importlib
import collections
# from torch._six import string_classes
from lib.utils import TensorDict, TensorList
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
int_classes = int
else:
from torch._six import int_classes
import warnings
warnings.filterwarnings("ignore")
string_classes = str
def _check_use_shared_memory():
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
if hasattr(collate_lib, '_use_shared_memory'):
return getattr(collate_lib, '_use_shared_memory')
return torch.utils.data.get_worker_info() is not None
def ltr_collate(batch):
"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
# if batch[0].dim() < 4:
# return torch.stack(batch, 0, out=out)
# return torch.cat(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], TensorDict):
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
elif isinstance(batch[0], collections.Mapping):
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], TensorList):
transposed = zip(*batch)
return TensorList([ltr_collate(samples) for samples in transposed])
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [ltr_collate(samples) for samples in transposed]
elif batch[0] is None:
return batch
raise TypeError((error_msg.format(type(batch[0]))))
def ltr_collate_stack1(batch):
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 1, out=out)
# if batch[0].dim() < 4:
# return torch.stack(batch, 0, out=out)
# return torch.cat(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 1)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], TensorDict):
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
elif isinstance(batch[0], collections.Mapping):
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], TensorList):
transposed = zip(*batch)
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [ltr_collate_stack1(samples) for samples in transposed]
elif batch[0] is None:
return batch
raise TypeError((error_msg.format(type(batch[0]))))
class LTRLoader(torch.utils.data.dataloader.DataLoader):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
select along which dimension the data should be stacked to form a batch.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: None)
.. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
by main process using its RNG. However, seeds for other libraries
may be duplicated upon initializing workers (w.g., NumPy), causing
each worker to return identical random numbers. (See
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
use ``torch.initial_seed()`` to access the PyTorch seed for each
worker in :attr:`worker_init_fn`, and use it to set other seeds
before data loading.
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
"""
__initialized = False
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
print("pin_memory is", pin_memory)
if collate_fn is None:
if stack_dim == 0:
collate_fn = ltr_collate
elif stack_dim == 1:
collate_fn = ltr_collate_stack1
else:
raise ValueError('Stack dim no supported. Must be 0 or 1.')
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
num_workers, collate_fn, pin_memory, drop_last,
timeout, worker_init_fn)
self.name = name
self.training = training
self.epoch_interval = epoch_interval
self.stack_dim = stack_dim