init commit of samurai

This commit is contained in:
Cheng-Yen Yang
2024-11-19 22:12:54 -08:00
parent f65f4ba181
commit c17e4cecc0
679 changed files with 123982 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .loader import LTRLoader
from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader

View File

@@ -0,0 +1,150 @@
import torch
import numpy as np
def batch_center2corner(boxes):
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
if isinstance(boxes, np.ndarray):
return np.stack([xmin, ymin, xmax, ymax], 1)
else:
return torch.stack([xmin, ymin, xmax, ymax], 1)
def batch_corner2center(boxes):
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
w = (boxes[:, 2] - boxes[:, 0])
h = (boxes[:, 3] - boxes[:, 1])
if isinstance(boxes, np.ndarray):
return np.stack([cx, cy, w, h], 1)
else:
return torch.stack([cx, cy, w, h], 1)
def batch_xywh2center(boxes):
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
w = boxes[:, 2]
h = boxes[:, 3]
if isinstance(boxes, np.ndarray):
return np.stack([cx, cy, w, h], 1)
else:
return torch.stack([cx, cy, w, h], 1)
def batch_xywh2center2(boxes):
cx = boxes[:, 0] + boxes[:, 2] / 2
cy = boxes[:, 1] + boxes[:, 3] / 2
w = boxes[:, 2]
h = boxes[:, 3]
if isinstance(boxes, np.ndarray):
return np.stack([cx, cy, w, h], 1)
else:
return torch.stack([cx, cy, w, h], 1)
def batch_xywh2corner(boxes):
xmin = boxes[:, 0]
ymin = boxes[:, 1]
xmax = boxes[:, 0] + boxes[:, 2]
ymax = boxes[:, 1] + boxes[:, 3]
if isinstance(boxes, np.ndarray):
return np.stack([xmin, ymin, xmax, ymax], 1)
else:
return torch.stack([xmin, ymin, xmax, ymax], 1)
def rect_to_rel(bb, sz_norm=None):
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
args:
bb - N x 4 tensor of boxes.
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
"""
c = bb[...,:2] + 0.5 * bb[...,2:]
if sz_norm is None:
c_rel = c / bb[...,2:]
else:
c_rel = c / sz_norm
sz_rel = torch.log(bb[...,2:])
return torch.cat((c_rel, sz_rel), dim=-1)
def rel_to_rect(bb, sz_norm=None):
"""Inverts the effect of rect_to_rel. See above."""
sz = torch.exp(bb[...,2:])
if sz_norm is None:
c = bb[...,:2] * sz
else:
c = bb[...,:2] * sz_norm
tl = c - 0.5 * sz
return torch.cat((tl, sz), dim=-1)
def masks_to_bboxes(mask, fmt='c'):
""" Convert a mask tensor to one or more bounding boxes.
Note: This function is a bit new, make sure it does what it says. /Andreas
:param mask: Tensor of masks, shape = (..., H, W)
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
't' => "top left + size" or (x_left, y_top, width, height)
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
"""
batch_shape = mask.shape[:-2]
mask = mask.reshape((-1, *mask.shape[-2:]))
bboxes = []
for m in mask:
mx = m.sum(dim=-2).nonzero()
my = m.sum(dim=-1).nonzero()
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
bboxes.append(bb)
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
bboxes = bboxes.reshape(batch_shape + (4,))
if fmt == 'v':
return bboxes
x1 = bboxes[..., :2]
s = bboxes[..., 2:] - x1 + 1
if fmt == 'c':
return torch.cat((x1 + 0.5 * s, s), dim=-1)
elif fmt == 't':
return torch.cat((x1, s), dim=-1)
raise ValueError("Undefined bounding box layout '%s'" % fmt)
def masks_to_bboxes_multi(mask, ids, fmt='c'):
assert mask.dim() == 2
bboxes = []
for id in ids:
mx = (mask == id).sum(dim=-2).nonzero()
my = (mask == id).float().sum(dim=-1).nonzero()
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
x1 = bb[:2]
s = bb[2:] - x1 + 1
if fmt == 'v':
pass
elif fmt == 'c':
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
elif fmt == 't':
bb = torch.cat((x1, s), dim=-1)
else:
raise ValueError("Undefined bounding box layout '%s'" % fmt)
bboxes.append(bb)
return bboxes

View File

@@ -0,0 +1,103 @@
import jpeg4py
import cv2 as cv
from PIL import Image
import numpy as np
davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
[0, 64, 128], [128, 64, 128]]
def default_image_loader(path):
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
but reverts to the opencv_loader if the former is not available."""
if default_image_loader.use_jpeg4py is None:
# Try using jpeg4py
im = jpeg4py_loader(path)
if im is None:
default_image_loader.use_jpeg4py = False
print('Using opencv_loader instead.')
else:
default_image_loader.use_jpeg4py = True
return im
if default_image_loader.use_jpeg4py:
return jpeg4py_loader(path)
return opencv_loader(path)
default_image_loader.use_jpeg4py = None
def jpeg4py_loader(path):
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
try:
return jpeg4py.JPEG(path).decode()
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None
def opencv_loader(path):
""" Read image using opencv's imread function and returns it in rgb format"""
try:
im = cv.imread(path, cv.IMREAD_COLOR)
# convert to rgb and return
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None
def jpeg4py_loader_w_failsafe(path):
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
try:
return jpeg4py.JPEG(path).decode()
except:
try:
im = cv.imread(path, cv.IMREAD_COLOR)
# convert to rgb and return
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None
def opencv_seg_loader(path):
""" Read segmentation annotation using opencv's imread function"""
try:
return cv.imread(path)
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None
def imread_indexed(filename):
""" Load indexed image with given filename. Used to read segmentation annotations."""
im = Image.open(filename)
annotation = np.atleast_3d(im)[...,0]
return annotation
def imwrite_indexed(filename, array, color_palette=None):
""" Save indexed image as png. Used to save segmentation annotation."""
if color_palette is None:
color_palette = davis_palette
if np.atleast_3d(array).shape[2] != 1:
raise Exception("Saving indexed PNGs requires 2D array.")
im = Image.fromarray(array)
im.putpalette(color_palette.ravel())
im.save(filename, format='PNG')

