156 lines
8.3 KiB
Python
156 lines
8.3 KiB
Python
import torch
|
|
import torchvision.transforms as transforms
|
|
from lib.utils import TensorDict
|
|
import lib.train.data.processing_utils as prutils
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def stack_tensors(x):
|
|
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
|
|
return torch.stack(x)
|
|
return x
|
|
|
|
|
|
class BaseProcessing:
|
|
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
|
|
through the network. For example, it can be used to crop a search region around the object, apply various data
|
|
augmentations, etc."""
|
|
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
|
|
"""
|
|
args:
|
|
transform - The set of transformations to be applied on the images. Used only if template_transform or
|
|
search_transform is None.
|
|
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
|
|
argument is used instead.
|
|
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
|
|
argument is used instead.
|
|
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
|
|
example, it can be used to convert both template and search images to grayscale.
|
|
"""
|
|
self.transform = {'template': transform if template_transform is None else template_transform,
|
|
'search': transform if search_transform is None else search_transform,
|
|
'joint': joint_transform}
|
|
|
|
def __call__(self, data: TensorDict):
|
|
raise NotImplementedError
|
|
|
|
|
|
class STARKProcessing(BaseProcessing):
|
|
""" The processing class used for training LittleBoy. The images are processed in the following way.
|
|
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
|
|
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
|
|
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
|
|
always at the center of the search region. The search region is then resized to a fixed size given by the
|
|
argument output_sz.
|
|
|
|
"""
|
|
|
|
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
|
|
mode='pair', settings=None, *args, **kwargs):
|
|
"""
|
|
args:
|
|
search_area_factor - The size of the search region relative to the target size.
|
|
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
|
|
square.
|
|
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
|
|
extracting the search region. See _get_jittered_box for how the jittering is done.
|
|
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
|
|
extracting the search region. See _get_jittered_box for how the jittering is done.
|
|
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.search_area_factor = search_area_factor
|
|
self.output_sz = output_sz
|
|
self.center_jitter_factor = center_jitter_factor
|
|
self.scale_jitter_factor = scale_jitter_factor
|
|
self.mode = mode
|
|
self.settings = settings
|
|
|
|
def _get_jittered_box(self, box, mode):
|
|
""" Jitter the input box
|
|
args:
|
|
box - input bounding box
|
|
mode - string 'template' or 'search' indicating template or search data
|
|
|
|
returns:
|
|
torch.Tensor - jittered box
|
|
"""
|
|
|
|
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
|
|
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
|
|
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
|
|
|
|
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
|
|
|
|
def __call__(self, data: TensorDict):
|
|
"""
|
|
args:
|
|
data - The input data, should contain the following fields:
|
|
'template_images', search_images', 'template_anno', 'search_anno'
|
|
returns:
|
|
TensorDict - output data block with following fields:
|
|
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
|
|
"""
|
|
# Apply joint transforms
|
|
if self.transform['joint'] is not None:
|
|
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
|
|
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
|
|
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
|
|
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
|
|
|
|
for s in ['template', 'search']:
|
|
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
|
|
"In pair mode, num train/test frames must be 1"
|
|
|
|
# Add a uniform noise to the center pos
|
|
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
|
|
|
|
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
|
|
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
|
|
|
|
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
|
|
if (crop_sz < 1).any():
|
|
data['valid'] = False
|
|
# print("Too small box is found. Replace it with new data.")
|
|
return data
|
|
|
|
# Crop image region centered at jittered_anno box and get the attention mask
|
|
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
|
|
data[s + '_anno'], self.search_area_factor[s],
|
|
self.output_sz[s], masks=data[s + '_masks'])
|
|
# Apply transforms
|
|
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
|
|
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
|
|
|
|
|
|
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
|
|
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
|
|
for ele in data[s + '_att']:
|
|
if (ele == 1).all():
|
|
data['valid'] = False
|
|
# print("Values of original attention mask are all one. Replace it with new data.")
|
|
return data
|
|
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
|
|
for ele in data[s + '_att']:
|
|
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
|
|
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
|
|
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
|
|
if (mask_down == 1).all():
|
|
data['valid'] = False
|
|
# print("Values of down-sampled attention mask are all one. "
|
|
# "Replace it with new data.")
|
|
return data
|
|
|
|
data['valid'] = True
|
|
# if we use copy-and-paste augmentation
|
|
if data["template_masks"] is None or data["search_masks"] is None:
|
|
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
|
|
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
|
|
# Prepare output
|
|
if self.mode == 'sequence':
|
|
data = data.apply(stack_tensors)
|
|
else:
|
|
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
|
|
|
|
return data
|