90 lines
3.6 KiB
Python
90 lines
3.6 KiB
Python
![]() |
import time
|
||
|
|
||
|
import torch
|
||
|
from _collections import OrderedDict
|
||
|
|
||
|
from lib.train.data.processing_utils import transform_image_to_crop
|
||
|
from lib.vis.visdom_cus import Visdom
|
||
|
|
||
|
|
||
|
class BaseTracker:
|
||
|
"""Base class for all trackers."""
|
||
|
|
||
|
def __init__(self, params):
|
||
|
self.params = params
|
||
|
self.visdom = None
|
||
|
|
||
|
def predicts_segmentation_mask(self):
|
||
|
return False
|
||
|
|
||
|
def initialize(self, image, info: dict) -> dict:
|
||
|
"""Overload this function in your tracker. This should initialize the model."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def track(self, image, info: dict = None) -> dict:
|
||
|
"""Overload this function in your tracker. This should track in the frame and update the model."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def visdom_draw_tracking(self, image, box, segmentation=None):
|
||
|
if isinstance(box, OrderedDict):
|
||
|
box = [v for k, v in box.items()]
|
||
|
else:
|
||
|
box = (box,)
|
||
|
if segmentation is None:
|
||
|
self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
|
||
|
else:
|
||
|
self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')
|
||
|
|
||
|
def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'):
|
||
|
# box_in: list [x1, y1, w, h], not normalized
|
||
|
# box_extract: same as box_in
|
||
|
# out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized
|
||
|
if crop_type == 'template':
|
||
|
crop_sz = torch.Tensor([self.params.template_size, self.params.template_size])
|
||
|
elif crop_type == 'search':
|
||
|
crop_sz = torch.Tensor([self.params.search_size, self.params.search_size])
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
box_in = torch.tensor(box_in)
|
||
|
if box_extract is None:
|
||
|
box_extract = box_in
|
||
|
else:
|
||
|
box_extract = torch.tensor(box_extract)
|
||
|
template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True)
|
||
|
template_bbox = template_bbox.view(1, 1, 4).to(device)
|
||
|
|
||
|
return template_bbox
|
||
|
|
||
|
def _init_visdom(self, visdom_info, debug):
|
||
|
visdom_info = {} if visdom_info is None else visdom_info
|
||
|
self.pause_mode = False
|
||
|
self.step = False
|
||
|
self.next_seq = False
|
||
|
if debug > 0 and visdom_info.get('use_visdom', True):
|
||
|
try:
|
||
|
self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
|
||
|
visdom_info=visdom_info)
|
||
|
|
||
|
# # Show help
|
||
|
# help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
|
||
|
# 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
|
||
|
# 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
|
||
|
# 'block list.'
|
||
|
# self.visdom.register(help_text, 'text', 1, 'Help')
|
||
|
except:
|
||
|
time.sleep(0.5)
|
||
|
print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
|
||
|
'!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')
|
||
|
|
||
|
def _visdom_ui_handler(self, data):
|
||
|
if data['event_type'] == 'KeyPress':
|
||
|
if data['key'] == ' ':
|
||
|
self.pause_mode = not self.pause_mode
|
||
|
|
||
|
elif data['key'] == 'ArrowRight' and self.pause_mode:
|
||
|
self.step = True
|
||
|
|
||
|
elif data['key'] == 'n':
|
||
|
self.next_seq = True
|