Files
Grounded-SAM-2/lib/test/tracker/basetracker.py
2024-11-19 22:12:54 -08:00

90 lines
3.6 KiB
Python

import time
import torch
from _collections import OrderedDict
from lib.train.data.processing_utils import transform_image_to_crop
from lib.vis.visdom_cus import Visdom
class BaseTracker:
"""Base class for all trackers."""
def __init__(self, params):
self.params = params
self.visdom = None
def predicts_segmentation_mask(self):
return False
def initialize(self, image, info: dict) -> dict:
"""Overload this function in your tracker. This should initialize the model."""
raise NotImplementedError
def track(self, image, info: dict = None) -> dict:
"""Overload this function in your tracker. This should track in the frame and update the model."""
raise NotImplementedError
def visdom_draw_tracking(self, image, box, segmentation=None):
if isinstance(box, OrderedDict):
box = [v for k, v in box.items()]
else:
box = (box,)
if segmentation is None:
self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
else:
self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')
def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'):
# box_in: list [x1, y1, w, h], not normalized
# box_extract: same as box_in
# out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized
if crop_type == 'template':
crop_sz = torch.Tensor([self.params.template_size, self.params.template_size])
elif crop_type == 'search':
crop_sz = torch.Tensor([self.params.search_size, self.params.search_size])
else:
raise NotImplementedError
box_in = torch.tensor(box_in)
if box_extract is None:
box_extract = box_in
else:
box_extract = torch.tensor(box_extract)
template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True)
template_bbox = template_bbox.view(1, 1, 4).to(device)
return template_bbox
def _init_visdom(self, visdom_info, debug):
visdom_info = {} if visdom_info is None else visdom_info
self.pause_mode = False
self.step = False
self.next_seq = False
if debug > 0 and visdom_info.get('use_visdom', True):
try:
self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
visdom_info=visdom_info)
# # Show help
# help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
# 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
# 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
# 'block list.'
# self.visdom.register(help_text, 'text', 1, 'Help')
except:
time.sleep(0.5)
print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
'!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')
def _visdom_ui_handler(self, data):
if data['event_type'] == 'KeyPress':
if data['key'] == ' ':
self.pause_mode = not self.pause_mode
elif data['key'] == 'ArrowRight' and self.pause_mode:
self.step = True
elif data['key'] == 'n':
self.next_seq = True