init commit of samurai
This commit is contained in:
2
lib/train/data/__init__.py
Normal file
2
lib/train/data/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .loader import LTRLoader
|
||||
from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader
|
150
lib/train/data/bounding_box_utils.py
Normal file
150
lib/train/data/bounding_box_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def batch_center2corner(boxes):
|
||||
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
|
||||
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
|
||||
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
|
||||
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def batch_corner2center(boxes):
|
||||
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
|
||||
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
|
||||
w = (boxes[:, 2] - boxes[:, 0])
|
||||
h = (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center(boxes):
|
||||
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
|
||||
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center2(boxes):
|
||||
cx = boxes[:, 0] + boxes[:, 2] / 2
|
||||
cy = boxes[:, 1] + boxes[:, 3] / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
|
||||
def batch_xywh2corner(boxes):
|
||||
xmin = boxes[:, 0]
|
||||
ymin = boxes[:, 1]
|
||||
xmax = boxes[:, 0] + boxes[:, 2]
|
||||
ymax = boxes[:, 1] + boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def rect_to_rel(bb, sz_norm=None):
|
||||
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
|
||||
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
|
||||
args:
|
||||
bb - N x 4 tensor of boxes.
|
||||
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
|
||||
"""
|
||||
|
||||
c = bb[...,:2] + 0.5 * bb[...,2:]
|
||||
if sz_norm is None:
|
||||
c_rel = c / bb[...,2:]
|
||||
else:
|
||||
c_rel = c / sz_norm
|
||||
sz_rel = torch.log(bb[...,2:])
|
||||
return torch.cat((c_rel, sz_rel), dim=-1)
|
||||
|
||||
|
||||
def rel_to_rect(bb, sz_norm=None):
|
||||
"""Inverts the effect of rect_to_rel. See above."""
|
||||
|
||||
sz = torch.exp(bb[...,2:])
|
||||
if sz_norm is None:
|
||||
c = bb[...,:2] * sz
|
||||
else:
|
||||
c = bb[...,:2] * sz_norm
|
||||
tl = c - 0.5 * sz
|
||||
return torch.cat((tl, sz), dim=-1)
|
||||
|
||||
|
||||
def masks_to_bboxes(mask, fmt='c'):
|
||||
|
||||
""" Convert a mask tensor to one or more bounding boxes.
|
||||
Note: This function is a bit new, make sure it does what it says. /Andreas
|
||||
:param mask: Tensor of masks, shape = (..., H, W)
|
||||
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
|
||||
't' => "top left + size" or (x_left, y_top, width, height)
|
||||
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
|
||||
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
|
||||
"""
|
||||
batch_shape = mask.shape[:-2]
|
||||
mask = mask.reshape((-1, *mask.shape[-2:]))
|
||||
bboxes = []
|
||||
|
||||
for m in mask:
|
||||
mx = m.sum(dim=-2).nonzero()
|
||||
my = m.sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
bboxes.append(bb)
|
||||
|
||||
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
|
||||
bboxes = bboxes.reshape(batch_shape + (4,))
|
||||
|
||||
if fmt == 'v':
|
||||
return bboxes
|
||||
|
||||
x1 = bboxes[..., :2]
|
||||
s = bboxes[..., 2:] - x1 + 1
|
||||
|
||||
if fmt == 'c':
|
||||
return torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
return torch.cat((x1, s), dim=-1)
|
||||
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
|
||||
|
||||
def masks_to_bboxes_multi(mask, ids, fmt='c'):
|
||||
assert mask.dim() == 2
|
||||
bboxes = []
|
||||
|
||||
for id in ids:
|
||||
mx = (mask == id).sum(dim=-2).nonzero()
|
||||
my = (mask == id).float().sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
|
||||
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
|
||||
|
||||
x1 = bb[:2]
|
||||
s = bb[2:] - x1 + 1
|
||||
|
||||
if fmt == 'v':
|
||||
pass
|
||||
elif fmt == 'c':
|
||||
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
bb = torch.cat((x1, s), dim=-1)
|
||||
else:
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
bboxes.append(bb)
|
||||
|
||||
return bboxes
|
103
lib/train/data/image_loader.py
Normal file
103
lib/train/data/image_loader.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import jpeg4py
|
||||
import cv2 as cv
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
|
||||
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
|
||||
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
|
||||
[0, 64, 128], [128, 64, 128]]
|
||||
|
||||
|
||||
def default_image_loader(path):
|
||||
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
|
||||
but reverts to the opencv_loader if the former is not available."""
|
||||
if default_image_loader.use_jpeg4py is None:
|
||||
# Try using jpeg4py
|
||||
im = jpeg4py_loader(path)
|
||||
if im is None:
|
||||
default_image_loader.use_jpeg4py = False
|
||||
print('Using opencv_loader instead.')
|
||||
else:
|
||||
default_image_loader.use_jpeg4py = True
|
||||
return im
|
||||
if default_image_loader.use_jpeg4py:
|
||||
return jpeg4py_loader(path)
|
||||
return opencv_loader(path)
|
||||
|
||||
default_image_loader.use_jpeg4py = None
|
||||
|
||||
|
||||
def jpeg4py_loader(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_loader(path):
|
||||
""" Read image using opencv's imread function and returns it in rgb format"""
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def jpeg4py_loader_w_failsafe(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except:
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_seg_loader(path):
|
||||
""" Read segmentation annotation using opencv's imread function"""
|
||||
try:
|
||||
return cv.imread(path)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def imread_indexed(filename):
|
||||
""" Load indexed image with given filename. Used to read segmentation annotations."""
|
||||
|
||||
im = Image.open(filename)
|
||||
|
||||
annotation = np.atleast_3d(im)[...,0]
|
||||
return annotation
|
||||
|
||||
|
||||
def imwrite_indexed(filename, array, color_palette=None):
|
||||
""" Save indexed image as png. Used to save segmentation annotation."""
|
||||
|
||||
if color_palette is None:
|
||||
color_palette = davis_palette
|
||||
|
||||
if np.atleast_3d(array).shape[2] != 1:
|
||||
raise Exception("Saving indexed PNGs requires 2D array.")
|
||||
|
||||
im = Image.fromarray(array)
|
||||
im.putpalette(color_palette.ravel())
|
||||
im.save(filename, format='PNG')
|
199
lib/train/data/loader.py
Normal file
199
lib/train/data/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.utils.data.dataloader
|
||||
import importlib
|
||||
import collections
|
||||
# from torch._six import string_classes
|
||||
from lib.utils import TensorDict, TensorList
|
||||
|
||||
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
|
||||
int_classes = int
|
||||
else:
|
||||
from torch._six import int_classes
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
string_classes = str
|
||||
|
||||
def _check_use_shared_memory():
|
||||
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
|
||||
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
|
||||
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
|
||||
if hasattr(collate_lib, '_use_shared_memory'):
|
||||
return getattr(collate_lib, '_use_shared_memory')
|
||||
return torch.utils.data.get_worker_info() is not None
|
||||
|
||||
|
||||
def ltr_collate(batch):
|
||||
"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 0, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
def ltr_collate_stack1(batch):
|
||||
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 1, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 1)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate_stack1(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
class LTRLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
|
||||
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
|
||||
select along which dimension the data should be stacked to form a batch.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: 1).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: False).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, ``shuffle`` must be False.
|
||||
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
||||
indices at a time. Mutually exclusive with batch_size, shuffle,
|
||||
sampler, and drop_last.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means that the data will be loaded in the main process.
|
||||
(default: 0)
|
||||
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
||||
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
||||
into CUDA pinned memory before returning them.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: False)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: 0)
|
||||
worker_init_fn (callable, optional): If not None, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: None)
|
||||
|
||||
.. note:: By default, each worker will have its PyTorch seed set to
|
||||
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
||||
by main process using its RNG. However, seeds for other libraries
|
||||
may be duplicated upon initializing workers (w.g., NumPy), causing
|
||||
each worker to return identical random numbers. (See
|
||||
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
||||
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
||||
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
||||
before data loading.
|
||||
|
||||
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
||||
unpicklable object, e.g., a lambda function.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
print("pin_memory is", pin_memory)
|
||||
if collate_fn is None:
|
||||
if stack_dim == 0:
|
||||
collate_fn = ltr_collate
|
||||
elif stack_dim == 1:
|
||||
collate_fn = ltr_collate_stack1
|
||||
else:
|
||||
raise ValueError('Stack dim no supported. Must be 0 or 1.')
|
||||
|
||||
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
155
lib/train/data/processing.py
Normal file
155
lib/train/data/processing.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from lib.utils import TensorDict
|
||||
import lib.train.data.processing_utils as prutils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stack_tensors(x):
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
|
||||
return torch.stack(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseProcessing:
|
||||
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
|
||||
through the network. For example, it can be used to crop a search region around the object, apply various data
|
||||
augmentations, etc."""
|
||||
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
|
||||
"""
|
||||
args:
|
||||
transform - The set of transformations to be applied on the images. Used only if template_transform or
|
||||
search_transform is None.
|
||||
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
|
||||
example, it can be used to convert both template and search images to grayscale.
|
||||
"""
|
||||
self.transform = {'template': transform if template_transform is None else template_transform,
|
||||
'search': transform if search_transform is None else search_transform,
|
||||
'joint': joint_transform}
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class STARKProcessing(BaseProcessing):
|
||||
""" The processing class used for training LittleBoy. The images are processed in the following way.
|
||||
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
|
||||
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
|
||||
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
|
||||
always at the center of the search region. The search region is then resized to a fixed size given by the
|
||||
argument output_sz.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
|
||||
mode='pair', settings=None, *args, **kwargs):
|
||||
"""
|
||||
args:
|
||||
search_area_factor - The size of the search region relative to the target size.
|
||||
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
|
||||
square.
|
||||
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.search_area_factor = search_area_factor
|
||||
self.output_sz = output_sz
|
||||
self.center_jitter_factor = center_jitter_factor
|
||||
self.scale_jitter_factor = scale_jitter_factor
|
||||
self.mode = mode
|
||||
self.settings = settings
|
||||
|
||||
def _get_jittered_box(self, box, mode):
|
||||
""" Jitter the input box
|
||||
args:
|
||||
box - input bounding box
|
||||
mode - string 'template' or 'search' indicating template or search data
|
||||
|
||||
returns:
|
||||
torch.Tensor - jittered box
|
||||
"""
|
||||
|
||||
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
|
||||
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
|
||||
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
|
||||
|
||||
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the following fields:
|
||||
'template_images', search_images', 'template_anno', 'search_anno'
|
||||
returns:
|
||||
TensorDict - output data block with following fields:
|
||||
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
|
||||
"""
|
||||
# Apply joint transforms
|
||||
if self.transform['joint'] is not None:
|
||||
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
|
||||
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
|
||||
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
|
||||
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
|
||||
|
||||
for s in ['template', 'search']:
|
||||
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
|
||||
"In pair mode, num train/test frames must be 1"
|
||||
|
||||
# Add a uniform noise to the center pos
|
||||
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
|
||||
|
||||
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
|
||||
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
|
||||
|
||||
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
|
||||
if (crop_sz < 1).any():
|
||||
data['valid'] = False
|
||||
# print("Too small box is found. Replace it with new data.")
|
||||
return data
|
||||
|
||||
# Crop image region centered at jittered_anno box and get the attention mask
|
||||
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
|
||||
data[s + '_anno'], self.search_area_factor[s],
|
||||
self.output_sz[s], masks=data[s + '_masks'])
|
||||
# Apply transforms
|
||||
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
|
||||
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
|
||||
|
||||
|
||||
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
|
||||
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
|
||||
for ele in data[s + '_att']:
|
||||
if (ele == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of original attention mask are all one. Replace it with new data.")
|
||||
return data
|
||||
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
|
||||
for ele in data[s + '_att']:
|
||||
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
|
||||
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
|
||||
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
|
||||
if (mask_down == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of down-sampled attention mask are all one. "
|
||||
# "Replace it with new data.")
|
||||
return data
|
||||
|
||||
data['valid'] = True
|
||||
# if we use copy-and-paste augmentation
|
||||
if data["template_masks"] is None or data["search_masks"] is None:
|
||||
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
|
||||
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
|
||||
# Prepare output
|
||||
if self.mode == 'sequence':
|
||||
data = data.apply(stack_tensors)
|
||||
else:
|
||||
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
|
||||
|
||||
return data
|
168
lib/train/data/processing_utils.py
Normal file
168
lib/train/data/processing_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
'''modified from the original test implementation
|
||||
Replace cv.BORDER_REPLICATE with cv.BORDER_CONSTANT
|
||||
Add a variable called att_mask for computing attention and positional encoding later'''
|
||||
|
||||
|
||||
def sample_target(im, target_bb, search_area_factor, output_sz=None, mask=None):
|
||||
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
||||
|
||||
args:
|
||||
im - cv image
|
||||
target_bb - target box [x, y, w, h]
|
||||
search_area_factor - Ratio of crop size to target size
|
||||
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
||||
|
||||
returns:
|
||||
cv image - extracted crop
|
||||
float - the factor by which the crop has been resized to make the crop size equal output_size
|
||||
"""
|
||||
if not isinstance(target_bb, list):
|
||||
x, y, w, h = target_bb.tolist()
|
||||
else:
|
||||
x, y, w, h = target_bb
|
||||
# Crop image
|
||||
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
||||
|
||||
if crop_sz < 1:
|
||||
raise Exception('Too small bounding box.')
|
||||
|
||||
x1 = round(x + 0.5 * w - crop_sz * 0.5)
|
||||
x2 = x1 + crop_sz
|
||||
|
||||
y1 = round(y + 0.5 * h - crop_sz * 0.5)
|
||||
y2 = y1 + crop_sz
|
||||
|
||||
x1_pad = max(0, -x1)
|
||||
x2_pad = max(x2 - im.shape[1] + 1, 0)
|
||||
|
||||
y1_pad = max(0, -y1)
|
||||
y2_pad = max(y2 - im.shape[0] + 1, 0)
|
||||
|
||||
# Crop target
|
||||
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
||||
if mask is not None:
|
||||
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
||||
|
||||
# Pad
|
||||
im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_CONSTANT)
|
||||
# deal with attention mask
|
||||
H, W, _ = im_crop_padded.shape
|
||||
att_mask = np.ones((H,W))
|
||||
end_x, end_y = -x2_pad, -y2_pad
|
||||
if y2_pad == 0:
|
||||
end_y = None
|
||||
if x2_pad == 0:
|
||||
end_x = None
|
||||
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
||||
if mask is not None:
|
||||
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
||||
|
||||
if output_sz is not None:
|
||||
resize_factor = output_sz / crop_sz
|
||||
im_crop_padded = cv.resize(im_crop_padded, (output_sz, output_sz))
|
||||
att_mask = cv.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
||||
if mask is None:
|
||||
return im_crop_padded, resize_factor, att_mask
|
||||
mask_crop_padded = \
|
||||
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
||||
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
||||
|
||||
else:
|
||||
if mask is None:
|
||||
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
||||
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
||||
|
||||
|
||||
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
|
||||
crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box_in - the box for which the co-ordinates are to be transformed
|
||||
box_extract - the box about which the image crop has been extracted.
|
||||
resize_factor - the ratio between the original image scale and the scale of the image crop
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]
|
||||
|
||||
box_in_center = box_in[0:2] + 0.5 * box_in[2:4]
|
||||
|
||||
box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
|
||||
box_out_wh = box_in[2:4] * resize_factor
|
||||
|
||||
box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
|
||||
def jittered_center_crop(frames, box_extract, box_gt, search_area_factor, output_sz, masks=None):
|
||||
""" For each frame in frames, extracts a square crop centered at box_extract, of area search_area_factor^2
|
||||
times box_extract area. The extracted crops are then resized to output_sz. Further, the co-ordinates of the box
|
||||
box_gt are transformed to the image crop co-ordinates
|
||||
|
||||
args:
|
||||
frames - list of frames
|
||||
box_extract - list of boxes of same length as frames. The crops are extracted using anno_extract
|
||||
box_gt - list of boxes of same length as frames. The co-ordinates of these boxes are transformed from
|
||||
image co-ordinates to the crop co-ordinates
|
||||
search_area_factor - The area of the extracted crop is search_area_factor^2 times box_extract area
|
||||
output_sz - The size to which the extracted crops are resized
|
||||
|
||||
returns:
|
||||
list - list of image crops
|
||||
list - box_gt location in the crop co-ordinates
|
||||
"""
|
||||
|
||||
if masks is None:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz)
|
||||
for f, a in zip(frames, box_extract)]
|
||||
frames_crop, resize_factors, att_mask = zip(*crops_resize_factors)
|
||||
masks_crop = None
|
||||
else:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz, m)
|
||||
for f, a, m in zip(frames, box_extract, masks)]
|
||||
frames_crop, resize_factors, att_mask, masks_crop = zip(*crops_resize_factors)
|
||||
# frames_crop: tuple of ndarray (128,128,3), att_mask: tuple of ndarray (128,128)
|
||||
crop_sz = torch.Tensor([output_sz, output_sz])
|
||||
|
||||
# find the bb location in the crop
|
||||
'''Note that here we use normalized coord'''
|
||||
box_crop = [transform_image_to_crop(a_gt, a_ex, rf, crop_sz, normalize=True)
|
||||
for a_gt, a_ex, rf in zip(box_gt, box_extract, resize_factors)] # (x1,y1,w,h) list of tensors
|
||||
|
||||
return frames_crop, box_crop, att_mask, masks_crop
|
||||
|
||||
|
||||
def transform_box_to_crop(box: torch.Tensor, crop_box: torch.Tensor, crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box - the box for which the co-ordinates are to be transformed
|
||||
crop_box - bounding box defining the crop in the original image
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
|
||||
box_out = box.clone()
|
||||
box_out[:2] -= crop_box[:2]
|
||||
|
||||
scale_factor = crop_sz / crop_box[2:]
|
||||
|
||||
box_out[:2] *= scale_factor
|
||||
box_out[2:] *= scale_factor
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
349
lib/train/data/sampler.py
Normal file
349
lib/train/data/sampler.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
from lib.utils import TensorDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def no_processing(data):
|
||||
return data
|
||||
|
||||
|
||||
class TrackingSampler(torch.utils.data.Dataset):
|
||||
""" Class responsible for sampling frames from training sequences to form batches.
|
||||
|
||||
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
|
||||
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
|
||||
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
|
||||
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
|
||||
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
|
||||
|
||||
The sampled frames are then passed through the input 'processing' function for the necessary processing-
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
|
||||
train_cls=False, pos_prob=0.5):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the test frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.train_cls = train_cls # whether we are training classification
|
||||
self.pos_prob = pos_prob # probability of sampling positive class when making classification
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.processing = processing
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
|
||||
allow_invisible=False, force_invisible=False):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
# get valid ids
|
||||
if force_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
|
||||
else:
|
||||
if allow_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id)]
|
||||
else:
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.train_cls:
|
||||
return self.getitem_cls()
|
||||
else:
|
||||
return self.getitem()
|
||||
|
||||
def getitem(self):
|
||||
"""
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
|
||||
if is_video_dataset:
|
||||
template_frame_ids = None
|
||||
search_frame_ids = None
|
||||
gap_increase = 0
|
||||
|
||||
if self.frame_sample_mode == 'causal':
|
||||
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
|
||||
while search_frame_ids is None:
|
||||
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
|
||||
min_id=base_frame_id[0] - self.max_gap - gap_increase,
|
||||
max_id=base_frame_id[0])
|
||||
if prev_frame_ids is None:
|
||||
gap_increase += 5
|
||||
continue
|
||||
template_frame_ids = base_frame_id + prev_frame_ids
|
||||
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
|
||||
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
|
||||
num_ids=self.num_search_frames)
|
||||
# Increase gap until a frame is found
|
||||
gap_increase += 5
|
||||
|
||||
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("Illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def getitem_cls(self):
|
||||
# get data for classification
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
label = None
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample template and search frame ids
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode in ["trident", "trident_pro"]:
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
# "try" is used to handle trackingnet data failure
|
||||
# get images and bounding boxes (for templates)
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
|
||||
seq_info_dict)
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
|
||||
(H, W))] * self.num_template_frames
|
||||
# get images and bounding boxes (for searches)
|
||||
# positive samples
|
||||
if random.random() < self.pos_prob:
|
||||
label = torch.ones(1,)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
# negative samples
|
||||
else:
|
||||
label = torch.zeros(1,)
|
||||
if is_video_dataset:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
|
||||
if search_frame_ids is None:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
|
||||
seq_info_dict)
|
||||
search_anno["bbox"] = [self.get_center_box(H, W)]
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
H, W, _ = search_frames[0].shape
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
# add classification label
|
||||
data["label"] = label
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def get_center_box(self, H, W, ratio=1/8):
|
||||
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
|
||||
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
|
||||
|
||||
def sample_seq_from_dataset(self, dataset, is_video_dataset):
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
return seq_id, visible, seq_info_dict
|
||||
|
||||
def get_one_search(self):
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
# sample a sequence
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample a frame
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == "stark":
|
||||
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
|
||||
else:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
|
||||
else:
|
||||
search_frame_ids = [1]
|
||||
# get the image, bounding box and other info
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
return search_frames, search_anno, meta_obj_test
|
||||
|
||||
def get_frame_ids_trident(self, visible):
|
||||
# get template and search ids in a 'trident' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
if self.frame_sample_mode == "trident_pro":
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
|
||||
allow_invisible=True)
|
||||
else:
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
def get_frame_ids_stark(self, visible, valid):
|
||||
# get template and search ids in a 'stark' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
"""we require the frame to be valid but not necessary visible"""
|
||||
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
265
lib/train/data/sequence_sampler.py
Normal file
265
lib/train/data/sequence_sampler.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class SequenceSampler(torch.utils.data.Dataset):
|
||||
"""
|
||||
Sample sequence for sequence-level training
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
|
||||
max_interval - Maximum interval between sampled frames
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
prob - sequential sampling by prob / interval sampling by 1-prob
|
||||
"""
|
||||
self.datasets = datasets
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
self.prob=prob
|
||||
self.extra=1
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
|
||||
def _sequential_sample(self, visible):
|
||||
# Sample frames in sequential manner
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
if self.max_gap == -1:
|
||||
left = template_frame_ids[0]
|
||||
else:
|
||||
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
|
||||
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
|
||||
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)[0]
|
||||
|
||||
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
|
||||
search_frame_ids = valid_ids[:self.num_search_frames]
|
||||
|
||||
# if length is not enough
|
||||
last = search_frame_ids[-1]
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
if last >= len(visible) - 1:
|
||||
search_frame_ids.append(last)
|
||||
else:
|
||||
last += 1
|
||||
if visible[last]:
|
||||
search_frame_ids.append(last)
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def _random_interval_sample(self, visible):
|
||||
# Get valid ids
|
||||
valid_ids = [i for i in range(len(visible)) if visible[i]]
|
||||
|
||||
# Sample template frame
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
|
||||
if template_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == 0:
|
||||
template_frame_ids = [valid_ids[0]]
|
||||
break
|
||||
|
||||
# Sample first search frame
|
||||
if self.max_gap == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
else:
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
|
||||
template_frame_ids[0] + self.max_gap)
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)
|
||||
|
||||
if search_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
break
|
||||
|
||||
# Sample rest of the search frames with random interval
|
||||
last = search_frame_ids[0]
|
||||
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
|
||||
# sample id with interval
|
||||
max_id = min(last + self.max_interval + 1, len(visible))
|
||||
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
|
||||
max_id=max_id)
|
||||
|
||||
if id is None:
|
||||
# If not found in current range, find from previous range
|
||||
last = last + self.max_interval
|
||||
else:
|
||||
search_frame_ids.append(id[0])
|
||||
last = search_frame_ids[-1]
|
||||
|
||||
# if length is not enough, randomly sample new ids
|
||||
if len(search_frame_ids) < self.num_search_frames:
|
||||
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
|
||||
|
||||
if len(valid_ids) > 0:
|
||||
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
|
||||
self.num_search_frames - len(search_frame_ids)))
|
||||
search_frame_ids = search_frame_ids + new_ids
|
||||
search_frame_ids = sorted(search_frame_ids, key=int)
|
||||
|
||||
# if length is still not enough, duplicate last frame
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
search_frame_ids.append(search_frame_ids[-1])
|
||||
|
||||
for i in range(1, self.num_search_frames):
|
||||
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
|
||||
print(search_frame_ids[i] - search_frame_ids[i - 1])
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
if dataset.get_name() == 'got10k' :
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
else:
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
self.max_gap = max_gap * self.extra
|
||||
self.max_interval = max_interval * self.extra
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
while True:
|
||||
try:
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == 'sequential':
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
|
||||
elif self.frame_sample_mode == 'random_interval':
|
||||
if random.random() < self.prob:
|
||||
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
|
||||
else:
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
else:
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
#print(self.max_gap, self.max_interval)
|
||||
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
|
||||
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
|
||||
# print("====================================================================================")
|
||||
# print("dataset index: {}".format(index))
|
||||
# print("seq_id: {}".format(seq_id))
|
||||
# print('template_frame_ids: {}'.format(template_frame_ids))
|
||||
# print('search_frame_ids: {}'.format(search_frame_ids))
|
||||
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
|
||||
'template_annos': np.array(template_bbox).squeeze(),
|
||||
'search_images': np.array(search_frames), # (num_frames) search images
|
||||
'search_annos': np.array(search_bbox),
|
||||
'seq_id': seq_id,
|
||||
'dataset': dataset.get_name(),
|
||||
'search_class': meta_obj_search.get('object_class_name'),
|
||||
'num_frames': len(search_frames)
|
||||
})
|
||||
except Exception:
|
||||
pass
|
335
lib/train/data/transforms.py
Normal file
335
lib/train/data/transforms.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
|
||||
|
||||
class Transform:
|
||||
"""A set of transformations, used for e.g. data augmentation.
|
||||
Args of constructor:
|
||||
transforms: An arbitrary number of transformations, derived from the TransformBase class.
|
||||
They are applied in the order they are given.
|
||||
|
||||
The Transform object can jointly transform images, bounding boxes and segmentation masks.
|
||||
This is done by calling the object with the following key-word arguments (all are optional).
|
||||
|
||||
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
|
||||
image - Image
|
||||
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
|
||||
bbox - Bounding box on the form [x, y, w, h]
|
||||
mask - Segmentation mask with discrete classes
|
||||
|
||||
The following parameters can be supplied with calling the transform object:
|
||||
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
|
||||
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
|
||||
different random rolls. Default: True.
|
||||
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
|
||||
is used instead. Default: True.
|
||||
|
||||
Check the DiMPProcessing class for examples.
|
||||
"""
|
||||
|
||||
def __init__(self, *transforms):
|
||||
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
|
||||
transforms = transforms[0]
|
||||
self.transforms = transforms
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['joint', 'new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
|
||||
def __call__(self, **inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
for v in inputs.keys():
|
||||
if v not in self._valid_all:
|
||||
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
|
||||
|
||||
joint_mode = inputs.get('joint', True)
|
||||
new_roll = inputs.get('new_roll', True)
|
||||
|
||||
if not joint_mode:
|
||||
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
|
||||
return tuple(list(o) for o in out)
|
||||
|
||||
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
|
||||
for t in self.transforms:
|
||||
out = t(**out, joint=joint_mode, new_roll=new_roll)
|
||||
if len(var_names) == 1:
|
||||
return out[var_names[0]]
|
||||
# Make sure order is correct
|
||||
return tuple(out[v] for v in var_names)
|
||||
|
||||
def _split_inputs(self, inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
|
||||
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
|
||||
if isinstance(arg_val, list):
|
||||
for inp, av in zip(split_inputs, arg_val):
|
||||
inp[arg_name] = av
|
||||
else:
|
||||
for inp in split_inputs:
|
||||
inp[arg_name] = arg_val
|
||||
return split_inputs
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class TransformBase:
|
||||
"""Base class for transformation objects. See the Transform class for details."""
|
||||
def __init__(self):
|
||||
"""2020.12.24 Add 'att' to valid inputs"""
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
self._rand_params = None
|
||||
|
||||
def __call__(self, **inputs):
|
||||
# Split input
|
||||
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
|
||||
|
||||
# Roll random parameters for the transform
|
||||
if input_args.get('new_roll', True):
|
||||
rand_params = self.roll()
|
||||
if rand_params is None:
|
||||
rand_params = ()
|
||||
elif not isinstance(rand_params, tuple):
|
||||
rand_params = (rand_params,)
|
||||
self._rand_params = rand_params
|
||||
|
||||
outputs = dict()
|
||||
for var_name, var in input_vars.items():
|
||||
if var is not None:
|
||||
transform_func = getattr(self, 'transform_' + var_name)
|
||||
if var_name in ['coords', 'bbox']:
|
||||
params = (self._get_image_size(input_vars),) + self._rand_params
|
||||
else:
|
||||
params = self._rand_params
|
||||
if isinstance(var, (list, tuple)):
|
||||
outputs[var_name] = [transform_func(x, *params) for x in var]
|
||||
else:
|
||||
outputs[var_name] = transform_func(var, *params)
|
||||
return outputs
|
||||
|
||||
def _get_image_size(self, inputs):
|
||||
im = None
|
||||
for var_name in ['image', 'mask']:
|
||||
if inputs.get(var_name) is not None:
|
||||
im = inputs[var_name]
|
||||
break
|
||||
if im is None:
|
||||
return None
|
||||
if isinstance(im, (list, tuple)):
|
||||
im = im[0]
|
||||
if isinstance(im, np.ndarray):
|
||||
return im.shape[:2]
|
||||
if torch.is_tensor(im):
|
||||
return (im.shape[-2], im.shape[-1])
|
||||
raise Exception('Unknown image type')
|
||||
|
||||
def roll(self):
|
||||
return None
|
||||
|
||||
def transform_image(self, image, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return coords
|
||||
|
||||
def transform_bbox(self, bbox, image_shape, *rand_params):
|
||||
"""Assumes [x, y, w, h]"""
|
||||
# Check if not overloaded
|
||||
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
|
||||
return bbox
|
||||
|
||||
coord = bbox.clone().view(-1,2).t().flip(0)
|
||||
|
||||
x1 = coord[1, 0]
|
||||
x2 = coord[1, 0] + coord[1, 1]
|
||||
|
||||
y1 = coord[0, 0]
|
||||
y2 = coord[0, 0] + coord[0, 1]
|
||||
|
||||
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
|
||||
|
||||
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
|
||||
tl = torch.min(coord_transf, dim=1)[0]
|
||||
sz = torch.max(coord_transf, dim=1)[0] - tl
|
||||
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
|
||||
return bbox_out
|
||||
|
||||
def transform_mask(self, mask, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, *rand_params):
|
||||
"""2020.12.24 Added to deal with attention masks"""
|
||||
return att
|
||||
|
||||
|
||||
class ToTensor(TransformBase):
|
||||
"""Convert to a Tensor"""
|
||||
|
||||
def transform_image(self, image):
|
||||
# handle numpy array
|
||||
if image.ndim == 2:
|
||||
image = image[:, :, None]
|
||||
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(image, torch.ByteTensor):
|
||||
return image.float().div(255)
|
||||
else:
|
||||
return image
|
||||
|
||||
def transfrom_mask(self, mask):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
|
||||
def transform_att(self, att):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class ToTensorAndJitter(TransformBase):
|
||||
"""Convert to a Tensor and jitter brightness"""
|
||||
def __init__(self, brightness_jitter=0.0, normalize=True):
|
||||
super().__init__()
|
||||
self.brightness_jitter = brightness_jitter
|
||||
self.normalize = normalize
|
||||
|
||||
def roll(self):
|
||||
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
|
||||
|
||||
def transform_image(self, image, brightness_factor):
|
||||
# handle numpy array
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
|
||||
# backward compatibility
|
||||
if self.normalize:
|
||||
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
|
||||
else:
|
||||
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
|
||||
|
||||
def transform_mask(self, mask, brightness_factor):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
else:
|
||||
return mask
|
||||
def transform_att(self, att, brightness_factor):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class Normalize(TransformBase):
|
||||
"""Normalize image"""
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
super().__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def transform_image(self, image):
|
||||
return tvisf.normalize(image, self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToGrayscale(TransformBase):
|
||||
"""Converts image to grayscale with probability"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_grayscale):
|
||||
if do_grayscale:
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
|
||||
return np.stack([img_gray, img_gray, img_gray], axis=2)
|
||||
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
|
||||
return image
|
||||
|
||||
|
||||
class ToBGR(TransformBase):
|
||||
"""Converts image to BGR"""
|
||||
def transform_image(self, image):
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
||||
return img_bgr
|
||||
|
||||
|
||||
class RandomHorizontalFlip(TransformBase):
|
||||
"""Horizontally flip image randomly with a probability p."""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(image):
|
||||
return image.flip((2,))
|
||||
return np.fliplr(image).copy()
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
|
||||
def transform_mask(self, mask, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(mask):
|
||||
return mask.flip((-1,))
|
||||
return np.fliplr(mask).copy()
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(att):
|
||||
return att.flip((-1,))
|
||||
return np.fliplr(att).copy()
|
||||
return att
|
||||
|
||||
|
||||
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
|
||||
"""Horizontally flip image randomly with a probability p.
|
||||
The difference is that the coord is normalized to [0,1]"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
"""we should use 1 rather than image_shape"""
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = 1 - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
33
lib/train/data/wandb_logger.py
Normal file
33
lib/train/data/wandb_logger.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install wandb" to install wandb')
|
||||
|
||||
|
||||
class WandbWriter:
|
||||
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
|
||||
self.wandb = wandb
|
||||
self.step = cur_step
|
||||
self.interval = step_interval
|
||||
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
|
||||
|
||||
def write_log(self, stats: OrderedDict, epoch=-1):
|
||||
self.step += 1
|
||||
for loader_name, loader_stats in stats.items():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
|
||||
log_dict = {}
|
||||
for var_name, val in loader_stats.items():
|
||||
if hasattr(val, 'avg'):
|
||||
log_dict.update({loader_name + '/' + var_name: val.avg})
|
||||
else:
|
||||
log_dict.update({loader_name + '/' + var_name: val.val})
|
||||
|
||||
if epoch >= 0:
|
||||
log_dict.update({loader_name + '/epoch': epoch})
|
||||
|
||||
self.wandb.log(log_dict, step=self.step*self.interval)
|
Reference in New Issue
Block a user