169 lines
6.8 KiB
Python
169 lines
6.8 KiB
Python
import numpy as np
|
|
from lib.test.evaluation.environment import env_settings
|
|
from lib.train.data.image_loader import imread_indexed
|
|
from collections import OrderedDict
|
|
|
|
|
|
class BaseDataset:
|
|
"""Base class for all datasets."""
|
|
def __init__(self):
|
|
self.env_settings = env_settings()
|
|
|
|
def __len__(self):
|
|
"""Overload this function in your dataset. This should return number of sequences in the dataset."""
|
|
raise NotImplementedError
|
|
|
|
def get_sequence_list(self):
|
|
"""Overload this in your dataset. Should return the list of sequences in the dataset."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Sequence:
|
|
"""Class for the sequence in an evaluation."""
|
|
def __init__(self, name, frames, dataset, ground_truth_rect, ground_truth_seg=None, init_data=None,
|
|
object_class=None, target_visible=None, object_ids=None, multiobj_mode=False):
|
|
self.name = name
|
|
self.frames = frames
|
|
self.dataset = dataset
|
|
self.ground_truth_rect = ground_truth_rect
|
|
self.ground_truth_seg = ground_truth_seg
|
|
self.object_class = object_class
|
|
self.target_visible = target_visible
|
|
self.object_ids = object_ids
|
|
self.multiobj_mode = multiobj_mode
|
|
self.init_data = self._construct_init_data(init_data)
|
|
self._ensure_start_frame()
|
|
|
|
def _ensure_start_frame(self):
|
|
# Ensure start frame is 0
|
|
start_frame = min(list(self.init_data.keys()))
|
|
if start_frame > 0:
|
|
self.frames = self.frames[start_frame:]
|
|
if self.ground_truth_rect is not None:
|
|
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
|
for obj_id, gt in self.ground_truth_rect.items():
|
|
self.ground_truth_rect[obj_id] = gt[start_frame:,:]
|
|
else:
|
|
self.ground_truth_rect = self.ground_truth_rect[start_frame:,:]
|
|
if self.ground_truth_seg is not None:
|
|
self.ground_truth_seg = self.ground_truth_seg[start_frame:]
|
|
assert len(self.frames) == len(self.ground_truth_seg)
|
|
|
|
if self.target_visible is not None:
|
|
self.target_visible = self.target_visible[start_frame:]
|
|
self.init_data = {frame-start_frame: val for frame, val in self.init_data.items()}
|
|
|
|
def _construct_init_data(self, init_data):
|
|
if init_data is not None:
|
|
if not self.multiobj_mode:
|
|
assert self.object_ids is None or len(self.object_ids) == 1
|
|
for frame, init_val in init_data.items():
|
|
if 'bbox' in init_val and isinstance(init_val['bbox'], (dict, OrderedDict)):
|
|
init_val['bbox'] = init_val['bbox'][self.object_ids[0]]
|
|
# convert to list
|
|
for frame, init_val in init_data.items():
|
|
if 'bbox' in init_val:
|
|
if isinstance(init_val['bbox'], (dict, OrderedDict)):
|
|
init_val['bbox'] = OrderedDict({obj_id: list(init) for obj_id, init in init_val['bbox'].items()})
|
|
else:
|
|
init_val['bbox'] = list(init_val['bbox'])
|
|
else:
|
|
init_data = {0: dict()} # Assume start from frame 0
|
|
|
|
if self.object_ids is not None:
|
|
init_data[0]['object_ids'] = self.object_ids
|
|
|
|
if self.ground_truth_rect is not None:
|
|
if self.multiobj_mode:
|
|
assert isinstance(self.ground_truth_rect, (dict, OrderedDict))
|
|
init_data[0]['bbox'] = OrderedDict({obj_id: list(gt[0,:]) for obj_id, gt in self.ground_truth_rect.items()})
|
|
else:
|
|
assert self.object_ids is None or len(self.object_ids) == 1
|
|
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
|
init_data[0]['bbox'] = list(self.ground_truth_rect[self.object_ids[0]][0, :])
|
|
else:
|
|
init_data[0]['bbox'] = list(self.ground_truth_rect[0,:])
|
|
|
|
if self.ground_truth_seg is not None:
|
|
init_data[0]['mask'] = self.ground_truth_seg[0]
|
|
|
|
return init_data
|
|
|
|
def init_info(self):
|
|
info = self.frame_info(frame_num=0)
|
|
return info
|
|
|
|
def frame_info(self, frame_num):
|
|
info = self.object_init_data(frame_num=frame_num)
|
|
return info
|
|
|
|
def init_bbox(self, frame_num=0):
|
|
return self.object_init_data(frame_num=frame_num).get('init_bbox')
|
|
|
|
def init_mask(self, frame_num=0):
|
|
return self.object_init_data(frame_num=frame_num).get('init_mask')
|
|
|
|
def get_info(self, keys, frame_num=None):
|
|
info = dict()
|
|
for k in keys:
|
|
val = self.get(k, frame_num=frame_num)
|
|
if val is not None:
|
|
info[k] = val
|
|
return info
|
|
|
|
def object_init_data(self, frame_num=None) -> dict:
|
|
if frame_num is None:
|
|
frame_num = 0
|
|
if frame_num not in self.init_data:
|
|
return dict()
|
|
|
|
init_data = dict()
|
|
for key, val in self.init_data[frame_num].items():
|
|
if val is None:
|
|
continue
|
|
init_data['init_'+key] = val
|
|
|
|
if 'init_mask' in init_data and init_data['init_mask'] is not None:
|
|
anno = imread_indexed(init_data['init_mask'])
|
|
if not self.multiobj_mode and self.object_ids is not None:
|
|
assert len(self.object_ids) == 1
|
|
anno = (anno == int(self.object_ids[0])).astype(np.uint8)
|
|
init_data['init_mask'] = anno
|
|
|
|
if self.object_ids is not None:
|
|
init_data['object_ids'] = self.object_ids
|
|
init_data['sequence_object_ids'] = self.object_ids
|
|
|
|
return init_data
|
|
|
|
def target_class(self, frame_num=None):
|
|
return self.object_class
|
|
|
|
def get(self, name, frame_num=None):
|
|
return getattr(self, name)(frame_num)
|
|
|
|
def __repr__(self):
|
|
return "{self.__class__.__name__} {self.name}, length={len} frames".format(self=self, len=len(self.frames))
|
|
|
|
|
|
|
|
class SequenceList(list):
|
|
"""List of sequences. Supports the addition operator to concatenate sequence lists."""
|
|
def __getitem__(self, item):
|
|
if isinstance(item, str):
|
|
for seq in self:
|
|
if seq.name == item:
|
|
return seq
|
|
raise IndexError('Sequence name not in the dataset.')
|
|
elif isinstance(item, int):
|
|
return super(SequenceList, self).__getitem__(item)
|
|
elif isinstance(item, (tuple, list)):
|
|
return SequenceList([super(SequenceList, self).__getitem__(i) for i in item])
|
|
else:
|
|
return SequenceList(super(SequenceList, self).__getitem__(item))
|
|
|
|
def __add__(self, other):
|
|
return SequenceList(super(SequenceList, self).__add__(other))
|
|
|
|
def copy(self):
|
|
return SequenceList(super(SequenceList, self).copy()) |