init commit of samurai
This commit is contained in:
80
lib/utils/ce_utils.py
Normal file
80
lib/utils/ce_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def generate_bbox_mask(bbox_mask, bbox):
|
||||
b, h, w = bbox_mask.shape
|
||||
for i in range(b):
|
||||
bbox_i = bbox[i].cpu().tolist()
|
||||
bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1
|
||||
return bbox_mask
|
||||
|
||||
|
||||
def generate_mask_cond(cfg, bs, device, gt_bbox):
|
||||
template_size = cfg.DATA.TEMPLATE.SIZE
|
||||
stride = cfg.MODEL.BACKBONE.STRIDE
|
||||
template_feat_size = template_size // stride
|
||||
|
||||
if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL':
|
||||
box_mask_z = None
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT':
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 6)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 14:
|
||||
index = slice(6, 7)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC':
|
||||
# use fixed 4x4 region, 3:5 for 8x8
|
||||
# use fixed 4x4 region 5:6 for 12x12
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 5)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 7)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX':
|
||||
box_mask_z = torch.zeros([bs, template_size, template_size], device=device)
|
||||
# box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128)
|
||||
box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to(
|
||||
torch.float) # (batch, 1, 128, 128)
|
||||
# box_mask_z_vis = box_mask_z.cpu().numpy()
|
||||
box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear',
|
||||
align_corners=False)
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
# box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy()
|
||||
# gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return box_mask_z
|
||||
|
||||
|
||||
def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1):
|
||||
if epoch < warmup_epochs:
|
||||
return 1
|
||||
if epoch >= total_epochs:
|
||||
return base_keep_rate
|
||||
if iters == -1:
|
||||
iters = epoch * ITERS_PER_EPOCH
|
||||
total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs)
|
||||
iters = iters - ITERS_PER_EPOCH * warmup_epochs
|
||||
keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \
|
||||
* (math.cos(iters / total_iters * math.pi) + 1) * 0.5
|
||||
|
||||
return keep_rate
|
Reference in New Issue
Block a user