199
lib/train/data/loader.py Normal file
View File

@@ -0,0 +1,199 @@
import torch
import torch.utils.data.dataloader
import importlib
import collections
# from torch._six import string_classes
from lib.utils import TensorDict, TensorList
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
int_classes = int
else:
from torch._six import int_classes
import warnings
warnings.filterwarnings("ignore")
string_classes = str
def _check_use_shared_memory():
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
if hasattr(collate_lib, '_use_shared_memory'):
return getattr(collate_lib, '_use_shared_memory')
return torch.utils.data.get_worker_info() is not None
def ltr_collate(batch):
"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
# if batch[0].dim() < 4:
# return torch.stack(batch, 0, out=out)
# return torch.cat(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], TensorDict):
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
elif isinstance(batch[0], collections.Mapping):
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], TensorList):
transposed = zip(*batch)
return TensorList([ltr_collate(samples) for samples in transposed])
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [ltr_collate(samples) for samples in transposed]
elif batch[0] is None:
return batch
raise TypeError((error_msg.format(type(batch[0]))))
def ltr_collate_stack1(batch):
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _check_use_shared_memory():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 1, out=out)
# if batch[0].dim() < 4:
# return torch.stack(batch, 0, out=out)
# return torch.cat(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 1)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], TensorDict):
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
elif isinstance(batch[0], collections.Mapping):
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], TensorList):
transposed = zip(*batch)
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [ltr_collate_stack1(samples) for samples in transposed]
elif batch[0] is None:
return batch
raise TypeError((error_msg.format(type(batch[0]))))
class LTRLoader(torch.utils.data.dataloader.DataLoader):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
select along which dimension the data should be stacked to form a batch.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: None)
.. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
by main process using its RNG. However, seeds for other libraries
may be duplicated upon initializing workers (w.g., NumPy), causing
each worker to return identical random numbers. (See
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
use ``torch.initial_seed()`` to access the PyTorch seed for each
worker in :attr:`worker_init_fn`, and use it to set other seeds
before data loading.
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
"""
__initialized = False
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
print("pin_memory is", pin_memory)
if collate_fn is None:
if stack_dim == 0:
collate_fn = ltr_collate
elif stack_dim == 1:
collate_fn = ltr_collate_stack1
else:
raise ValueError('Stack dim no supported. Must be 0 or 1.')
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
num_workers, collate_fn, pin_memory, drop_last,
timeout, worker_init_fn)
self.name = name
self.training = training
self.epoch_interval = epoch_interval
self.stack_dim = stack_dim

View File

@@ -0,0 +1,155 @@
import torch
import torchvision.transforms as transforms
from lib.utils import TensorDict
import lib.train.data.processing_utils as prutils
import torch.nn.functional as F
def stack_tensors(x):
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
return torch.stack(x)
return x
class BaseProcessing:
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
through the network. For example, it can be used to crop a search region around the object, apply various data
augmentations, etc."""
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
"""
args:
transform - The set of transformations to be applied on the images. Used only if template_transform or
search_transform is None.
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
argument is used instead.
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
argument is used instead.
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
example, it can be used to convert both template and search images to grayscale.
"""
self.transform = {'template': transform if template_transform is None else template_transform,
'search': transform if search_transform is None else search_transform,
'joint': joint_transform}
def __call__(self, data: TensorDict):
raise NotImplementedError
class STARKProcessing(BaseProcessing):
""" The processing class used for training LittleBoy. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
always at the center of the search region. The search region is then resized to a fixed size given by the
argument output_sz.
"""
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
mode='pair', settings=None, *args, **kwargs):
"""
args:
search_area_factor - The size of the search region relative to the target size.
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
square.
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
extracting the search region. See _get_jittered_box for how the jittering is done.
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
extracting the search region. See _get_jittered_box for how the jittering is done.
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
"""
super().__init__(*args, **kwargs)
self.search_area_factor = search_area_factor
self.output_sz = output_sz
self.center_jitter_factor = center_jitter_factor
self.scale_jitter_factor = scale_jitter_factor
self.mode = mode
self.settings = settings
def _get_jittered_box(self, box, mode):
""" Jitter the input box
args:
box - input bounding box
mode - string 'template' or 'search' indicating template or search data
returns:
torch.Tensor - jittered box
"""
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
def __call__(self, data: TensorDict):
"""
args:
data - The input data, should contain the following fields:
'template_images', search_images', 'template_anno', 'search_anno'
returns:
TensorDict - output data block with following fields:
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
"""
# Apply joint transforms
if self.transform['joint'] is not None:
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
for s in ['template', 'search']:
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
"In pair mode, num train/test frames must be 1"
# Add a uniform noise to the center pos
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
if (crop_sz < 1).any():
data['valid'] = False
# print("Too small box is found. Replace it with new data.")
return data
# Crop image region centered at jittered_anno box and get the attention mask
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
data[s + '_anno'], self.search_area_factor[s],
self.output_sz[s], masks=data[s + '_masks'])
# Apply transforms
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
for ele in data[s + '_att']:
if (ele == 1).all():
data['valid'] = False
# print("Values of original attention mask are all one. Replace it with new data.")
return data
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
for ele in data[s + '_att']:
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
if (mask_down == 1).all():
data['valid'] = False
# print("Values of down-sampled attention mask are all one. "
# "Replace it with new data.")
return data
data['valid'] = True
# if we use copy-and-paste augmentation
if data["template_masks"] is None or data["search_masks"] is None:
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
# Prepare output
if self.mode == 'sequence':
data = data.apply(stack_tensors)
else:
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
return data

