282 lines
10 KiB
Python
282 lines
10 KiB
Python
from . import BaseActor
|
|
from lib.utils.misc import NestedTensor
|
|
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
|
import torch
|
|
import math
|
|
import numpy as np
|
|
from lib.utils.merge import merge_template_search
|
|
from ...utils.heapmap_utils import generate_heatmap
|
|
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
|
|
|
def fp16_clamp(x, min=None, max=None):
|
|
if not x.is_cuda and x.dtype == torch.float16:
|
|
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
|
return x.float().clamp(min, max).half()
|
|
|
|
return x.clamp(min, max)
|
|
|
|
def generate_sa_simdr(joints):
|
|
'''
|
|
:param joints: [num_joints, 3]
|
|
:param joints_vis: [num_joints, 3]
|
|
:return: target, target_weight(1: visible, 0: invisible)
|
|
'''
|
|
num_joints = 48
|
|
image_size = [256, 256]
|
|
simdr_split_ratio = 1.5625
|
|
sigma = 6
|
|
|
|
target_x1 = np.zeros((num_joints,
|
|
int(image_size[0] * simdr_split_ratio)),
|
|
dtype=np.float32)
|
|
target_y1 = np.zeros((num_joints,
|
|
int(image_size[1] * simdr_split_ratio)),
|
|
dtype=np.float32)
|
|
target_x2 = np.zeros((num_joints,
|
|
int(image_size[0] * simdr_split_ratio)),
|
|
dtype=np.float32)
|
|
target_y2 = np.zeros((num_joints,
|
|
int(image_size[1] * simdr_split_ratio)),
|
|
dtype=np.float32)
|
|
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
|
|
|
tmp_size = sigma * 3
|
|
|
|
for joint_id in range(num_joints):
|
|
|
|
mu_x1 = joints[joint_id][0]
|
|
mu_y1 = joints[joint_id][1]
|
|
mu_x2 = joints[joint_id][2]
|
|
mu_y2 = joints[joint_id][3]
|
|
|
|
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
|
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
|
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
|
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
|
|
|
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
|
sigma * np.sqrt(np.pi * 2))
|
|
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
|
sigma * np.sqrt(np.pi * 2))
|
|
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
|
sigma * np.sqrt(np.pi * 2))
|
|
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
|
sigma * np.sqrt(np.pi * 2))
|
|
return target_x1, target_y1, target_x2, target_y2
|
|
|
|
# angle cost
|
|
def SIoU_loss(test1, test2, theta=4):
|
|
eps = 1e-7
|
|
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
|
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
|
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
|
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
|
|
|
dist = ((cx_pred - cx_gt)**2 + (cy_pred - cy_gt)**2) ** 0.5
|
|
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
|
x = ch / (dist + eps)
|
|
|
|
angle = 1 - 2*torch.sin(torch.arcsin(x)-torch.pi/4)**2
|
|
# distance cost
|
|
xmin = torch.min(test1[:, 0], test2[:, 0])
|
|
xmax = torch.max(test1[:, 2], test2[:, 2])
|
|
ymin = torch.min(test1[:, 1], test2[:, 1])
|
|
ymax = torch.max(test1[:, 3], test2[:, 3])
|
|
cw = xmax - xmin
|
|
ch = ymax - ymin
|
|
px = ((cx_gt - cx_pred) / (cw+eps))**2
|
|
py = ((cy_gt - cy_pred) / (ch+eps))**2
|
|
gama = 2 - angle
|
|
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
|
|
|
#shape cost
|
|
w_pred = test1[:, 2] - test1[:, 0]
|
|
h_pred = test1[:, 3] - test1[:, 1]
|
|
w_gt = test2[:, 2] - test2[:, 0]
|
|
h_gt = test2[:, 3] - test2[:, 1]
|
|
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
|
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
|
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
|
|
|
#IoU loss
|
|
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
|
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
|
|
|
wh = fp16_clamp(rb - lt, min=0)
|
|
overlap = wh[..., 0] * wh[..., 1]
|
|
area1 = (test1[..., 2] - test1[..., 0]) * (
|
|
test1[..., 3] - test1[..., 1])
|
|
area2 = (test2[..., 2] - test2[..., 0]) * (
|
|
test2[..., 3] - test2[..., 1])
|
|
iou = overlap / (area1 + area2 - overlap)
|
|
|
|
SIoU = 1 - iou + (omega + dis) / 2
|
|
return SIoU, iou
|
|
|
|
def ciou(pred, target, eps=1e-7):
|
|
# overlap
|
|
lt = torch.max(pred[:, :2], target[:, :2])
|
|
rb = torch.min(pred[:, 2:], target[:, 2:])
|
|
wh = (rb - lt).clamp(min=0)
|
|
overlap = wh[:, 0] * wh[:, 1]
|
|
|
|
# union
|
|
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
|
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
|
union = ap + ag - overlap + eps
|
|
|
|
# IoU
|
|
ious = overlap / union
|
|
|
|
# enclose area
|
|
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
|
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
|
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
|
|
|
cw = enclose_wh[:, 0]
|
|
ch = enclose_wh[:, 1]
|
|
|
|
c2 = cw**2 + ch**2 + eps
|
|
|
|
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
|
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
|
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
|
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
|
|
|
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
|
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
|
|
|
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
|
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
|
rho2 = left + right
|
|
|
|
factor = 4 / math.pi**2
|
|
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
|
|
|
# CIoU
|
|
cious = ious - (rho2 / c2 + v**2 / (1 - ious + v))
|
|
return cious, ious
|
|
|
|
class ARTrackActor(BaseActor):
|
|
""" Actor for training ARTrack models """
|
|
|
|
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
|
super().__init__(net, objective)
|
|
self.loss_weight = loss_weight
|
|
self.settings = settings
|
|
self.bs = self.settings.batchsize # batch size
|
|
self.cfg = cfg
|
|
self.bins = bins
|
|
self.range = self.cfg.MODEL.RANGE
|
|
self.search_size = search_size
|
|
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
|
self.focal = None
|
|
self.loss_weight['KL'] = 100
|
|
self.loss_weight['focal'] = 2
|
|
|
|
def __call__(self, data):
|
|
"""
|
|
args:
|
|
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
|
template_images: (N_t, batch, 3, H, W)
|
|
search_images: (N_s, batch, 3, H, W)
|
|
returns:
|
|
loss - the training loss
|
|
status - dict containing detailed losses
|
|
"""
|
|
# forward pass
|
|
out_dict = self.forward_pass(data)
|
|
|
|
# compute losses
|
|
loss, status = self.compute_losses(out_dict, data)
|
|
|
|
return loss, status
|
|
|
|
def forward_pass(self, data):
|
|
# currently only support 1 template and 1 search region
|
|
assert len(data['template_images']) == 1
|
|
assert len(data['search_images']) == 1
|
|
|
|
template_list = []
|
|
for i in range(self.settings.num_template):
|
|
template_img_i = data['template_images'][i].view(-1,
|
|
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
|
template_list.append(template_img_i)
|
|
|
|
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
|
|
|
if len(template_list) == 1:
|
|
template_list = template_list[0]
|
|
gt_bbox = data['search_anno'][-1]
|
|
begin = self.bins * self.range
|
|
end = self.bins * self.range + 1
|
|
|
|
magic_num = (self.range - 1) * 0.5
|
|
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
|
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
|
gt_bbox = gt_bbox.clamp(min=(-1*magic_num), max=(1+magic_num))
|
|
data['real_bbox'] = gt_bbox
|
|
|
|
seq_ori = (gt_bbox + magic_num) * (self.bins - 1)
|
|
|
|
seq_ori = seq_ori.int().to(search_img)
|
|
B = seq_ori.shape[0]
|
|
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
|
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
|
data['seq_input'] = seq_input
|
|
data['seq_output'] = seq_output
|
|
out_dict = self.net(template=template_list,
|
|
search=search_img,
|
|
seq_input=seq_input)
|
|
|
|
return out_dict
|
|
|
|
def compute_losses(self, pred_dict, gt_dict, return_status=True):
|
|
bins = self.bins
|
|
magic_num = (self.range - 1) * 0.5
|
|
seq_output = gt_dict['seq_output']
|
|
pred_feat = pred_dict["feat"]
|
|
if self.focal == None:
|
|
weight = torch.ones(bins*self.range+2) * 1
|
|
weight[bins*self.range+1] = 0.1
|
|
weight[bins*self.range] = 0.1
|
|
weight.to(pred_feat)
|
|
self.klloss = torch.nn.KLDivLoss(reduction='none').to(pred_feat)
|
|
|
|
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
|
# compute varfifocal loss
|
|
pred = pred_feat.permute(1, 0, 2).reshape(-1, bins*2+2)
|
|
target = seq_output.reshape(-1).to(torch.int64)
|
|
varifocal_loss = self.focal(pred, target)
|
|
# compute giou and L1 loss
|
|
beta = 1
|
|
pred = pred_feat[0:4, :, 0:bins*self.range] * beta
|
|
target = seq_output[:, 0:4].to(pred_feat)
|
|
|
|
out = pred.softmax(-1).to(pred)
|
|
mul = torch.range((-1*magic_num+1/(self.bins*self.range)), (1+magic_num-1/(self.bins*self.range)), 2/(self.bins*self.range)).to(pred)
|
|
ans = out * mul
|
|
ans = ans.sum(dim=-1)
|
|
ans = ans.permute(1, 0).to(pred)
|
|
target = target / (bins - 1) - magic_num
|
|
extra_seq = ans
|
|
extra_seq = extra_seq.to(pred)
|
|
sious, iou = SIoU_loss(extra_seq, target, 4)
|
|
sious = sious.mean()
|
|
siou_loss = sious
|
|
l1_loss = self.objective['l1'](extra_seq, target)
|
|
|
|
loss = self.loss_weight['giou'] * siou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * varifocal_loss
|
|
|
|
if return_status:
|
|
# status for log
|
|
mean_iou = iou.detach().mean()
|
|
status = {"Loss/total": loss.item(),
|
|
"Loss/giou": siou_loss.item(),
|
|
"Loss/l1": l1_loss.item(),
|
|
"Loss/location": varifocal_loss.item(),
|
|
"IoU": mean_iou.item()}
|
|
return loss, status
|
|
else:
|
|
return loss
|