init commit of samurai
This commit is contained in:
335
lib/train/data/transforms.py
Normal file
335
lib/train/data/transforms.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
|
||||
|
||||
class Transform:
|
||||
"""A set of transformations, used for e.g. data augmentation.
|
||||
Args of constructor:
|
||||
transforms: An arbitrary number of transformations, derived from the TransformBase class.
|
||||
They are applied in the order they are given.
|
||||
|
||||
The Transform object can jointly transform images, bounding boxes and segmentation masks.
|
||||
This is done by calling the object with the following key-word arguments (all are optional).
|
||||
|
||||
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
|
||||
image - Image
|
||||
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
|
||||
bbox - Bounding box on the form [x, y, w, h]
|
||||
mask - Segmentation mask with discrete classes
|
||||
|
||||
The following parameters can be supplied with calling the transform object:
|
||||
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
|
||||
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
|
||||
different random rolls. Default: True.
|
||||
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
|
||||
is used instead. Default: True.
|
||||
|
||||
Check the DiMPProcessing class for examples.
|
||||
"""
|
||||
|
||||
def __init__(self, *transforms):
|
||||
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
|
||||
transforms = transforms[0]
|
||||
self.transforms = transforms
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['joint', 'new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
|
||||
def __call__(self, **inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
for v in inputs.keys():
|
||||
if v not in self._valid_all:
|
||||
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
|
||||
|
||||
joint_mode = inputs.get('joint', True)
|
||||
new_roll = inputs.get('new_roll', True)
|
||||
|
||||
if not joint_mode:
|
||||
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
|
||||
return tuple(list(o) for o in out)
|
||||
|
||||
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
|
||||
for t in self.transforms:
|
||||
out = t(**out, joint=joint_mode, new_roll=new_roll)
|
||||
if len(var_names) == 1:
|
||||
return out[var_names[0]]
|
||||
# Make sure order is correct
|
||||
return tuple(out[v] for v in var_names)
|
||||
|
||||
def _split_inputs(self, inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
|
||||
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
|
||||
if isinstance(arg_val, list):
|
||||
for inp, av in zip(split_inputs, arg_val):
|
||||
inp[arg_name] = av
|
||||
else:
|
||||
for inp in split_inputs:
|
||||
inp[arg_name] = arg_val
|
||||
return split_inputs
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class TransformBase:
|
||||
"""Base class for transformation objects. See the Transform class for details."""
|
||||
def __init__(self):
|
||||
"""2020.12.24 Add 'att' to valid inputs"""
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
self._rand_params = None
|
||||
|
||||
def __call__(self, **inputs):
|
||||
# Split input
|
||||
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
|
||||
|
||||
# Roll random parameters for the transform
|
||||
if input_args.get('new_roll', True):
|
||||
rand_params = self.roll()
|
||||
if rand_params is None:
|
||||
rand_params = ()
|
||||
elif not isinstance(rand_params, tuple):
|
||||
rand_params = (rand_params,)
|
||||
self._rand_params = rand_params
|
||||
|
||||
outputs = dict()
|
||||
for var_name, var in input_vars.items():
|
||||
if var is not None:
|
||||
transform_func = getattr(self, 'transform_' + var_name)
|
||||
if var_name in ['coords', 'bbox']:
|
||||
params = (self._get_image_size(input_vars),) + self._rand_params
|
||||
else:
|
||||
params = self._rand_params
|
||||
if isinstance(var, (list, tuple)):
|
||||
outputs[var_name] = [transform_func(x, *params) for x in var]
|
||||
else:
|
||||
outputs[var_name] = transform_func(var, *params)
|
||||
return outputs
|
||||
|
||||
def _get_image_size(self, inputs):
|
||||
im = None
|
||||
for var_name in ['image', 'mask']:
|
||||
if inputs.get(var_name) is not None:
|
||||
im = inputs[var_name]
|
||||
break
|
||||
if im is None:
|
||||
return None
|
||||
if isinstance(im, (list, tuple)):
|
||||
im = im[0]
|
||||
if isinstance(im, np.ndarray):
|
||||
return im.shape[:2]
|
||||
if torch.is_tensor(im):
|
||||
return (im.shape[-2], im.shape[-1])
|
||||
raise Exception('Unknown image type')
|
||||
|
||||
def roll(self):
|
||||
return None
|
||||
|
||||
def transform_image(self, image, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return coords
|
||||
|
||||
def transform_bbox(self, bbox, image_shape, *rand_params):
|
||||
"""Assumes [x, y, w, h]"""
|
||||
# Check if not overloaded
|
||||
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
|
||||
return bbox
|
||||
|
||||
coord = bbox.clone().view(-1,2).t().flip(0)
|
||||
|
||||
x1 = coord[1, 0]
|
||||
x2 = coord[1, 0] + coord[1, 1]
|
||||
|
||||
y1 = coord[0, 0]
|
||||
y2 = coord[0, 0] + coord[0, 1]
|
||||
|
||||
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
|
||||
|
||||
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
|
||||
tl = torch.min(coord_transf, dim=1)[0]
|
||||
sz = torch.max(coord_transf, dim=1)[0] - tl
|
||||
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
|
||||
return bbox_out
|
||||
|
||||
def transform_mask(self, mask, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, *rand_params):
|
||||
"""2020.12.24 Added to deal with attention masks"""
|
||||
return att
|
||||
|
||||
|
||||
class ToTensor(TransformBase):
|
||||
"""Convert to a Tensor"""
|
||||
|
||||
def transform_image(self, image):
|
||||
# handle numpy array
|
||||
if image.ndim == 2:
|
||||
image = image[:, :, None]
|
||||
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(image, torch.ByteTensor):
|
||||
return image.float().div(255)
|
||||
else:
|
||||
return image
|
||||
|
||||
def transfrom_mask(self, mask):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
|
||||
def transform_att(self, att):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class ToTensorAndJitter(TransformBase):
|
||||
"""Convert to a Tensor and jitter brightness"""
|
||||
def __init__(self, brightness_jitter=0.0, normalize=True):
|
||||
super().__init__()
|
||||
self.brightness_jitter = brightness_jitter
|
||||
self.normalize = normalize
|
||||
|
||||
def roll(self):
|
||||
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
|
||||
|
||||
def transform_image(self, image, brightness_factor):
|
||||
# handle numpy array
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
|
||||
# backward compatibility
|
||||
if self.normalize:
|
||||
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
|
||||
else:
|
||||
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
|
||||
|
||||
def transform_mask(self, mask, brightness_factor):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
else:
|
||||
return mask
|
||||
def transform_att(self, att, brightness_factor):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class Normalize(TransformBase):
|
||||
"""Normalize image"""
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
super().__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def transform_image(self, image):
|
||||
return tvisf.normalize(image, self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToGrayscale(TransformBase):
|
||||
"""Converts image to grayscale with probability"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_grayscale):
|
||||
if do_grayscale:
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
|
||||
return np.stack([img_gray, img_gray, img_gray], axis=2)
|
||||
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
|
||||
return image
|
||||
|
||||
|
||||
class ToBGR(TransformBase):
|
||||
"""Converts image to BGR"""
|
||||
def transform_image(self, image):
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
||||
return img_bgr
|
||||
|
||||
|
||||
class RandomHorizontalFlip(TransformBase):
|
||||
"""Horizontally flip image randomly with a probability p."""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(image):
|
||||
return image.flip((2,))
|
||||
return np.fliplr(image).copy()
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
|
||||
def transform_mask(self, mask, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(mask):
|
||||
return mask.flip((-1,))
|
||||
return np.fliplr(mask).copy()
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(att):
|
||||
return att.flip((-1,))
|
||||
return np.fliplr(att).copy()
|
||||
return att
|
||||
|
||||
|
||||
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
|
||||
"""Horizontally flip image randomly with a probability p.
|
||||
The difference is that the coord is normalized to [0,1]"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
"""we should use 1 rather than image_shape"""
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = 1 - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
Reference in New Issue
Block a user