View File

@@ -0,0 +1,168 @@
import torch
import math
import cv2 as cv
import torch.nn.functional as F
import numpy as np
'''modified from the original test implementation
Replace cv.BORDER_REPLICATE with cv.BORDER_CONSTANT
Add a variable called att_mask for computing attention and positional encoding later'''
def sample_target(im, target_bb, search_area_factor, output_sz=None, mask=None):
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
args:
im - cv image
target_bb - target box [x, y, w, h]
search_area_factor - Ratio of crop size to target size
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
returns:
cv image - extracted crop
float - the factor by which the crop has been resized to make the crop size equal output_size
"""
if not isinstance(target_bb, list):
x, y, w, h = target_bb.tolist()
else:
x, y, w, h = target_bb
# Crop image
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
if crop_sz < 1:
raise Exception('Too small bounding box.')
x1 = round(x + 0.5 * w - crop_sz * 0.5)
x2 = x1 + crop_sz
y1 = round(y + 0.5 * h - crop_sz * 0.5)
y2 = y1 + crop_sz
x1_pad = max(0, -x1)
x2_pad = max(x2 - im.shape[1] + 1, 0)
y1_pad = max(0, -y1)
y2_pad = max(y2 - im.shape[0] + 1, 0)
# Crop target
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
if mask is not None:
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
# Pad
im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_CONSTANT)
# deal with attention mask
H, W, _ = im_crop_padded.shape
att_mask = np.ones((H,W))
end_x, end_y = -x2_pad, -y2_pad
if y2_pad == 0:
end_y = None
if x2_pad == 0:
end_x = None
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
if mask is not None:
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
if output_sz is not None:
resize_factor = output_sz / crop_sz
im_crop_padded = cv.resize(im_crop_padded, (output_sz, output_sz))
att_mask = cv.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
if mask is None:
return im_crop_padded, resize_factor, att_mask
mask_crop_padded = \
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
else:
if mask is None:
return im_crop_padded, att_mask.astype(np.bool_), 1.0
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
args:
box_in - the box for which the co-ordinates are to be transformed
box_extract - the box about which the image crop has been extracted.
resize_factor - the ratio between the original image scale and the scale of the image crop
crop_sz - size of the cropped image
returns:
torch.Tensor - transformed co-ordinates of box_in
"""
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]
box_in_center = box_in[0:2] + 0.5 * box_in[2:4]
box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
box_out_wh = box_in[2:4] * resize_factor
box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
if normalize:
return box_out / crop_sz[0]
else:
return box_out
def jittered_center_crop(frames, box_extract, box_gt, search_area_factor, output_sz, masks=None):
""" For each frame in frames, extracts a square crop centered at box_extract, of area search_area_factor^2
times box_extract area. The extracted crops are then resized to output_sz. Further, the co-ordinates of the box
box_gt are transformed to the image crop co-ordinates
args:
frames - list of frames
box_extract - list of boxes of same length as frames. The crops are extracted using anno_extract
box_gt - list of boxes of same length as frames. The co-ordinates of these boxes are transformed from
image co-ordinates to the crop co-ordinates
search_area_factor - The area of the extracted crop is search_area_factor^2 times box_extract area
output_sz - The size to which the extracted crops are resized
returns:
list - list of image crops
list - box_gt location in the crop co-ordinates
"""
if masks is None:
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz)
for f, a in zip(frames, box_extract)]
frames_crop, resize_factors, att_mask = zip(*crops_resize_factors)
masks_crop = None
else:
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz, m)
for f, a, m in zip(frames, box_extract, masks)]
frames_crop, resize_factors, att_mask, masks_crop = zip(*crops_resize_factors)
# frames_crop: tuple of ndarray (128,128,3), att_mask: tuple of ndarray (128,128)
crop_sz = torch.Tensor([output_sz, output_sz])
# find the bb location in the crop
'''Note that here we use normalized coord'''
box_crop = [transform_image_to_crop(a_gt, a_ex, rf, crop_sz, normalize=True)
for a_gt, a_ex, rf in zip(box_gt, box_extract, resize_factors)] # (x1,y1,w,h) list of tensors
return frames_crop, box_crop, att_mask, masks_crop
def transform_box_to_crop(box: torch.Tensor, crop_box: torch.Tensor, crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
args:
box - the box for which the co-ordinates are to be transformed
crop_box - bounding box defining the crop in the original image
crop_sz - size of the cropped image
returns:
torch.Tensor - transformed co-ordinates of box_in
"""
box_out = box.clone()
box_out[:2] -= crop_box[:2]
scale_factor = crop_sz / crop_box[2:]
box_out[:2] *= scale_factor
box_out[2:] *= scale_factor
if normalize:
return box_out / crop_sz[0]
else:
return box_out

349
lib/train/data/sampler.py Normal file
View File

@@ -0,0 +1,349 @@
import random
import torch.utils.data
from lib.utils import TensorDict
import numpy as np
def no_processing(data):
return data
class TrackingSampler(torch.utils.data.Dataset):
""" Class responsible for sampling frames from training sequences to form batches.
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
The sampled frames are then passed through the input 'processing' function for the necessary processing-
"""
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
train_cls=False, pos_prob=0.5):
"""
args:
datasets - List of datasets to be used for training
p_datasets - List containing the probabilities by which each dataset will be sampled
samples_per_epoch - Number of training samples per epoch
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
num_search_frames - Number of search frames to sample.
num_template_frames - Number of template frames to sample.
processing - An instance of Processing class which performs the necessary processing of the data.
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the test frames are sampled in a causally,
otherwise randomly within the interval.
"""
self.datasets = datasets
self.train_cls = train_cls # whether we are training classification
self.pos_prob = pos_prob # probability of sampling positive class when making classification
# If p not provided, sample uniformly from all videos
if p_datasets is None:
p_datasets = [len(d) for d in self.datasets]
# Normalize
p_total = sum(p_datasets)
self.p_datasets = [x / p_total for x in p_datasets]
self.samples_per_epoch = samples_per_epoch
self.max_gap = max_gap
self.num_search_frames = num_search_frames
self.num_template_frames = num_template_frames
self.processing = processing
self.frame_sample_mode = frame_sample_mode
def __len__(self):
return self.samples_per_epoch
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
allow_invisible=False, force_invisible=False):
""" Samples num_ids frames between min_id and max_id for which target is visible
args:
visible - 1d Tensor indicating whether target is visible for each frame
num_ids - number of frames to be samples
min_id - Minimum allowed frame number
max_id - Maximum allowed frame number
returns:
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
"""
if num_ids == 0:
return []
if min_id is None or min_id < 0:
min_id = 0
if max_id is None or max_id > len(visible):
max_id = len(visible)
# get valid ids
if force_invisible:
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
else:
if allow_invisible:
valid_ids = [i for i in range(min_id, max_id)]
else:
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
# No visible ids
if len(valid_ids) == 0:
return None
return random.choices(valid_ids, k=num_ids)
def __getitem__(self, index):
if self.train_cls:
return self.getitem_cls()
else:
return self.getitem()
def getitem(self):
"""
returns:
TensorDict - dict containing all the data blocks
"""
valid = False
while not valid:
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence from the given dataset
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
if is_video_dataset:
template_frame_ids = None
search_frame_ids = None
gap_increase = 0
if self.frame_sample_mode == 'causal':
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
while search_frame_ids is None:
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
max_id=len(visible) - self.num_search_frames)
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
min_id=base_frame_id[0] - self.max_gap - gap_increase,
max_id=base_frame_id[0])
if prev_frame_ids is None:
gap_increase += 5
continue
template_frame_ids = base_frame_id + prev_frame_ids
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
num_ids=self.num_search_frames)
# Increase gap until a frame is found
gap_increase += 5
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
elif self.frame_sample_mode == "stark":
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
else:
raise ValueError("Illegal frame sample mode")
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
try:
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
H, W, _ = template_frames[0].shape
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
data = TensorDict({'template_images': template_frames,
'template_anno': template_anno['bbox'],
'template_masks': template_masks,
'search_images': search_frames,
'search_anno': search_anno['bbox'],
'search_masks': search_masks,
'dataset': dataset.get_name(),
'test_class': meta_obj_test.get('object_class_name')})
# make data augmentation
data = self.processing(data)
# check whether data is valid
valid = data['valid']
except:
valid = False
return data
def getitem_cls(self):
# get data for classification
"""
args:
index (int): Index (Ignored since we sample randomly)
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
returns:
TensorDict - dict containing all the data blocks
"""
valid = False
label = None
while not valid:
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence from the given dataset
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
# sample template and search frame ids
if is_video_dataset:
if self.frame_sample_mode in ["trident", "trident_pro"]:
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
elif self.frame_sample_mode == "stark":
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
else:
raise ValueError("illegal frame sample mode")
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
try:
# "try" is used to handle trackingnet data failure
# get images and bounding boxes (for templates)
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
seq_info_dict)
H, W, _ = template_frames[0].shape
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
(H, W))] * self.num_template_frames
# get images and bounding boxes (for searches)
# positive samples
if random.random() < self.pos_prob:
label = torch.ones(1,)
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
# negative samples
else:
label = torch.zeros(1,)
if is_video_dataset:
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
if search_frame_ids is None:
search_frames, search_anno, meta_obj_test = self.get_one_search()
else:
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
seq_info_dict)
search_anno["bbox"] = [self.get_center_box(H, W)]
else:
search_frames, search_anno, meta_obj_test = self.get_one_search()
H, W, _ = search_frames[0].shape
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
(H, W))] * self.num_search_frames
data = TensorDict({'template_images': template_frames,
'template_anno': template_anno['bbox'],
'template_masks': template_masks,
'search_images': search_frames,
'search_anno': search_anno['bbox'],
'search_masks': search_masks,
'dataset': dataset.get_name(),
'test_class': meta_obj_test.get('object_class_name')})
# make data augmentation
data = self.processing(data)
# add classification label
data["label"] = label
# check whether data is valid
valid = data['valid']
except:
valid = False
return data
def get_center_box(self, H, W, ratio=1/8):
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
def sample_seq_from_dataset(self, dataset, is_video_dataset):
# Sample a sequence with enough visible frames
enough_visible_frames = False
while not enough_visible_frames:
# Sample a sequence
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
# Sample frames
seq_info_dict = dataset.get_sequence_info(seq_id)
visible = seq_info_dict['visible']
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
enough_visible_frames = enough_visible_frames or not is_video_dataset
return seq_id, visible, seq_info_dict
def get_one_search(self):
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
is_video_dataset = dataset.is_video_sequence()
# sample a sequence
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
# sample a frame
if is_video_dataset:
if self.frame_sample_mode == "stark":
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
else:
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
else:
search_frame_ids = [1]
# get the image, bounding box and other info
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
return search_frames, search_anno, meta_obj_test
def get_frame_ids_trident(self, visible):
# get template and search ids in a 'trident' manner
template_frame_ids_extra = []
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
template_frame_ids_extra = []
# first randomly sample two frames from a video
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
# get the dynamic template id
for max_gap in self.max_gap:
if template_frame_id1[0] >= search_frame_ids[0]:
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
else:
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
if self.frame_sample_mode == "trident_pro":
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
allow_invisible=True)
else:
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
if f_id is None:
template_frame_ids_extra += [None]
else:
template_frame_ids_extra += f_id
template_frame_ids = template_frame_id1 + template_frame_ids_extra
return template_frame_ids, search_frame_ids
def get_frame_ids_stark(self, visible, valid):
# get template and search ids in a 'stark' manner
template_frame_ids_extra = []
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
template_frame_ids_extra = []
# first randomly sample two frames from a video
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
# get the dynamic template id
for max_gap in self.max_gap:
if template_frame_id1[0] >= search_frame_ids[0]:
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
else:
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
"""we require the frame to be valid but not necessary visible"""
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
if f_id is None:
template_frame_ids_extra += [None]
else:
template_frame_ids_extra += f_id
template_frame_ids = template_frame_id1 + template_frame_ids_extra
return template_frame_ids, search_frame_ids

View File

@@ -0,0 +1,265 @@
import random
import torch.utils.data
import numpy as np
from lib.utils import TensorDict
class SequenceSampler(torch.utils.data.Dataset):
"""
Sample sequence for sequence-level training
"""
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
"""
args:
datasets - List of datasets to be used for training
p_datasets - List containing the probabilities by which each dataset will be sampled
samples_per_epoch - Number of training samples per epoch
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
max_interval - Maximum interval between sampled frames
num_search_frames - Number of search frames to sample.
num_template_frames - Number of template frames to sample.
processing - An instance of Processing class which performs the necessary processing of the data.
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
otherwise randomly within the interval.
prob - sequential sampling by prob / interval sampling by 1-prob
"""
self.datasets = datasets
# If p not provided, sample uniformly from all videos
if p_datasets is None:
p_datasets = [len(d) for d in self.datasets]
# Normalize
p_total = sum(p_datasets)
self.p_datasets = [x / p_total for x in p_datasets]
self.samples_per_epoch = samples_per_epoch
self.max_gap = max_gap
self.max_interval = max_interval
self.num_search_frames = num_search_frames
self.num_template_frames = num_template_frames
self.frame_sample_mode = frame_sample_mode
self.prob=prob
self.extra=1
def __len__(self):
return self.samples_per_epoch
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
""" Samples num_ids frames between min_id and max_id for which target is visible
args:
visible - 1d Tensor indicating whether target is visible for each frame
num_ids - number of frames to be samples
min_id - Minimum allowed frame number
max_id - Maximum allowed frame number
returns:
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
"""
if num_ids == 0:
return []
if min_id is None or min_id < 0:
min_id = 0
if max_id is None or max_id > len(visible):
max_id = len(visible)
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
# No visible ids
if len(valid_ids) == 0:
return None
return random.choices(valid_ids, k=num_ids)
def _sequential_sample(self, visible):
# Sample frames in sequential manner
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
max_id=len(visible) - self.num_search_frames)
if self.max_gap == -1:
left = template_frame_ids[0]
else:
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
max_id=left_max)[0]
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
search_frame_ids = valid_ids[:self.num_search_frames]
# if length is not enough
last = search_frame_ids[-1]
while len(search_frame_ids) < self.num_search_frames:
if last >= len(visible) - 1:
search_frame_ids.append(last)
else:
last += 1
if visible[last]:
search_frame_ids.append(last)
return template_frame_ids, search_frame_ids
def _random_interval_sample(self, visible):
# Get valid ids
valid_ids = [i for i in range(len(visible)) if visible[i]]
# Sample template frame
avg_interval = self.max_interval
while avg_interval * (self.num_search_frames - 1) > len(visible):
avg_interval = max(avg_interval - 1, 1)
while True:
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
if template_frame_ids == None:
avg_interval = avg_interval - 1
else:
break
if avg_interval == 0:
template_frame_ids = [valid_ids[0]]
break
# Sample first search frame
if self.max_gap == -1:
search_frame_ids = template_frame_ids
else:
avg_interval = self.max_interval
while avg_interval * (self.num_search_frames - 1) > len(visible):
avg_interval = max(avg_interval - 1, 1)
while True:
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
template_frame_ids[0] + self.max_gap)
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
max_id=left_max)
if search_frame_ids == None:
avg_interval = avg_interval - 1
else:
break
if avg_interval == -1:
search_frame_ids = template_frame_ids
break
# Sample rest of the search frames with random interval
last = search_frame_ids[0]
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
# sample id with interval
max_id = min(last + self.max_interval + 1, len(visible))
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
max_id=max_id)
if id is None:
# If not found in current range, find from previous range
last = last + self.max_interval
else:
search_frame_ids.append(id[0])
last = search_frame_ids[-1]
# if length is not enough, randomly sample new ids
if len(search_frame_ids) < self.num_search_frames:
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
if len(valid_ids) > 0:
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
self.num_search_frames - len(search_frame_ids)))
search_frame_ids = search_frame_ids + new_ids
search_frame_ids = sorted(search_frame_ids, key=int)
# if length is still not enough, duplicate last frame
while len(search_frame_ids) < self.num_search_frames:
search_frame_ids.append(search_frame_ids[-1])
for i in range(1, self.num_search_frames):
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
print(search_frame_ids[i] - search_frame_ids[i - 1])
return template_frame_ids, search_frame_ids
def __getitem__(self, index):
"""
args:
index (int): Index (Ignored since we sample randomly)
returns:
TensorDict - dict containing all the data blocks
"""
# Select a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
if dataset.get_name() == 'got10k' :
max_gap = self.max_gap
max_interval = self.max_interval
else:
max_gap = self.max_gap
max_interval = self.max_interval
self.max_gap = max_gap * self.extra
self.max_interval = max_interval * self.extra
is_video_dataset = dataset.is_video_sequence()
# Sample a sequence with enough visible frames
while True:
try:
enough_visible_frames = False
while not enough_visible_frames:
# Sample a sequence
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
# Sample frames
seq_info_dict = dataset.get_sequence_info(seq_id)
visible = seq_info_dict['visible']
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
enough_visible_frames = enough_visible_frames or not is_video_dataset
if is_video_dataset:
if self.frame_sample_mode == 'sequential':
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
elif self.frame_sample_mode == 'random_interval':
if random.random() < self.prob:
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
else:
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
else:
self.max_gap = max_gap
self.max_interval = max_interval
raise NotImplementedError
else:
# In case of image dataset, just repeat the image to generate synthetic video
template_frame_ids = [1] * self.num_template_frames
search_frame_ids = [1] * self.num_search_frames
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
self.max_gap = max_gap
self.max_interval = max_interval
#print(self.max_gap, self.max_interval)
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
# print("====================================================================================")
# print("dataset index: {}".format(index))
# print("seq_id: {}".format(seq_id))
# print('template_frame_ids: {}'.format(template_frame_ids))
# print('search_frame_ids: {}'.format(search_frame_ids))
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
'template_annos': np.array(template_bbox).squeeze(),
'search_images': np.array(search_frames), # (num_frames) search images
'search_annos': np.array(search_bbox),
'seq_id': seq_id,
'dataset': dataset.get_name(),
'search_class': meta_obj_search.get('object_class_name'),
'num_frames': len(search_frames)
})
except Exception:
pass

View File

@@ -0,0 +1,335 @@
import random
import numpy as np
import math
import cv2 as cv
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tvisf
class Transform:
"""A set of transformations, used for e.g. data augmentation.
Args of constructor:
transforms: An arbitrary number of transformations, derived from the TransformBase class.
They are applied in the order they are given.
The Transform object can jointly transform images, bounding boxes and segmentation masks.
This is done by calling the object with the following key-word arguments (all are optional).
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
image - Image
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
bbox - Bounding box on the form [x, y, w, h]
mask - Segmentation mask with discrete classes
The following parameters can be supplied with calling the transform object:
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
different random rolls. Default: True.
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
is used instead. Default: True.
Check the DiMPProcessing class for examples.
"""
def __init__(self, *transforms):
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
transforms = transforms[0]
self.transforms = transforms
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
self._valid_args = ['joint', 'new_roll']
self._valid_all = self._valid_inputs + self._valid_args
def __call__(self, **inputs):
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
for v in inputs.keys():
if v not in self._valid_all:
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
joint_mode = inputs.get('joint', True)
new_roll = inputs.get('new_roll', True)
if not joint_mode:
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
return tuple(list(o) for o in out)
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
for t in self.transforms:
out = t(**out, joint=joint_mode, new_roll=new_roll)
if len(var_names) == 1:
return out[var_names[0]]
# Make sure order is correct
return tuple(out[v] for v in var_names)
def _split_inputs(self, inputs):
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
if isinstance(arg_val, list):
for inp, av in zip(split_inputs, arg_val):
inp[arg_name] = av
else:
for inp in split_inputs:
inp[arg_name] = arg_val
return split_inputs
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class TransformBase:
"""Base class for transformation objects. See the Transform class for details."""
def __init__(self):
"""2020.12.24 Add 'att' to valid inputs"""
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
self._valid_args = ['new_roll']
self._valid_all = self._valid_inputs + self._valid_args
self._rand_params = None
def __call__(self, **inputs):
# Split input
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
# Roll random parameters for the transform
if input_args.get('new_roll', True):
rand_params = self.roll()
if rand_params is None:
rand_params = ()
elif not isinstance(rand_params, tuple):
rand_params = (rand_params,)
self._rand_params = rand_params
outputs = dict()
for var_name, var in input_vars.items():
if var is not None:
transform_func = getattr(self, 'transform_' + var_name)
if var_name in ['coords', 'bbox']:
params = (self._get_image_size(input_vars),) + self._rand_params
else:
params = self._rand_params
if isinstance(var, (list, tuple)):
outputs[var_name] = [transform_func(x, *params) for x in var]
else:
outputs[var_name] = transform_func(var, *params)
return outputs
def _get_image_size(self, inputs):
im = None
for var_name in ['image', 'mask']:
if inputs.get(var_name) is not None:
im = inputs[var_name]
break
if im is None:
return None
if isinstance(im, (list, tuple)):
im = im[0]
if isinstance(im, np.ndarray):
return im.shape[:2]
if torch.is_tensor(im):
return (im.shape[-2], im.shape[-1])
raise Exception('Unknown image type')
def roll(self):
return None
def transform_image(self, image, *rand_params):
"""Must be deterministic"""
return image
def transform_coords(self, coords, image_shape, *rand_params):
"""Must be deterministic"""
return coords
def transform_bbox(self, bbox, image_shape, *rand_params):
"""Assumes [x, y, w, h]"""
# Check if not overloaded
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
return bbox
coord = bbox.clone().view(-1,2).t().flip(0)
x1 = coord[1, 0]
x2 = coord[1, 0] + coord[1, 1]
y1 = coord[0, 0]
y2 = coord[0, 0] + coord[0, 1]
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
tl = torch.min(coord_transf, dim=1)[0]
sz = torch.max(coord_transf, dim=1)[0] - tl
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
return bbox_out
def transform_mask(self, mask, *rand_params):
"""Must be deterministic"""
return mask
def transform_att(self, att, *rand_params):
"""2020.12.24 Added to deal with attention masks"""
return att
class ToTensor(TransformBase):
"""Convert to a Tensor"""
def transform_image(self, image):
# handle numpy array
if image.ndim == 2:
image = image[:, :, None]
image = torch.from_numpy(image.transpose((2, 0, 1)))
# backward compatibility
if isinstance(image, torch.ByteTensor):
return image.float().div(255)
else:
return image
def transfrom_mask(self, mask):
if isinstance(mask, np.ndarray):
return torch.from_numpy(mask)
def transform_att(self, att):
if isinstance(att, np.ndarray):
return torch.from_numpy(att).to(torch.bool)
elif isinstance(att, torch.Tensor):
return att.to(torch.bool)
else:
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
class ToTensorAndJitter(TransformBase):
"""Convert to a Tensor and jitter brightness"""
def __init__(self, brightness_jitter=0.0, normalize=True):
super().__init__()
self.brightness_jitter = brightness_jitter
self.normalize = normalize
def roll(self):
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
def transform_image(self, image, brightness_factor):
# handle numpy array
image = torch.from_numpy(image.transpose((2, 0, 1)))
# backward compatibility
if self.normalize:
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
else:
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
def transform_mask(self, mask, brightness_factor):
if isinstance(mask, np.ndarray):
return torch.from_numpy(mask)
else:
return mask
def transform_att(self, att, brightness_factor):
if isinstance(att, np.ndarray):
return torch.from_numpy(att).to(torch.bool)
elif isinstance(att, torch.Tensor):
return att.to(torch.bool)
else:
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
class Normalize(TransformBase):
"""Normalize image"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
def transform_image(self, image):
return tvisf.normalize(image, self.mean, self.std, self.inplace)
class ToGrayscale(TransformBase):
"""Converts image to grayscale with probability"""
def __init__(self, probability = 0.5):
super().__init__()
self.probability = probability
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
def roll(self):
return random.random() < self.probability
def transform_image(self, image, do_grayscale):
if do_grayscale:
if torch.is_tensor(image):
raise NotImplementedError('Implement torch variant.')
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
return np.stack([img_gray, img_gray, img_gray], axis=2)
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
return image
class ToBGR(TransformBase):
"""Converts image to BGR"""
def transform_image(self, image):
if torch.is_tensor(image):
raise NotImplementedError('Implement torch variant.')
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
return img_bgr
class RandomHorizontalFlip(TransformBase):
"""Horizontally flip image randomly with a probability p."""
def __init__(self, probability = 0.5):
super().__init__()
self.probability = probability
def roll(self):
return random.random() < self.probability
def transform_image(self, image, do_flip):
if do_flip:
if torch.is_tensor(image):
return image.flip((2,))
return np.fliplr(image).copy()
return image
def transform_coords(self, coords, image_shape, do_flip):
if do_flip:
coords_flip = coords.clone()
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
return coords_flip
return coords
def transform_mask(self, mask, do_flip):
if do_flip:
if torch.is_tensor(mask):
return mask.flip((-1,))
return np.fliplr(mask).copy()
return mask
def transform_att(self, att, do_flip):
if do_flip:
if torch.is_tensor(att):
return att.flip((-1,))
return np.fliplr(att).copy()
return att
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
"""Horizontally flip image randomly with a probability p.
The difference is that the coord is normalized to [0,1]"""
def __init__(self, probability = 0.5):
super().__init__()
self.probability = probability
def transform_coords(self, coords, image_shape, do_flip):
"""we should use 1 rather than image_shape"""
if do_flip:
coords_flip = coords.clone()
coords_flip[1,:] = 1 - coords[1,:]
return coords_flip
return coords

View File

@@ -0,0 +1,33 @@
from collections import OrderedDict
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
class WandbWriter:
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
self.wandb = wandb
self.step = cur_step
self.interval = step_interval
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
def write_log(self, stats: OrderedDict, epoch=-1):
self.step += 1
for loader_name, loader_stats in stats.items():
if loader_stats is None:
continue
log_dict = {}
for var_name, val in loader_stats.items():
if hasattr(val, 'avg'):
log_dict.update({loader_name + '/' + var_name: val.avg})
else:
log_dict.update({loader_name + '/' + var_name: val.val})
if epoch >= 0:
log_dict.update({loader_name + '/epoch': epoch})
self.wandb.log(log_dict, step=self.step*self.interval)