import math from lib.models.artrack_seq import build_artrack_seq from lib.test.tracker.basetracker import BaseTracker import torch from lib.test.tracker.vis_utils import gen_visualization from lib.test.utils.hann import hann2d from lib.train.data.processing_utils import sample_target, transform_image_to_crop # for debug import cv2 import os from lib.test.tracker.data_utils import Preprocessor from lib.utils.box_ops import clip_box from lib.utils.ce_utils import generate_mask_cond class ARTrackSeq(BaseTracker): def __init__(self, params, dataset_name): super(ARTrackSeq, self).__init__(params) network = build_artrack_seq(params.cfg, training=False) print(self.params.checkpoint) network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True) self.cfg = params.cfg self.bins = self.cfg.MODEL.BINS self.network = network.cuda() self.network.eval() self.preprocessor = Preprocessor() self.state = None self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE # motion constrain self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda() # for debug self.debug = params.debug self.use_visdom = params.debug self.frame_id = 0 if self.debug: if not self.use_visdom: self.save_dir = "debug" if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) else: # self.add_hook() self._init_visdom(None, 1) # for save boxes from all queries self.save_all_boxes = params.save_all_boxes self.z_dict1 = {} self.store_result = None self.save_all = 7 self.x_feat = None self.update = None self.update_threshold = 5.0 self.update_intervals = 1 def initialize(self, image, info: dict): # forward the template once self.x_feat = None z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor, output_sz=self.params.template_size) # output_sz=self.params.template_size self.z_patch_arr = z_patch_arr template = self.preprocessor.process(z_patch_arr, z_amask_arr) with torch.no_grad(): self.z_dict1 = template self.box_mask_z = None # if self.cfg.MODEL.BACKBONE.CE_LOC: # template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor, # template.tensors.device).squeeze(1) # self.box_mask_z = generate_mask_cond(self.cfg, 1, template.tensors.device, template_bbox) # save states self.state = info['init_bbox'] self.store_result = [info['init_bbox'].copy()] for i in range(self.save_all - 1): self.store_result.append(info['init_bbox'].copy()) self.frame_id = 0 self.update = None if self.save_all_boxes: '''save all predicted boxes''' all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES return {"all_boxes": all_boxes_save} def track(self, image, info: dict = None): H, W, _ = image.shape self.frame_id += 1 x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor, output_sz=self.params.search_size) # (x1, y1, w, h) for i in range(len(self.store_result)): box_temp = self.store_result[i].copy() box_out_i = transform_image_to_crop(torch.Tensor(self.store_result[i]), torch.Tensor(self.state), resize_factor, torch.Tensor([self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]), normalize=True) box_out_i[2] = box_out_i[2] + box_out_i[0] box_out_i[3] = box_out_i[3] + box_out_i[1] box_out_i = box_out_i.clamp(min=-0.5, max=1.5) box_out_i = (box_out_i + 0.5) * (self.bins - 1) if i == 0: seqs_out = box_out_i else: seqs_out = torch.cat((seqs_out, box_out_i), dim=-1) seqs_out = seqs_out.unsqueeze(0) search = self.preprocessor.process(x_patch_arr, x_amask_arr) with torch.no_grad(): x_dict = search # merge the template and the search # run the transformer out_dict = self.network.forward( template=self.z_dict1.tensors, search=x_dict.tensors, seq_input=seqs_out, stage="sequence", search_feature=self.x_feat, update=None) self.x_feat = out_dict['x_feat'] pred_boxes = out_dict['seqs'][:, 0:4] / (self.bins - 1) - 0.5 pred_boxes = pred_boxes.view(-1, 4).mean(dim=0) pred_new = pred_boxes pred_new[2] = pred_boxes[2] - pred_boxes[0] pred_new[3] = pred_boxes[3] - pred_boxes[1] pred_new[0] = pred_boxes[0] + pred_new[2] / 2 pred_new[1] = pred_boxes[1] + pred_new[3] / 2 pred_boxes = (pred_new * self.params.search_size / resize_factor).tolist() # Baseline: Take the mean of all pred boxes as the final result # pred_box = (pred_boxes.mean( # dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1] # get the final box result self.state = clip_box(self.map_box_back(pred_boxes, resize_factor), H, W, margin=10) if len(self.store_result) < self.save_all: self.store_result.append(self.state.copy()) else: for i in range(self.save_all): if i != self.save_all - 1: self.store_result[i] = self.store_result[i + 1] else: self.store_result[i] = self.state.copy() # for debug if self.debug: if not self.use_visdom: x1, y1, w, h = self.state image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2) save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id) cv2.imwrite(save_path, image_BGR) else: self.visdom.register((image, info['gt_bbox'].tolist(), self.state), 'Tracking', 1, 'Tracking') self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region') self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template') self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map') self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map_hann') if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']: removed_indexes_s = out_dict['removed_indexes_s'] removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s] masked_search = gen_visualization(x_patch_arr, removed_indexes_s) self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search') while self.pause_mode: if self.step: self.step = False break if self.save_all_boxes: '''save all predictions''' all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor) all_boxes_save = all_boxes.view(-1).tolist() # (4N, ) return {"target_bbox": self.state, "all_boxes": all_boxes_save} else: return {"target_bbox": self.state} def map_box_back(self, pred_box: list, resize_factor: float): cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] cx, cy, w, h = pred_box half_side = 0.5 * self.params.search_size / resize_factor cx_real = cx + (cx_prev - half_side) cy_real = cy + (cy_prev - half_side) # cx_real = cx + cx_prev # cy_real = cy + cy_prev return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float): cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,) half_side = 0.5 * self.params.search_size / resize_factor cx_real = cx + (cx_prev - half_side) cy_real = cy + (cy_prev - half_side) return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1) def add_hook(self): conv_features, enc_attn_weights, dec_attn_weights = [], [], [] for i in range(12): self.network.backbone.blocks[i].attn.register_forward_hook( # lambda self, input, output: enc_attn_weights.append(output[1]) lambda self, input, output: enc_attn_weights.append(output[1]) ) self.enc_attn_weights = enc_attn_weights def get_tracker_class(): return ARTrackSeq