265 lines
12 KiB
Python
265 lines
12 KiB
Python
import random
|
|
import torch.utils.data
|
|
import numpy as np
|
|
from lib.utils import TensorDict
|
|
|
|
|
|
class SequenceSampler(torch.utils.data.Dataset):
|
|
"""
|
|
Sample sequence for sequence-level training
|
|
"""
|
|
|
|
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
|
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
|
|
"""
|
|
args:
|
|
datasets - List of datasets to be used for training
|
|
p_datasets - List containing the probabilities by which each dataset will be sampled
|
|
samples_per_epoch - Number of training samples per epoch
|
|
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
|
|
max_interval - Maximum interval between sampled frames
|
|
num_search_frames - Number of search frames to sample.
|
|
num_template_frames - Number of template frames to sample.
|
|
processing - An instance of Processing class which performs the necessary processing of the data.
|
|
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
|
|
otherwise randomly within the interval.
|
|
prob - sequential sampling by prob / interval sampling by 1-prob
|
|
"""
|
|
self.datasets = datasets
|
|
|
|
# If p not provided, sample uniformly from all videos
|
|
if p_datasets is None:
|
|
p_datasets = [len(d) for d in self.datasets]
|
|
|
|
# Normalize
|
|
p_total = sum(p_datasets)
|
|
self.p_datasets = [x / p_total for x in p_datasets]
|
|
|
|
self.samples_per_epoch = samples_per_epoch
|
|
self.max_gap = max_gap
|
|
self.max_interval = max_interval
|
|
self.num_search_frames = num_search_frames
|
|
self.num_template_frames = num_template_frames
|
|
self.frame_sample_mode = frame_sample_mode
|
|
self.prob=prob
|
|
self.extra=1
|
|
|
|
def __len__(self):
|
|
return self.samples_per_epoch
|
|
|
|
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
|
|
""" Samples num_ids frames between min_id and max_id for which target is visible
|
|
|
|
args:
|
|
visible - 1d Tensor indicating whether target is visible for each frame
|
|
num_ids - number of frames to be samples
|
|
min_id - Minimum allowed frame number
|
|
max_id - Maximum allowed frame number
|
|
|
|
returns:
|
|
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
|
"""
|
|
if num_ids == 0:
|
|
return []
|
|
if min_id is None or min_id < 0:
|
|
min_id = 0
|
|
if max_id is None or max_id > len(visible):
|
|
max_id = len(visible)
|
|
|
|
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
|
|
|
# No visible ids
|
|
if len(valid_ids) == 0:
|
|
return None
|
|
|
|
return random.choices(valid_ids, k=num_ids)
|
|
|
|
|
|
def _sequential_sample(self, visible):
|
|
# Sample frames in sequential manner
|
|
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
|
max_id=len(visible) - self.num_search_frames)
|
|
if self.max_gap == -1:
|
|
left = template_frame_ids[0]
|
|
else:
|
|
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
|
|
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
|
|
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
|
max_id=left_max)[0]
|
|
|
|
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
|
|
search_frame_ids = valid_ids[:self.num_search_frames]
|
|
|
|
# if length is not enough
|
|
last = search_frame_ids[-1]
|
|
while len(search_frame_ids) < self.num_search_frames:
|
|
if last >= len(visible) - 1:
|
|
search_frame_ids.append(last)
|
|
else:
|
|
last += 1
|
|
if visible[last]:
|
|
search_frame_ids.append(last)
|
|
|
|
return template_frame_ids, search_frame_ids
|
|
|
|
|
|
def _random_interval_sample(self, visible):
|
|
# Get valid ids
|
|
valid_ids = [i for i in range(len(visible)) if visible[i]]
|
|
|
|
# Sample template frame
|
|
avg_interval = self.max_interval
|
|
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
|
avg_interval = max(avg_interval - 1, 1)
|
|
|
|
while True:
|
|
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
|
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
|
|
if template_frame_ids == None:
|
|
avg_interval = avg_interval - 1
|
|
else:
|
|
break
|
|
|
|
if avg_interval == 0:
|
|
template_frame_ids = [valid_ids[0]]
|
|
break
|
|
|
|
# Sample first search frame
|
|
if self.max_gap == -1:
|
|
search_frame_ids = template_frame_ids
|
|
else:
|
|
avg_interval = self.max_interval
|
|
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
|
avg_interval = max(avg_interval - 1, 1)
|
|
|
|
while True:
|
|
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
|
|
template_frame_ids[0] + self.max_gap)
|
|
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
|
max_id=left_max)
|
|
|
|
if search_frame_ids == None:
|
|
avg_interval = avg_interval - 1
|
|
else:
|
|
break
|
|
|
|
if avg_interval == -1:
|
|
search_frame_ids = template_frame_ids
|
|
break
|
|
|
|
# Sample rest of the search frames with random interval
|
|
last = search_frame_ids[0]
|
|
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
|
|
# sample id with interval
|
|
max_id = min(last + self.max_interval + 1, len(visible))
|
|
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
|
|
max_id=max_id)
|
|
|
|
if id is None:
|
|
# If not found in current range, find from previous range
|
|
last = last + self.max_interval
|
|
else:
|
|
search_frame_ids.append(id[0])
|
|
last = search_frame_ids[-1]
|
|
|
|
# if length is not enough, randomly sample new ids
|
|
if len(search_frame_ids) < self.num_search_frames:
|
|
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
|
|
|
|
if len(valid_ids) > 0:
|
|
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
|
|
self.num_search_frames - len(search_frame_ids)))
|
|
search_frame_ids = search_frame_ids + new_ids
|
|
search_frame_ids = sorted(search_frame_ids, key=int)
|
|
|
|
# if length is still not enough, duplicate last frame
|
|
while len(search_frame_ids) < self.num_search_frames:
|
|
search_frame_ids.append(search_frame_ids[-1])
|
|
|
|
for i in range(1, self.num_search_frames):
|
|
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
|
|
print(search_frame_ids[i] - search_frame_ids[i - 1])
|
|
|
|
return template_frame_ids, search_frame_ids
|
|
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
args:
|
|
index (int): Index (Ignored since we sample randomly)
|
|
|
|
returns:
|
|
TensorDict - dict containing all the data blocks
|
|
"""
|
|
|
|
# Select a dataset
|
|
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
|
if dataset.get_name() == 'got10k' :
|
|
max_gap = self.max_gap
|
|
max_interval = self.max_interval
|
|
else:
|
|
max_gap = self.max_gap
|
|
max_interval = self.max_interval
|
|
self.max_gap = max_gap * self.extra
|
|
self.max_interval = max_interval * self.extra
|
|
|
|
is_video_dataset = dataset.is_video_sequence()
|
|
|
|
# Sample a sequence with enough visible frames
|
|
while True:
|
|
try:
|
|
enough_visible_frames = False
|
|
while not enough_visible_frames:
|
|
# Sample a sequence
|
|
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
|
|
|
# Sample frames
|
|
seq_info_dict = dataset.get_sequence_info(seq_id)
|
|
visible = seq_info_dict['visible']
|
|
|
|
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
|
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
|
|
|
|
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
|
|
|
if is_video_dataset:
|
|
if self.frame_sample_mode == 'sequential':
|
|
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
|
|
|
elif self.frame_sample_mode == 'random_interval':
|
|
if random.random() < self.prob:
|
|
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
|
|
else:
|
|
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
|
else:
|
|
self.max_gap = max_gap
|
|
self.max_interval = max_interval
|
|
raise NotImplementedError
|
|
else:
|
|
# In case of image dataset, just repeat the image to generate synthetic video
|
|
template_frame_ids = [1] * self.num_template_frames
|
|
search_frame_ids = [1] * self.num_search_frames
|
|
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
|
|
self.max_gap = max_gap
|
|
self.max_interval = max_interval
|
|
#print(self.max_gap, self.max_interval)
|
|
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
|
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
|
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
|
|
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
|
|
# print("====================================================================================")
|
|
# print("dataset index: {}".format(index))
|
|
# print("seq_id: {}".format(seq_id))
|
|
# print('template_frame_ids: {}'.format(template_frame_ids))
|
|
# print('search_frame_ids: {}'.format(search_frame_ids))
|
|
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
|
|
'template_annos': np.array(template_bbox).squeeze(),
|
|
'search_images': np.array(search_frames), # (num_frames) search images
|
|
'search_annos': np.array(search_bbox),
|
|
'seq_id': seq_id,
|
|
'dataset': dataset.get_name(),
|
|
'search_class': meta_obj_search.get('object_class_name'),
|
|
'num_frames': len(search_frames)
|
|
})
|
|
except Exception:
|
|
pass |