init commit of samurai
This commit is contained in:
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import numpy as np
|
||||
from lib.test.utils.load_text import load_text
|
||||
import torch
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
env_path = os.path.join(os.path.dirname(__file__), '../../..')
|
||||
if env_path not in sys.path:
|
||||
sys.path.append(env_path)
|
||||
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def calc_err_center(pred_bb, anno_bb, normalized=False):
|
||||
pred_center = pred_bb[:, :2] + 0.5 * (pred_bb[:, 2:] - 1.0)
|
||||
anno_center = anno_bb[:, :2] + 0.5 * (anno_bb[:, 2:] - 1.0)
|
||||
|
||||
if normalized:
|
||||
pred_center = pred_center / anno_bb[:, 2:]
|
||||
anno_center = anno_center / anno_bb[:, 2:]
|
||||
|
||||
err_center = ((pred_center - anno_center)**2).sum(1).sqrt()
|
||||
return err_center
|
||||
|
||||
|
||||
def calc_iou_overlap(pred_bb, anno_bb):
|
||||
tl = torch.max(pred_bb[:, :2], anno_bb[:, :2])
|
||||
br = torch.min(pred_bb[:, :2] + pred_bb[:, 2:] - 1.0, anno_bb[:, :2] + anno_bb[:, 2:] - 1.0)
|
||||
sz = (br - tl + 1.0).clamp(0)
|
||||
|
||||
# Area
|
||||
intersection = sz.prod(dim=1)
|
||||
union = pred_bb[:, 2:].prod(dim=1) + anno_bb[:, 2:].prod(dim=1) - intersection
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
def calc_seq_err_robust(pred_bb, anno_bb, dataset, target_visible=None):
|
||||
pred_bb = pred_bb.clone()
|
||||
|
||||
# Check if invalid values are present
|
||||
if torch.isnan(pred_bb).any() or (pred_bb[:, 2:] < 0.0).any():
|
||||
raise Exception('Error: Invalid results')
|
||||
|
||||
if torch.isnan(anno_bb).any():
|
||||
if dataset == 'uav':
|
||||
pass
|
||||
else:
|
||||
raise Exception('Warning: NaNs in annotation')
|
||||
|
||||
if (pred_bb[:, 2:] == 0.0).any():
|
||||
for i in range(1, pred_bb.shape[0]):
|
||||
if i >= anno_bb.shape[0]:
|
||||
continue
|
||||
if (pred_bb[i, 2:] == 0.0).any() and not torch.isnan(anno_bb[i, :]).any():
|
||||
pred_bb[i, :] = pred_bb[i-1, :]
|
||||
|
||||
if pred_bb.shape[0] != anno_bb.shape[0]:
|
||||
if dataset == 'lasot':
|
||||
if pred_bb.shape[0] > anno_bb.shape[0]:
|
||||
# For monkey-17, there is a mismatch for some trackers.
|
||||
pred_bb = pred_bb[:anno_bb.shape[0], :]
|
||||
else:
|
||||
raise Exception('Mis-match in tracker prediction and GT lengths')
|
||||
else:
|
||||
# print('Warning: Mis-match in tracker prediction and GT lengths')
|
||||
if pred_bb.shape[0] > anno_bb.shape[0]:
|
||||
pred_bb = pred_bb[:anno_bb.shape[0], :]
|
||||
else:
|
||||
pad = torch.zeros((anno_bb.shape[0] - pred_bb.shape[0], 4)).type_as(pred_bb)
|
||||
pred_bb = torch.cat((pred_bb, pad), dim=0)
|
||||
|
||||
pred_bb[0, :] = anno_bb[0, :]
|
||||
|
||||
if target_visible is not None:
|
||||
target_visible = target_visible.bool()
|
||||
valid = ((anno_bb[:, 2:] > 0.0).sum(1) == 2) & target_visible
|
||||
else:
|
||||
valid = ((anno_bb[:, 2:] > 0.0).sum(1) == 2)
|
||||
|
||||
err_center = calc_err_center(pred_bb, anno_bb)
|
||||
err_center_normalized = calc_err_center(pred_bb, anno_bb, normalized=True)
|
||||
err_overlap = calc_iou_overlap(pred_bb, anno_bb)
|
||||
|
||||
# handle invalid anno cases
|
||||
if dataset in ['uav']:
|
||||
err_center[~valid] = -1.0
|
||||
else:
|
||||
err_center[~valid] = float("Inf")
|
||||
err_center_normalized[~valid] = -1.0
|
||||
err_overlap[~valid] = -1.0
|
||||
|
||||
if dataset == 'lasot':
|
||||
err_center_normalized[~target_visible] = float("Inf")
|
||||
err_center[~target_visible] = float("Inf")
|
||||
|
||||
if torch.isnan(err_overlap).any():
|
||||
raise Exception('Nans in calculated overlap')
|
||||
return err_overlap, err_center, err_center_normalized, valid
|
||||
|
||||
|
||||
def extract_results(trackers, dataset, report_name, skip_missing_seq=False, plot_bin_gap=0.05,
|
||||
exclude_invalid_frames=False):
|
||||
settings = env_settings()
|
||||
eps = 1e-16
|
||||
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
|
||||
if not os.path.exists(result_plot_path):
|
||||
os.makedirs(result_plot_path)
|
||||
|
||||
threshold_set_overlap = torch.arange(0.0, 1.0 + plot_bin_gap, plot_bin_gap, dtype=torch.float64)
|
||||
threshold_set_center = torch.arange(0, 51, dtype=torch.float64)
|
||||
threshold_set_center_norm = torch.arange(0, 51, dtype=torch.float64) / 100.0
|
||||
|
||||
avg_overlap_all = torch.zeros((len(dataset), len(trackers)), dtype=torch.float64)
|
||||
ave_success_rate_plot_overlap = torch.zeros((len(dataset), len(trackers), threshold_set_overlap.numel()),
|
||||
dtype=torch.float32)
|
||||
ave_success_rate_plot_center = torch.zeros((len(dataset), len(trackers), threshold_set_center.numel()),
|
||||
dtype=torch.float32)
|
||||
ave_success_rate_plot_center_norm = torch.zeros((len(dataset), len(trackers), threshold_set_center.numel()),
|
||||
dtype=torch.float32)
|
||||
|
||||
from collections import defaultdict
|
||||
# default dict of default dict of list
|
||||
|
||||
|
||||
valid_sequence = torch.ones(len(dataset), dtype=torch.uint8)
|
||||
|
||||
for seq_id, seq in enumerate(tqdm(dataset)):
|
||||
frame_success_rate_plot_overlap = defaultdict(lambda: defaultdict(list))
|
||||
frame_success_rate_plot_center = defaultdict(lambda: defaultdict(list))
|
||||
frame_success_rate_plot_center_norm = defaultdict(lambda: defaultdict(list))
|
||||
# Load anno
|
||||
anno_bb = torch.tensor(seq.ground_truth_rect)
|
||||
target_visible = torch.tensor(seq.target_visible, dtype=torch.uint8) if seq.target_visible is not None else None
|
||||
for trk_id, trk in enumerate(trackers):
|
||||
# Load results
|
||||
base_results_path = '{}/{}'.format(trk.results_dir, seq.name)
|
||||
results_path = '{}.txt'.format(base_results_path)
|
||||
|
||||
if os.path.isfile(results_path):
|
||||
pred_bb = torch.tensor(load_text(str(results_path), delimiter=('\t', ','), dtype=np.float64))
|
||||
else:
|
||||
if skip_missing_seq:
|
||||
valid_sequence[seq_id] = 0
|
||||
break
|
||||
else:
|
||||
raise Exception('Result not found. {}'.format(results_path))
|
||||
|
||||
# Calculate measures
|
||||
err_overlap, err_center, err_center_normalized, valid_frame = calc_seq_err_robust(
|
||||
pred_bb, anno_bb, seq.dataset, target_visible)
|
||||
|
||||
avg_overlap_all[seq_id, trk_id] = err_overlap[valid_frame].mean()
|
||||
|
||||
if exclude_invalid_frames:
|
||||
seq_length = valid_frame.long().sum()
|
||||
else:
|
||||
seq_length = anno_bb.shape[0]
|
||||
|
||||
if seq_length <= 0:
|
||||
raise Exception('Seq length zero')
|
||||
|
||||
ave_success_rate_plot_overlap[seq_id, trk_id, :] = (err_overlap.view(-1, 1) > threshold_set_overlap.view(1, -1)).sum(0).float() / seq_length
|
||||
ave_success_rate_plot_center[seq_id, trk_id, :] = (err_center.view(-1, 1) <= threshold_set_center.view(1, -1)).sum(0).float() / seq_length
|
||||
ave_success_rate_plot_center_norm[seq_id, trk_id, :] = (err_center_normalized.view(-1, 1) <= threshold_set_center_norm.view(1, -1)).sum(0).float() / seq_length
|
||||
|
||||
# for frame_id in range(seq_length):
|
||||
# frame_success_rate_plot_overlap[trk_id][frame_id].append((err_overlap[frame_id]).item())
|
||||
# frame_success_rate_plot_center[trk_id][frame_id].append((err_center[frame_id]).item())
|
||||
# frame_success_rate_plot_center_norm[trk_id][frame_id].append((err_center_normalized[frame_id] < 0.2).item())
|
||||
|
||||
# output_folder = "../cvpr2025/per_frame_success_rate"
|
||||
# os.makedirs(output_folder, exist_ok=True)
|
||||
# with open(osp.join(output_folder, f"{seq.name}.txt"), 'w') as f:
|
||||
# for frame_id in range(seq_length):
|
||||
# suc_score = frame_success_rate_plot_overlap[trk_id][frame_id][0]
|
||||
# f.write(f"{suc_score}\n")
|
||||
|
||||
# # plot the average success rate, center normalized for each tracker
|
||||
# # y axis: success rate
|
||||
# # x axis: frame number
|
||||
# # different color for each tracker
|
||||
# # save the plot as a figure
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.figure(figsize=(10, 6))
|
||||
# for trk_id, trk in enumerate(trackers):
|
||||
# list_to_plot = [np.mean(frame_success_rate_plot_overlap[trk_id][frame_id]) for frame_id in range(2000)]
|
||||
# # smooth the curve; window size = 10
|
||||
# smooth_list_to_plot = np.convolve(list_to_plot, np.ones((10,))/10, mode='valid')
|
||||
# # the smooth curve and non smooth curve have the same label
|
||||
# plt.plot(smooth_list_to_plot, label=trk.display_name, alpha=1)
|
||||
# plt.xlabel('Frame Number')
|
||||
# plt.ylabel('Success Rate')
|
||||
# plt.title('Average Success Rate Over Frames')
|
||||
# plt.legend()
|
||||
# plt.grid(True)
|
||||
# plt.savefig('average_success_rate_plot_overlap.png')
|
||||
# plt.close()
|
||||
|
||||
|
||||
print('\n\nComputed results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
# Prepare dictionary for saving data
|
||||
seq_names = [s.name for s in dataset]
|
||||
tracker_names = [{'name': t.name, 'param': t.parameter_name, 'run_id': t.run_id, 'disp_name': t.display_name}
|
||||
for t in trackers]
|
||||
|
||||
eval_data = {'sequences': seq_names, 'trackers': tracker_names,
|
||||
'valid_sequence': valid_sequence.tolist(),
|
||||
'ave_success_rate_plot_overlap': ave_success_rate_plot_overlap.tolist(),
|
||||
'ave_success_rate_plot_center': ave_success_rate_plot_center.tolist(),
|
||||
'ave_success_rate_plot_center_norm': ave_success_rate_plot_center_norm.tolist(),
|
||||
'avg_overlap_all': avg_overlap_all.tolist(),
|
||||
'threshold_set_overlap': threshold_set_overlap.tolist(),
|
||||
'threshold_set_center': threshold_set_center.tolist(),
|
||||
'threshold_set_center_norm': threshold_set_center_norm.tolist()}
|
||||
|
||||
with open(result_plot_path + '/eval_data.pkl', 'wb') as fh:
|
||||
pickle.dump(eval_data, fh)
|
||||
|
||||
return eval_data
|
||||
@@ -0,0 +1,796 @@
|
||||
import tikzplotlib
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import torch
|
||||
import pickle
|
||||
import json
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.test.analysis.extract_results import extract_results
|
||||
|
||||
|
||||
def get_plot_draw_styles():
|
||||
plot_draw_style = [
|
||||
# {'color': (1.0, 0.0, 0.0), 'line_style': '-'},
|
||||
# {'color': (0.0, 1.0, 0.0), 'line_style': '-'},
|
||||
{'color': (0.0, 1.0, 0.0), 'line_style': '-'},
|
||||
{'color': (0.0, 0.0, 0.0), 'line_style': '-'},
|
||||
{'color': (1.0, 0.0, 1.0), 'line_style': '-'},
|
||||
{'color': (0.0, 1.0, 1.0), 'line_style': '-'},
|
||||
{'color': (0.5, 0.5, 0.5), 'line_style': '-'},
|
||||
{'color': (136.0 / 255.0, 0.0, 21.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (1.0, 127.0 / 255.0, 39.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (0.0, 162.0 / 255.0, 232.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (0.0, 0.5, 0.0), 'line_style': '-'},
|
||||
{'color': (1.0, 0.5, 0.2), 'line_style': '-'},
|
||||
{'color': (0.1, 0.4, 0.0), 'line_style': '-'},
|
||||
{'color': (0.6, 0.3, 0.9), 'line_style': '-'},
|
||||
{'color': (0.4, 0.7, 0.1), 'line_style': '-'},
|
||||
{'color': (0.2, 0.1, 0.7), 'line_style': '-'},
|
||||
{'color': (0.7, 0.6, 0.2), 'line_style': '-'}]
|
||||
|
||||
return plot_draw_style
|
||||
|
||||
|
||||
def check_eval_data_is_valid(eval_data, trackers, dataset):
|
||||
""" Checks if the pre-computed results are valid"""
|
||||
seq_names = [s.name for s in dataset]
|
||||
seq_names_saved = eval_data['sequences']
|
||||
|
||||
tracker_names_f = [(t.name, t.parameter_name, t.run_id) for t in trackers]
|
||||
tracker_names_f_saved = [(t['name'], t['param'], t['run_id']) for t in eval_data['trackers']]
|
||||
|
||||
return seq_names == seq_names_saved and tracker_names_f == tracker_names_f_saved
|
||||
|
||||
|
||||
def merge_multiple_runs(eval_data):
|
||||
new_tracker_names = []
|
||||
ave_success_rate_plot_overlap_merged = []
|
||||
ave_success_rate_plot_center_merged = []
|
||||
ave_success_rate_plot_center_norm_merged = []
|
||||
avg_overlap_all_merged = []
|
||||
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
avg_overlap_all = torch.tensor(eval_data['avg_overlap_all'])
|
||||
|
||||
trackers = eval_data['trackers']
|
||||
merged = torch.zeros(len(trackers), dtype=torch.uint8)
|
||||
for i in range(len(trackers)):
|
||||
if merged[i]:
|
||||
continue
|
||||
base_tracker = trackers[i]
|
||||
new_tracker_names.append(base_tracker)
|
||||
|
||||
match = [t['name'] == base_tracker['name'] and t['param'] == base_tracker['param'] for t in trackers]
|
||||
match = torch.tensor(match)
|
||||
|
||||
ave_success_rate_plot_overlap_merged.append(ave_success_rate_plot_overlap[:, match, :].mean(1))
|
||||
ave_success_rate_plot_center_merged.append(ave_success_rate_plot_center[:, match, :].mean(1))
|
||||
ave_success_rate_plot_center_norm_merged.append(ave_success_rate_plot_center_norm[:, match, :].mean(1))
|
||||
avg_overlap_all_merged.append(avg_overlap_all[:, match].mean(1))
|
||||
|
||||
merged[match] = 1
|
||||
|
||||
ave_success_rate_plot_overlap_merged = torch.stack(ave_success_rate_plot_overlap_merged, dim=1)
|
||||
ave_success_rate_plot_center_merged = torch.stack(ave_success_rate_plot_center_merged, dim=1)
|
||||
ave_success_rate_plot_center_norm_merged = torch.stack(ave_success_rate_plot_center_norm_merged, dim=1)
|
||||
avg_overlap_all_merged = torch.stack(avg_overlap_all_merged, dim=1)
|
||||
|
||||
eval_data['trackers'] = new_tracker_names
|
||||
eval_data['ave_success_rate_plot_overlap'] = ave_success_rate_plot_overlap_merged.tolist()
|
||||
eval_data['ave_success_rate_plot_center'] = ave_success_rate_plot_center_merged.tolist()
|
||||
eval_data['ave_success_rate_plot_center_norm'] = ave_success_rate_plot_center_norm_merged.tolist()
|
||||
eval_data['avg_overlap_all'] = avg_overlap_all_merged.tolist()
|
||||
|
||||
return eval_data
|
||||
|
||||
|
||||
def get_tracker_display_name(tracker):
|
||||
if tracker['disp_name'] is None:
|
||||
if tracker['run_id'] is None:
|
||||
disp_name = '{}_{}'.format(tracker['name'], tracker['param'])
|
||||
else:
|
||||
disp_name = '{}_{}_{:03d}'.format(tracker['name'], tracker['param'],
|
||||
tracker['run_id'])
|
||||
else:
|
||||
disp_name = tracker['disp_name']
|
||||
|
||||
return disp_name
|
||||
|
||||
|
||||
def plot_draw_save(y, x, scores, trackers, plot_draw_styles, result_plot_path, plot_opts):
|
||||
plt.rcParams['text.usetex']=True
|
||||
plt.rcParams["font.family"] = "Times New Roman"
|
||||
# Plot settings
|
||||
font_size = plot_opts.get('font_size', 25)
|
||||
font_size_axis = plot_opts.get('font_size_axis', 20)
|
||||
line_width = plot_opts.get('line_width', 2)
|
||||
font_size_legend = plot_opts.get('font_size_legend', 15)
|
||||
|
||||
plot_type = plot_opts['plot_type']
|
||||
legend_loc = plot_opts['legend_loc']
|
||||
if 'attr' in plot_opts:
|
||||
attr = plot_opts['attr']
|
||||
else:
|
||||
attr = None
|
||||
|
||||
xlabel = plot_opts['xlabel']
|
||||
ylabel = plot_opts['ylabel']
|
||||
ylabel = "%s"%(ylabel.replace('%','\%'))
|
||||
xlim = plot_opts['xlim']
|
||||
ylim = plot_opts['ylim']
|
||||
|
||||
title = r"\textbf{%s}" %(plot_opts['title'])
|
||||
print
|
||||
|
||||
matplotlib.rcParams.update({'font.size': font_size})
|
||||
matplotlib.rcParams.update({'axes.titlesize': font_size_axis})
|
||||
matplotlib.rcParams.update({'axes.titleweight': 'black'})
|
||||
matplotlib.rcParams.update({'axes.labelsize': font_size_axis})
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
index_sort = scores.argsort(descending=False)
|
||||
|
||||
plotted_lines = []
|
||||
legend_text = []
|
||||
|
||||
for id, id_sort in enumerate(index_sort):
|
||||
if trackers[id_sort]['disp_name'].startswith('SAMURAI'):
|
||||
alpha = 1.0
|
||||
line_style = '-'
|
||||
if trackers[id_sort]['disp_name'] == 'SAMURAI-L':
|
||||
color = (1.0, 0.0, 0.0)
|
||||
elif trackers[id_sort]['disp_name'] == 'SAMURAI-B':
|
||||
color = (0.0, 0.0, 1.0)
|
||||
elif trackers[id_sort]['disp_name'].startswith('SAM2.1'):
|
||||
alpha = 0.8
|
||||
line_style = '--'
|
||||
if trackers[id_sort]['disp_name'] == 'SAM2.1-L':
|
||||
color = (1.0, 0.0, 0.0)
|
||||
elif trackers[id_sort]['disp_name'] == 'SAM2.1-B':
|
||||
color = (0.0, 0.0, 1.0)
|
||||
else:
|
||||
alpha = 0.5
|
||||
color = plot_draw_styles[index_sort.numel() - id - 1]['color']
|
||||
line_style = ":"
|
||||
line = ax.plot(x.tolist(), y[id_sort, :].tolist(),
|
||||
linewidth=line_width,
|
||||
color=color,
|
||||
linestyle=line_style,
|
||||
alpha=alpha)
|
||||
|
||||
plotted_lines.append(line[0])
|
||||
|
||||
tracker = trackers[id_sort]
|
||||
disp_name = get_tracker_display_name(tracker)
|
||||
|
||||
legend_text.append('{} [{:.1f}]'.format(disp_name, scores[id_sort]))
|
||||
|
||||
try:
|
||||
# add bold to top method
|
||||
# for i in range(1,2):
|
||||
# legend_text[-i] = r'\textbf{%s}'%(legend_text[-i])
|
||||
|
||||
for id, id_sort in enumerate(index_sort):
|
||||
if trackers[id_sort]['disp_name'].startswith('SAMTrack'):
|
||||
legend_text[id] = r'\textbf{%s}'%(legend_text[id])
|
||||
|
||||
ax.legend(plotted_lines[::-1], legend_text[::-1], loc=legend_loc, fancybox=False, edgecolor='black',
|
||||
fontsize=font_size_legend, framealpha=1.0)
|
||||
except:
|
||||
pass
|
||||
|
||||
ax.set(xlabel=xlabel,
|
||||
ylabel=ylabel,
|
||||
xlim=xlim, ylim=ylim,
|
||||
title=title)
|
||||
|
||||
ax.grid(True, linestyle='-.')
|
||||
fig.tight_layout()
|
||||
|
||||
def tikzplotlib_fix_ncols(obj):
|
||||
"""
|
||||
workaround for matplotlib 3.6 renamed legend's _ncol to _ncols, which breaks tikzplotlib
|
||||
"""
|
||||
if hasattr(obj, "_ncols"):
|
||||
obj._ncol = obj._ncols
|
||||
for child in obj.get_children():
|
||||
tikzplotlib_fix_ncols(child)
|
||||
|
||||
tikzplotlib_fix_ncols(fig)
|
||||
|
||||
# tikzplotlib.save('{}/{}_plot.tex'.format(result_plot_path, plot_type))
|
||||
if attr is not None:
|
||||
fig.savefig('{}/{}_{}_plot.pdf'.format(result_plot_path, plot_type, attr), dpi=300, format='pdf', transparent=True)
|
||||
else:
|
||||
fig.savefig('{}/{}_plot.pdf'.format(result_plot_path, plot_type), dpi=300, format='pdf', transparent=True)
|
||||
plt.draw()
|
||||
|
||||
|
||||
def check_and_load_precomputed_results(trackers, dataset, report_name, force_evaluation=False, **kwargs):
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data_path = os.path.join(result_plot_path, 'eval_data.pkl')
|
||||
|
||||
if os.path.isfile(eval_data_path) and not force_evaluation:
|
||||
with open(eval_data_path, 'rb') as fh:
|
||||
eval_data = pickle.load(fh)
|
||||
else:
|
||||
# print('Pre-computed evaluation data not found. Computing results!')
|
||||
eval_data = extract_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
if not check_eval_data_is_valid(eval_data, trackers, dataset):
|
||||
# print('Pre-computed evaluation data invalid. Re-computing results!')
|
||||
eval_data = extract_results(trackers, dataset, report_name, **kwargs)
|
||||
# pass
|
||||
else:
|
||||
# Update display names
|
||||
tracker_names = [{'name': t.name, 'param': t.parameter_name, 'run_id': t.run_id, 'disp_name': t.display_name}
|
||||
for t in trackers]
|
||||
eval_data['trackers'] = tracker_names
|
||||
with open(eval_data_path, 'wb') as fh:
|
||||
pickle.dump(eval_data, fh)
|
||||
return eval_data
|
||||
|
||||
|
||||
def get_auc_curve(ave_success_rate_plot_overlap, valid_sequence):
|
||||
ave_success_rate_plot_overlap = ave_success_rate_plot_overlap[valid_sequence, :, :]
|
||||
auc_curve = ave_success_rate_plot_overlap.mean(0) * 100.0
|
||||
auc = auc_curve.mean(-1)
|
||||
|
||||
return auc_curve, auc
|
||||
|
||||
|
||||
def get_prec_curve(ave_success_rate_plot_center, valid_sequence):
|
||||
ave_success_rate_plot_center = ave_success_rate_plot_center[valid_sequence, :, :]
|
||||
prec_curve = ave_success_rate_plot_center.mean(0) * 100.0
|
||||
prec_score = prec_curve[:, 20]
|
||||
|
||||
return prec_curve, prec_score
|
||||
|
||||
def plot_per_attribute_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
attr_folder = 'data/LaSOT/att'
|
||||
|
||||
attr_list = ['Illumination Variation', 'Partial Occlusion', 'Deformation', 'Motion Blur', 'Camera Motion', 'Rotation', 'Background Clutter', 'Viewpoint Change', 'Scale Variation', 'Full Occlusion', 'Fast Motion', 'Out-of-View', 'Low Resolution', 'Aspect Ration Change']
|
||||
attr_list = ['IV', 'POC', 'DEF', 'MB', 'CM', 'ROT', 'BC', 'VC', 'SV', 'FOC', 'FM', 'OV', 'LR', 'ARC']
|
||||
|
||||
# Iterate over the sequence and construct a valid_sequence for each attribute
|
||||
valid_sequence_attr = {}
|
||||
for attr in attr_list:
|
||||
valid_sequence_attr[attr] = torch.zeros(valid_sequence.shape[0], dtype=torch.bool)
|
||||
for seq_id, seq_obj in enumerate(dataset):
|
||||
seq_name = seq_obj.name
|
||||
attr_txt = osp.join(attr_folder, f'{seq_name}.txt')
|
||||
if osp.exists(attr_txt):
|
||||
# read the attribute file into a list of True and False
|
||||
# the attribute file looks like this: 0,0,0,0,0,1,0,1,1,0,0,0,0,0
|
||||
attr_anno = np.loadtxt(attr_txt, dtype=int, delimiter=',')
|
||||
# broadcast the valid_sequence to the attribute list
|
||||
for attr_id, attr in enumerate(attr_list):
|
||||
valid_sequence_attr[attr][seq_id] = attr_anno[attr_id]
|
||||
else:
|
||||
raise Exception(f'Attribute file not found for sequence {seq_name}')
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
if report_name == 'LaSOT-ext':
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
report_name = "LaSOT_{ext}"
|
||||
else:
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 100)
|
||||
ylim_norm_precision = (0, 88)
|
||||
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
for attr in attr_list:
|
||||
scores = {}
|
||||
|
||||
print(f'{attr}: {valid_sequence_attr[attr].sum().item()}')
|
||||
valid_sequence_attr[attr] = valid_sequence_attr[attr] & valid_sequence
|
||||
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence_attr[attr])
|
||||
scores['AUC'] = auc
|
||||
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence_attr[attr])
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence_attr[attr])
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=attr)
|
||||
print(report_text)
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold', 'attr': attr,
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': ylim_success, 'title': f'Success\ of\ {attr}\ ({report_name})'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, auc, tracker_names, plot_draw_styles, result_plot_path, success_plot_opts)
|
||||
|
||||
|
||||
|
||||
|
||||
def plot_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), force_evaluation=False, **kwargs):
|
||||
"""
|
||||
Plot results for the given trackers
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success',
|
||||
'prec' (precision), and 'norm_prec' (normalized precision)
|
||||
"""
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, force_evaluation, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nPlotting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
print('\nGenerating plots for: {}'.format(report_name))
|
||||
|
||||
print(report_name)
|
||||
if report_name == 'LaSOT':
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 95)
|
||||
ylim_norm_precision = (0, 95)
|
||||
elif report_name == 'LaSOT-ext':
|
||||
ylim_success = (0, 85)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
else:
|
||||
ylim_success = (0, 85)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold',
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': ylim_success, 'title': f'Success\ ({report_name})'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, auc, tracker_names, plot_draw_styles, result_plot_path, success_plot_opts)
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
threshold_set_center = torch.tensor(eval_data['threshold_set_center'])
|
||||
|
||||
precision_plot_opts = {'plot_type': 'precision', 'legend_loc': 'lower right',
|
||||
'xlabel': 'Location error threshold [pixels]', 'ylabel': 'Distance Precision [%]',
|
||||
'xlim': (0, 50), 'ylim': ylim_precision, 'title': f'Precision\ ({report_name})'}
|
||||
plot_draw_save(prec_curve, threshold_set_center, prec_score, tracker_names, plot_draw_styles, result_plot_path,
|
||||
precision_plot_opts)
|
||||
|
||||
# ******************************** Norm Precision Plot **************************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
threshold_set_center_norm = torch.tensor(eval_data['threshold_set_center_norm'])
|
||||
|
||||
norm_precision_plot_opts = {'plot_type': 'norm_precision', 'legend_loc': 'lower right',
|
||||
'xlabel': 'Location error threshold', 'ylabel': 'Distance Precision [%]',
|
||||
'xlim': (0, 0.5), 'ylim': ylim_norm_precision, 'title': f'Normalized\ Precision\ ({report_name})'}
|
||||
plot_draw_save(prec_curve, threshold_set_center_norm, prec_score, tracker_names, plot_draw_styles, result_plot_path,
|
||||
norm_precision_plot_opts)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def generate_formatted_report(row_labels, scores, table_name=''):
|
||||
name_width = max([len(d) for d in row_labels] + [len(table_name)]) + 5
|
||||
min_score_width = 10
|
||||
|
||||
report_text = '\n{label: <{width}} |'.format(label=table_name, width=name_width)
|
||||
|
||||
score_widths = [max(min_score_width, len(k) + 3) for k in scores.keys()]
|
||||
|
||||
for s, s_w in zip(scores.keys(), score_widths):
|
||||
report_text = '{prev} {s: <{width}} |'.format(prev=report_text, s=s, width=s_w)
|
||||
|
||||
report_text = '{prev}\n'.format(prev=report_text)
|
||||
|
||||
for trk_id, d_name in enumerate(row_labels):
|
||||
# display name
|
||||
report_text = '{prev}{tracker: <{width}} |'.format(prev=report_text, tracker=d_name,
|
||||
width=name_width)
|
||||
for (score_type, score_value), s_w in zip(scores.items(), score_widths):
|
||||
report_text = '{prev} {score: <{width}} |'.format(prev=report_text,
|
||||
score='{:0.2f}'.format(score_value[trk_id].item()),
|
||||
width=s_w)
|
||||
report_text = '{prev}\n'.format(prev=report_text)
|
||||
|
||||
return report_text
|
||||
|
||||
def print_per_attribute_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
attr_folder = 'data/LaSOT/att'
|
||||
|
||||
attr_list = ['Illumination Variation', 'Partial Occlusion', 'Deformation', 'Motion Blur', 'Camera Motion', 'Rotation', 'Background Clutter', 'Viewpoint Change', 'Scale Variation', 'Full Occlusion', 'Fast Motion', 'Out-of-View', 'Low Resolution', 'Aspect Ration Change']
|
||||
attr_list = ['IV', 'POC', 'DEF', 'MB', 'CM', 'ROT', 'BC', 'VC', 'SV', 'FOC', 'FM', 'OV', 'LR', 'ARC']
|
||||
|
||||
# Iterate over the sequence and construct a valid_sequence for each attribute
|
||||
valid_sequence_attr = {}
|
||||
for attr in attr_list:
|
||||
valid_sequence_attr[attr] = torch.zeros(valid_sequence.shape[0], dtype=torch.bool)
|
||||
for seq_id, seq_obj in enumerate(dataset):
|
||||
seq_name = seq_obj.name
|
||||
attr_txt = osp.join(attr_folder, f'{seq_name}.txt')
|
||||
if osp.exists(attr_txt):
|
||||
# read the attribute file into a list of True and False
|
||||
# the attribute file looks like this: 0,0,0,0,0,1,0,1,1,0,0,0,0,0
|
||||
attr_anno = np.loadtxt(attr_txt, dtype=int, delimiter=',')
|
||||
# broadcast the valid_sequence to the attribute list
|
||||
for attr_id, attr in enumerate(attr_list):
|
||||
valid_sequence_attr[attr][seq_id] = attr_anno[attr_id]
|
||||
else:
|
||||
raise Exception(f'Attribute file not found for sequence {seq_name}')
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
for attr in attr_list:
|
||||
scores = {}
|
||||
|
||||
print(f'{attr}: {valid_sequence_attr[attr].sum().item()}')
|
||||
valid_sequence_attr[attr] = valid_sequence_attr[attr] & valid_sequence
|
||||
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence_attr[attr])
|
||||
scores['AUC'] = auc
|
||||
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence_attr[attr])
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence_attr[attr])
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=attr)
|
||||
print(report_text)
|
||||
|
||||
|
||||
def print_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
""" Print the results for the given trackers in a formatted table
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success' (prints AUC, OP50, and OP75 scores),
|
||||
'prec' (prints precision score), and 'norm_prec' (prints normalized precision score)
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nReporting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
scores = {}
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
|
||||
|
||||
def plot_got_success(trackers, report_name):
|
||||
""" Plot success plot for GOT-10k dataset using the json reports.
|
||||
Save the json reports from http://got-10k.aitestunion.com/leaderboard in the directory set to
|
||||
env_settings.got_reports_path
|
||||
|
||||
The tracker name in the experiment file should be set to the name of the report file for that tracker,
|
||||
e.g. DiMP50_report_2019_09_02_15_44_25 if the report is name DiMP50_report_2019_09_02_15_44_25.json
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
"""
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
|
||||
auc_curve = torch.zeros((len(trackers), 101))
|
||||
scores = torch.zeros(len(trackers))
|
||||
|
||||
# Load results
|
||||
tracker_names = []
|
||||
for trk_id, trk in enumerate(trackers):
|
||||
json_path = '{}/{}.json'.format(settings.got_reports_path, trk.name)
|
||||
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
eval_data = json.load(f)
|
||||
else:
|
||||
raise Exception('Report not found {}'.format(json_path))
|
||||
|
||||
if len(eval_data.keys()) > 1:
|
||||
raise Exception
|
||||
|
||||
# First field is the tracker name. Index it out
|
||||
eval_data = eval_data[list(eval_data.keys())[0]]
|
||||
if 'succ_curve' in eval_data.keys():
|
||||
curve = eval_data['succ_curve']
|
||||
ao = eval_data['ao']
|
||||
elif 'overall' in eval_data.keys() and 'succ_curve' in eval_data['overall'].keys():
|
||||
curve = eval_data['overall']['succ_curve']
|
||||
ao = eval_data['overall']['ao']
|
||||
else:
|
||||
raise Exception('Invalid JSON file {}'.format(json_path))
|
||||
|
||||
auc_curve[trk_id, :] = torch.tensor(curve) * 100.0
|
||||
scores[trk_id] = ao * 100.0
|
||||
|
||||
tracker_names.append({'name': trk.name, 'param': trk.parameter_name, 'run_id': trk.run_id,
|
||||
'disp_name': trk.display_name})
|
||||
|
||||
threshold_set_overlap = torch.arange(0.0, 1.01, 0.01, dtype=torch.float64)
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold',
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': (0, 100), 'title': 'Success plot'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, scores, tracker_names, plot_draw_styles, result_plot_path,
|
||||
success_plot_opts)
|
||||
plt.show()
|
||||
|
||||
|
||||
def print_per_sequence_results(trackers, dataset, report_name, merge_results=False,
|
||||
filter_criteria=None, **kwargs):
|
||||
""" Print per-sequence results for the given trackers. Additionally, the sequences to list can be filtered using
|
||||
the filter criteria.
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
filter_criteria - Filter sequence results which are reported. Following modes are supported
|
||||
None: No filtering. Display results for all sequences in dataset
|
||||
'ao_min': Only display sequences for which the minimum average overlap (AO) score over the
|
||||
trackers is less than a threshold filter_criteria['threshold']. This mode can
|
||||
be used to select sequences where at least one tracker performs poorly.
|
||||
'ao_max': Only display sequences for which the maximum average overlap (AO) score over the
|
||||
trackers is less than a threshold filter_criteria['threshold']. This mode can
|
||||
be used to select sequences all tracker performs poorly.
|
||||
'delta_ao': Only display sequences for which the performance of different trackers vary by at
|
||||
least filter_criteria['threshold'] in average overlap (AO) score. This mode can
|
||||
be used to select sequences where the behaviour of the trackers greatly differ
|
||||
between each other.
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
sequence_names = eval_data['sequences']
|
||||
avg_overlap_all = torch.tensor(eval_data['avg_overlap_all']) * 100.0
|
||||
|
||||
# Filter sequences
|
||||
if filter_criteria is not None:
|
||||
if filter_criteria['mode'] == 'ao_min':
|
||||
min_ao = avg_overlap_all.min(dim=1)[0]
|
||||
valid_sequence = valid_sequence & (min_ao < filter_criteria['threshold'])
|
||||
elif filter_criteria['mode'] == 'ao_max':
|
||||
max_ao = avg_overlap_all.max(dim=1)[0]
|
||||
valid_sequence = valid_sequence & (max_ao < filter_criteria['threshold'])
|
||||
elif filter_criteria['mode'] == 'delta_ao':
|
||||
min_ao = avg_overlap_all.min(dim=1)[0]
|
||||
max_ao = avg_overlap_all.max(dim=1)[0]
|
||||
valid_sequence = valid_sequence & ((max_ao - min_ao) > filter_criteria['threshold'])
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
avg_overlap_all = avg_overlap_all[valid_sequence, :]
|
||||
sequence_names = [s + ' (ID={})'.format(i) for i, (s, v) in enumerate(zip(sequence_names, valid_sequence.tolist())) if v]
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
|
||||
scores_per_tracker = {k: avg_overlap_all[:, i] for i, k in enumerate(tracker_disp_names)}
|
||||
report_text = generate_formatted_report(sequence_names, scores_per_tracker)
|
||||
|
||||
print(report_text)
|
||||
|
||||
|
||||
def print_results_per_video(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), per_video=False, **kwargs):
|
||||
""" Print the results for the given trackers in a formatted table
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success' (prints AUC, OP50, and OP75 scores),
|
||||
'prec' (prints precision score), and 'norm_prec' (prints normalized precision score)
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
seq_lens = len(eval_data['sequences'])
|
||||
eval_datas = [{} for _ in range(seq_lens)]
|
||||
if per_video:
|
||||
for key, value in eval_data.items():
|
||||
if len(value) == seq_lens:
|
||||
for i in range(seq_lens):
|
||||
eval_datas[i][key] = [value[i]]
|
||||
else:
|
||||
for i in range(seq_lens):
|
||||
eval_datas[i][key] = value
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nReporting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
scores = {}
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
|
||||
if per_video:
|
||||
for i in range(seq_lens):
|
||||
eval_data = eval_datas[i]
|
||||
|
||||
print('\n{} sequences'.format(eval_data['sequences'][0]))
|
||||
|
||||
scores = {}
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .data import Sequence
|
||||
from .tracker import Tracker, trackerlist
|
||||
from .datasets import get_dataset
|
||||
from .environment import create_default_local_file_ITP_test
|
||||
@@ -0,0 +1,169 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.train.data.image_loader import imread_indexed
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
"""Base class for all datasets."""
|
||||
def __init__(self):
|
||||
self.env_settings = env_settings()
|
||||
|
||||
def __len__(self):
|
||||
"""Overload this function in your dataset. This should return number of sequences in the dataset."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_sequence_list(self):
|
||||
"""Overload this in your dataset. Should return the list of sequences in the dataset."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""Class for the sequence in an evaluation."""
|
||||
def __init__(self, name, frames, dataset, ground_truth_rect, ground_truth_seg=None, init_data=None,
|
||||
object_class=None, target_visible=None, object_ids=None, multiobj_mode=False):
|
||||
self.name = name
|
||||
self.frames = frames
|
||||
self.dataset = dataset
|
||||
self.ground_truth_rect = ground_truth_rect
|
||||
self.ground_truth_seg = ground_truth_seg
|
||||
self.object_class = object_class
|
||||
self.target_visible = target_visible
|
||||
self.object_ids = object_ids
|
||||
self.multiobj_mode = multiobj_mode
|
||||
self.init_data = self._construct_init_data(init_data)
|
||||
self._ensure_start_frame()
|
||||
|
||||
def _ensure_start_frame(self):
|
||||
# Ensure start frame is 0
|
||||
start_frame = min(list(self.init_data.keys()))
|
||||
if start_frame > 0:
|
||||
self.frames = self.frames[start_frame:]
|
||||
if self.ground_truth_rect is not None:
|
||||
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
||||
for obj_id, gt in self.ground_truth_rect.items():
|
||||
self.ground_truth_rect[obj_id] = gt[start_frame:,:]
|
||||
else:
|
||||
self.ground_truth_rect = self.ground_truth_rect[start_frame:,:]
|
||||
if self.ground_truth_seg is not None:
|
||||
self.ground_truth_seg = self.ground_truth_seg[start_frame:]
|
||||
assert len(self.frames) == len(self.ground_truth_seg)
|
||||
|
||||
if self.target_visible is not None:
|
||||
self.target_visible = self.target_visible[start_frame:]
|
||||
self.init_data = {frame-start_frame: val for frame, val in self.init_data.items()}
|
||||
|
||||
def _construct_init_data(self, init_data):
|
||||
if init_data is not None:
|
||||
if not self.multiobj_mode:
|
||||
assert self.object_ids is None or len(self.object_ids) == 1
|
||||
for frame, init_val in init_data.items():
|
||||
if 'bbox' in init_val and isinstance(init_val['bbox'], (dict, OrderedDict)):
|
||||
init_val['bbox'] = init_val['bbox'][self.object_ids[0]]
|
||||
# convert to list
|
||||
for frame, init_val in init_data.items():
|
||||
if 'bbox' in init_val:
|
||||
if isinstance(init_val['bbox'], (dict, OrderedDict)):
|
||||
init_val['bbox'] = OrderedDict({obj_id: list(init) for obj_id, init in init_val['bbox'].items()})
|
||||
else:
|
||||
init_val['bbox'] = list(init_val['bbox'])
|
||||
else:
|
||||
init_data = {0: dict()} # Assume start from frame 0
|
||||
|
||||
if self.object_ids is not None:
|
||||
init_data[0]['object_ids'] = self.object_ids
|
||||
|
||||
if self.ground_truth_rect is not None:
|
||||
if self.multiobj_mode:
|
||||
assert isinstance(self.ground_truth_rect, (dict, OrderedDict))
|
||||
init_data[0]['bbox'] = OrderedDict({obj_id: list(gt[0,:]) for obj_id, gt in self.ground_truth_rect.items()})
|
||||
else:
|
||||
assert self.object_ids is None or len(self.object_ids) == 1
|
||||
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
||||
init_data[0]['bbox'] = list(self.ground_truth_rect[self.object_ids[0]][0, :])
|
||||
else:
|
||||
init_data[0]['bbox'] = list(self.ground_truth_rect[0,:])
|
||||
|
||||
if self.ground_truth_seg is not None:
|
||||
init_data[0]['mask'] = self.ground_truth_seg[0]
|
||||
|
||||
return init_data
|
||||
|
||||
def init_info(self):
|
||||
info = self.frame_info(frame_num=0)
|
||||
return info
|
||||
|
||||
def frame_info(self, frame_num):
|
||||
info = self.object_init_data(frame_num=frame_num)
|
||||
return info
|
||||
|
||||
def init_bbox(self, frame_num=0):
|
||||
return self.object_init_data(frame_num=frame_num).get('init_bbox')
|
||||
|
||||
def init_mask(self, frame_num=0):
|
||||
return self.object_init_data(frame_num=frame_num).get('init_mask')
|
||||
|
||||
def get_info(self, keys, frame_num=None):
|
||||
info = dict()
|
||||
for k in keys:
|
||||
val = self.get(k, frame_num=frame_num)
|
||||
if val is not None:
|
||||
info[k] = val
|
||||
return info
|
||||
|
||||
def object_init_data(self, frame_num=None) -> dict:
|
||||
if frame_num is None:
|
||||
frame_num = 0
|
||||
if frame_num not in self.init_data:
|
||||
return dict()
|
||||
|
||||
init_data = dict()
|
||||
for key, val in self.init_data[frame_num].items():
|
||||
if val is None:
|
||||
continue
|
||||
init_data['init_'+key] = val
|
||||
|
||||
if 'init_mask' in init_data and init_data['init_mask'] is not None:
|
||||
anno = imread_indexed(init_data['init_mask'])
|
||||
if not self.multiobj_mode and self.object_ids is not None:
|
||||
assert len(self.object_ids) == 1
|
||||
anno = (anno == int(self.object_ids[0])).astype(np.uint8)
|
||||
init_data['init_mask'] = anno
|
||||
|
||||
if self.object_ids is not None:
|
||||
init_data['object_ids'] = self.object_ids
|
||||
init_data['sequence_object_ids'] = self.object_ids
|
||||
|
||||
return init_data
|
||||
|
||||
def target_class(self, frame_num=None):
|
||||
return self.object_class
|
||||
|
||||
def get(self, name, frame_num=None):
|
||||
return getattr(self, name)(frame_num)
|
||||
|
||||
def __repr__(self):
|
||||
return "{self.__class__.__name__} {self.name}, length={len} frames".format(self=self, len=len(self.frames))
|
||||
|
||||
|
||||
|
||||
class SequenceList(list):
|
||||
"""List of sequences. Supports the addition operator to concatenate sequence lists."""
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str):
|
||||
for seq in self:
|
||||
if seq.name == item:
|
||||
return seq
|
||||
raise IndexError('Sequence name not in the dataset.')
|
||||
elif isinstance(item, int):
|
||||
return super(SequenceList, self).__getitem__(item)
|
||||
elif isinstance(item, (tuple, list)):
|
||||
return SequenceList([super(SequenceList, self).__getitem__(i) for i in item])
|
||||
else:
|
||||
return SequenceList(super(SequenceList, self).__getitem__(item))
|
||||
|
||||
def __add__(self, other):
|
||||
return SequenceList(super(SequenceList, self).__add__(other))
|
||||
|
||||
def copy(self):
|
||||
return SequenceList(super(SequenceList, self).copy())
|
||||
@@ -0,0 +1,48 @@
|
||||
from collections import namedtuple
|
||||
import importlib
|
||||
from lib.test.evaluation.data import SequenceList
|
||||
|
||||
DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs'])
|
||||
|
||||
pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter
|
||||
|
||||
dataset_dict = dict(
|
||||
otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()),
|
||||
nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()),
|
||||
uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()),
|
||||
tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()),
|
||||
tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()),
|
||||
trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()),
|
||||
got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')),
|
||||
got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')),
|
||||
got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')),
|
||||
lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()),
|
||||
lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()),
|
||||
|
||||
vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()),
|
||||
vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)),
|
||||
itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()),
|
||||
tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()),
|
||||
lasot_extension_subset=DatasetInfo(module=pt % "lasotextensionsubset", class_name="LaSOTExtensionSubsetDataset",
|
||||
kwargs=dict()),
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(name: str):
|
||||
""" Import and load a single dataset."""
|
||||
name = name.lower()
|
||||
dset_info = dataset_dict.get(name)
|
||||
if dset_info is None:
|
||||
raise ValueError('Unknown dataset \'%s\'' % name)
|
||||
|
||||
m = importlib.import_module(dset_info.module)
|
||||
dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor
|
||||
return dataset.get_sequence_list()
|
||||
|
||||
|
||||
def get_dataset(*args):
|
||||
""" Get a single or set of datasets."""
|
||||
dset = SequenceList()
|
||||
for name in args:
|
||||
dset.extend(load_dataset(name))
|
||||
return dset
|
||||
@@ -0,0 +1,124 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
class EnvSettings:
|
||||
def __init__(self):
|
||||
test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
self.results_path = '{}/results/'.format(test_path)
|
||||
self.segmentation_path = '{}/segmentation_results/'.format(test_path)
|
||||
self.network_path = '{}/networks/'.format(test_path)
|
||||
self.result_plot_path = '{}/result_plots/'.format(test_path)
|
||||
self.otb_path = ''
|
||||
self.nfs_path = ''
|
||||
self.uav_path = ''
|
||||
self.tpl_path = ''
|
||||
self.vot_path = ''
|
||||
self.got10k_path = ''
|
||||
self.lasot_path = ''
|
||||
self.trackingnet_path = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
|
||||
self.got_packed_results_path = ''
|
||||
self.got_reports_path = ''
|
||||
self.tn_packed_results_path = ''
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
comment = {'results_path': 'Where to store tracking results',
|
||||
'network_path': 'Where tracking networks are stored.'}
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
with open(path, 'w') as f:
|
||||
settings = EnvSettings()
|
||||
|
||||
f.write('from test.evaluation.environment import EnvSettings\n\n')
|
||||
f.write('def local_env_settings():\n')
|
||||
f.write(' settings = EnvSettings()\n\n')
|
||||
f.write(' # Set your local paths here.\n\n')
|
||||
|
||||
for attr in dir(settings):
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
attr_val = getattr(settings, attr)
|
||||
if not attr.startswith('__') and not callable(attr_val):
|
||||
if comment_str is None:
|
||||
f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
f.write('\n return settings\n\n')
|
||||
|
||||
|
||||
class EnvSettings_ITP:
|
||||
def __init__(self, workspace_dir, data_dir, save_dir):
|
||||
self.prj_dir = workspace_dir
|
||||
self.save_dir = save_dir
|
||||
self.results_path = os.path.join(save_dir, 'test/tracking_results')
|
||||
self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results')
|
||||
self.network_path = os.path.join(save_dir, 'test/networks')
|
||||
self.result_plot_path = os.path.join(save_dir, 'test/result_plots')
|
||||
self.otb_path = os.path.join(data_dir, 'otb')
|
||||
self.nfs_path = os.path.join(data_dir, 'nfs')
|
||||
self.uav_path = os.path.join(data_dir, 'uav')
|
||||
self.tc128_path = os.path.join(data_dir, 'TC128')
|
||||
self.tpl_path = ''
|
||||
self.vot_path = os.path.join(data_dir, 'VOT2019')
|
||||
self.got10k_path = os.path.join(data_dir, 'got10k')
|
||||
self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb')
|
||||
self.lasot_path = os.path.join(data_dir, 'lasot')
|
||||
self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb')
|
||||
self.trackingnet_path = os.path.join(data_dir, 'trackingnet')
|
||||
self.vot18_path = os.path.join(data_dir, 'vot2018')
|
||||
self.vot22_path = os.path.join(data_dir, 'vot2022')
|
||||
self.itb_path = os.path.join(data_dir, 'itb')
|
||||
self.tnl2k_path = os.path.join(data_dir, 'tnl2k')
|
||||
self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset')
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
|
||||
self.got_packed_results_path = ''
|
||||
self.got_reports_path = ''
|
||||
self.tn_packed_results_path = ''
|
||||
|
||||
|
||||
def create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir):
|
||||
comment = {'results_path': 'Where to store tracking results',
|
||||
'network_path': 'Where tracking networks are stored.'}
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
with open(path, 'w') as f:
|
||||
settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir)
|
||||
|
||||
f.write('from lib.test.evaluation.environment import EnvSettings\n\n')
|
||||
f.write('def local_env_settings():\n')
|
||||
f.write(' settings = EnvSettings()\n\n')
|
||||
f.write(' # Set your local paths here.\n\n')
|
||||
|
||||
for attr in dir(settings):
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
attr_val = getattr(settings, attr)
|
||||
if not attr.startswith('__') and not callable(attr_val):
|
||||
if comment_str is None:
|
||||
f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
f.write('\n return settings\n\n')
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.test.evaluation.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.local_env_settings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
# Create a default file
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. '
|
||||
'Then try to run again.'.format(env_file))
|
||||
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
import os
|
||||
|
||||
|
||||
class GOT10KDataset(BaseDataset):
|
||||
""" GOT-10k dataset.
|
||||
|
||||
Publication:
|
||||
GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
|
||||
Lianghua Huang, Xin Zhao, and Kaiqi Huang
|
||||
arXiv:1810.11981, 2018
|
||||
https://arxiv.org/pdf/1810.11981.pdf
|
||||
|
||||
Download dataset from http://got-10k.aitestunion.com/downloads
|
||||
"""
|
||||
def __init__(self, split):
|
||||
super().__init__()
|
||||
# Split can be test, val, or ltrval (a validation split consisting of videos from the official train set)
|
||||
if split == 'test' or split == 'val':
|
||||
self.base_path = os.path.join(self.env_settings.got10k_path, split)
|
||||
else:
|
||||
self.base_path = os.path.join(self.env_settings.got10k_path, 'train')
|
||||
|
||||
self.sequence_list = self._get_sequence_list(split)
|
||||
self.split = split
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
frames_path = '{}/{}'.format(self.base_path, sequence_name)
|
||||
frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
|
||||
frame_list.sort(key=lambda f: int(f[:-4]))
|
||||
frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
|
||||
|
||||
return Sequence(sequence_name, frames_list, 'got10k', ground_truth_rect.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self, split):
|
||||
with open('{}/list.txt'.format(self.base_path)) as f:
|
||||
sequence_list = f.read().splitlines()
|
||||
|
||||
if split == 'ltrval':
|
||||
with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f:
|
||||
seq_ids = f.read().splitlines()
|
||||
|
||||
sequence_list = [sequence_list[int(x)] for x in seq_ids]
|
||||
return sequence_list
|
||||
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
import os
|
||||
|
||||
|
||||
class ITBDataset(BaseDataset):
|
||||
""" NUS-PRO dataset
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.itb_path
|
||||
self.sequence_info_list = self._get_sequence_info_list(self.base_path)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num,
|
||||
nz=nz, ext=ext) for frame_num in
|
||||
range(start_frame + init_omit, end_frame + 1)]
|
||||
|
||||
anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
|
||||
# NOTE: NUS has some weird annos which panda cannot handle
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy')
|
||||
return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:, :],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def get_fileNames(self, rootdir):
|
||||
fs = []
|
||||
fs_all = []
|
||||
for root, dirs, files in os.walk(rootdir, topdown=True):
|
||||
files.sort()
|
||||
files.sort(key=len)
|
||||
if files is not None:
|
||||
for name in files:
|
||||
_, ending = os.path.splitext(name)
|
||||
if ending == ".jpg":
|
||||
_, root_ = os.path.split(root)
|
||||
fs.append(os.path.join(root_, name))
|
||||
fs_all.append(os.path.join(root, name))
|
||||
|
||||
return fs_all, fs
|
||||
|
||||
def _get_sequence_info_list(self, base_path):
|
||||
sequence_info_list = []
|
||||
for scene in os.listdir(base_path):
|
||||
if '.' in scene:
|
||||
continue
|
||||
videos = os.listdir(os.path.join(base_path, scene))
|
||||
for video in videos:
|
||||
_, fs = self.get_fileNames(os.path.join(base_path, scene, video))
|
||||
video_tmp = {"name": video, "path": scene + '/' + video, "startFrame": 1, "endFrame": len(fs),
|
||||
"nz": len(fs[0].split('/')[-1].split('.')[0]), "ext": "jpg",
|
||||
"anno_path": scene + '/' + video + "/groundtruth.txt",
|
||||
"object_class": "unknown"}
|
||||
sequence_info_list.append(video_tmp)
|
||||
|
||||
return sequence_info_list # sequence_info_list_50 #
|
||||
@@ -0,0 +1,345 @@
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
'''2021.1.27 LaSOT dataset using lmdb data'''
|
||||
|
||||
|
||||
class LaSOTlmdbDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_lmdb_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = str('{}/{}/groundtruth.txt'.format(class_name, sequence_name))
|
||||
# decode the groundtruth
|
||||
gt_str_list = decode_str(self.base_path, anno_path).split('\n')[:-1] # the last line is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
ground_truth_rect = np.array(gt_list).astype(np.float64)
|
||||
# decode occlusion file
|
||||
occlusion_label_path = str('{}/{}/full_occlusion.txt'.format(class_name, sequence_name))
|
||||
occ_list = list(map(int, decode_str(self.base_path, occlusion_label_path).split(',')))
|
||||
full_occlusion = np.array(occ_list).astype(np.float64)
|
||||
# decode out of view file
|
||||
out_of_view_label_path = str('{}/{}/out_of_view.txt'.format(class_name, sequence_name))
|
||||
out_of_view_list = list(map(int, decode_str(self.base_path, out_of_view_label_path).split(',')))
|
||||
out_of_view = np.array(out_of_view_list).astype(np.float64)
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/img'.format(class_name, sequence_name)
|
||||
|
||||
frames_list = [[self.base_path, '{}/{:08d}.jpg'.format(frames_path, frame_number)] for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['airplane-1',
|
||||
'airplane-9',
|
||||
'airplane-13',
|
||||
'airplane-15',
|
||||
'basketball-1',
|
||||
'basketball-6',
|
||||
'basketball-7',
|
||||
'basketball-11',
|
||||
'bear-2',
|
||||
'bear-4',
|
||||
'bear-6',
|
||||
'bear-17',
|
||||
'bicycle-2',
|
||||
'bicycle-7',
|
||||
'bicycle-9',
|
||||
'bicycle-18',
|
||||
'bird-2',
|
||||
'bird-3',
|
||||
'bird-15',
|
||||
'bird-17',
|
||||
'boat-3',
|
||||
'boat-4',
|
||||
'boat-12',
|
||||
'boat-17',
|
||||
'book-3',
|
||||
'book-10',
|
||||
'book-11',
|
||||
'book-19',
|
||||
'bottle-1',
|
||||
'bottle-12',
|
||||
'bottle-14',
|
||||
'bottle-18',
|
||||
'bus-2',
|
||||
'bus-5',
|
||||
'bus-17',
|
||||
'bus-19',
|
||||
'car-2',
|
||||
'car-6',
|
||||
'car-9',
|
||||
'car-17',
|
||||
'cat-1',
|
||||
'cat-3',
|
||||
'cat-18',
|
||||
'cat-20',
|
||||
'cattle-2',
|
||||
'cattle-7',
|
||||
'cattle-12',
|
||||
'cattle-13',
|
||||
'spider-14',
|
||||
'spider-16',
|
||||
'spider-18',
|
||||
'spider-20',
|
||||
'coin-3',
|
||||
'coin-6',
|
||||
'coin-7',
|
||||
'coin-18',
|
||||
'crab-3',
|
||||
'crab-6',
|
||||
'crab-12',
|
||||
'crab-18',
|
||||
'surfboard-12',
|
||||
'surfboard-4',
|
||||
'surfboard-5',
|
||||
'surfboard-8',
|
||||
'cup-1',
|
||||
'cup-4',
|
||||
'cup-7',
|
||||
'cup-17',
|
||||
'deer-4',
|
||||
'deer-8',
|
||||
'deer-10',
|
||||
'deer-14',
|
||||
'dog-1',
|
||||
'dog-7',
|
||||
'dog-15',
|
||||
'dog-19',
|
||||
'guitar-3',
|
||||
'guitar-8',
|
||||
'guitar-10',
|
||||
'guitar-16',
|
||||
'person-1',
|
||||
'person-5',
|
||||
'person-10',
|
||||
'person-12',
|
||||
'pig-2',
|
||||
'pig-10',
|
||||
'pig-13',
|
||||
'pig-18',
|
||||
'rubicCube-1',
|
||||
'rubicCube-6',
|
||||
'rubicCube-14',
|
||||
'rubicCube-19',
|
||||
'swing-10',
|
||||
'swing-14',
|
||||
'swing-17',
|
||||
'swing-20',
|
||||
'drone-13',
|
||||
'drone-15',
|
||||
'drone-2',
|
||||
'drone-7',
|
||||
'pool-12',
|
||||
'pool-15',
|
||||
'pool-3',
|
||||
'pool-7',
|
||||
'rabbit-10',
|
||||
'rabbit-13',
|
||||
'rabbit-17',
|
||||
'rabbit-19',
|
||||
'racing-10',
|
||||
'racing-15',
|
||||
'racing-16',
|
||||
'racing-20',
|
||||
'robot-1',
|
||||
'robot-19',
|
||||
'robot-5',
|
||||
'robot-8',
|
||||
'sepia-13',
|
||||
'sepia-16',
|
||||
'sepia-6',
|
||||
'sepia-8',
|
||||
'sheep-3',
|
||||
'sheep-5',
|
||||
'sheep-7',
|
||||
'sheep-9',
|
||||
'skateboard-16',
|
||||
'skateboard-19',
|
||||
'skateboard-3',
|
||||
'skateboard-8',
|
||||
'tank-14',
|
||||
'tank-16',
|
||||
'tank-6',
|
||||
'tank-9',
|
||||
'tiger-12',
|
||||
'tiger-18',
|
||||
'tiger-4',
|
||||
'tiger-6',
|
||||
'train-1',
|
||||
'train-11',
|
||||
'train-20',
|
||||
'train-7',
|
||||
'truck-16',
|
||||
'truck-3',
|
||||
'truck-6',
|
||||
'truck-7',
|
||||
'turtle-16',
|
||||
'turtle-5',
|
||||
'turtle-8',
|
||||
'turtle-9',
|
||||
'umbrella-17',
|
||||
'umbrella-19',
|
||||
'umbrella-2',
|
||||
'umbrella-9',
|
||||
'yoyo-15',
|
||||
'yoyo-17',
|
||||
'yoyo-19',
|
||||
'yoyo-7',
|
||||
'zebra-10',
|
||||
'zebra-14',
|
||||
'zebra-16',
|
||||
'zebra-17',
|
||||
'elephant-1',
|
||||
'elephant-12',
|
||||
'elephant-16',
|
||||
'elephant-18',
|
||||
'goldfish-3',
|
||||
'goldfish-7',
|
||||
'goldfish-8',
|
||||
'goldfish-10',
|
||||
'hat-1',
|
||||
'hat-2',
|
||||
'hat-5',
|
||||
'hat-18',
|
||||
'kite-4',
|
||||
'kite-6',
|
||||
'kite-10',
|
||||
'kite-15',
|
||||
'motorcycle-1',
|
||||
'motorcycle-3',
|
||||
'motorcycle-9',
|
||||
'motorcycle-18',
|
||||
'mouse-1',
|
||||
'mouse-8',
|
||||
'mouse-9',
|
||||
'mouse-17',
|
||||
'flag-3',
|
||||
'flag-9',
|
||||
'flag-5',
|
||||
'flag-2',
|
||||
'frog-3',
|
||||
'frog-4',
|
||||
'frog-20',
|
||||
'frog-9',
|
||||
'gametarget-1',
|
||||
'gametarget-2',
|
||||
'gametarget-7',
|
||||
'gametarget-13',
|
||||
'hand-2',
|
||||
'hand-3',
|
||||
'hand-9',
|
||||
'hand-16',
|
||||
'helmet-5',
|
||||
'helmet-11',
|
||||
'helmet-19',
|
||||
'helmet-13',
|
||||
'licenseplate-6',
|
||||
'licenseplate-12',
|
||||
'licenseplate-13',
|
||||
'licenseplate-15',
|
||||
'electricfan-1',
|
||||
'electricfan-10',
|
||||
'electricfan-18',
|
||||
'electricfan-20',
|
||||
'chameleon-3',
|
||||
'chameleon-6',
|
||||
'chameleon-11',
|
||||
'chameleon-20',
|
||||
'crocodile-3',
|
||||
'crocodile-4',
|
||||
'crocodile-10',
|
||||
'crocodile-14',
|
||||
'gecko-1',
|
||||
'gecko-5',
|
||||
'gecko-16',
|
||||
'gecko-19',
|
||||
'fox-2',
|
||||
'fox-3',
|
||||
'fox-5',
|
||||
'fox-20',
|
||||
'giraffe-2',
|
||||
'giraffe-10',
|
||||
'giraffe-13',
|
||||
'giraffe-15',
|
||||
'gorilla-4',
|
||||
'gorilla-6',
|
||||
'gorilla-9',
|
||||
'gorilla-13',
|
||||
'hippo-1',
|
||||
'hippo-7',
|
||||
'hippo-9',
|
||||
'hippo-20',
|
||||
'horse-1',
|
||||
'horse-4',
|
||||
'horse-12',
|
||||
'horse-15',
|
||||
'kangaroo-2',
|
||||
'kangaroo-5',
|
||||
'kangaroo-11',
|
||||
'kangaroo-14',
|
||||
'leopard-1',
|
||||
'leopard-7',
|
||||
'leopard-16',
|
||||
'leopard-20',
|
||||
'lion-1',
|
||||
'lion-5',
|
||||
'lion-12',
|
||||
'lion-20',
|
||||
'lizard-1',
|
||||
'lizard-3',
|
||||
'lizard-6',
|
||||
'lizard-13',
|
||||
'microphone-2',
|
||||
'microphone-6',
|
||||
'microphone-14',
|
||||
'microphone-16',
|
||||
'monkey-3',
|
||||
'monkey-4',
|
||||
'monkey-9',
|
||||
'monkey-17',
|
||||
'shark-2',
|
||||
'shark-3',
|
||||
'shark-5',
|
||||
'shark-6',
|
||||
'squirrel-8',
|
||||
'squirrel-11',
|
||||
'squirrel-13',
|
||||
'squirrel-19',
|
||||
'volleyball-1',
|
||||
'volleyball-13',
|
||||
'volleyball-18',
|
||||
'volleyball-19']
|
||||
return sequence_list
|
||||
@@ -0,0 +1,342 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class LaSOTDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/{}/groundtruth.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
occlusion_label_path = '{}/{}/{}/full_occlusion.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
# NOTE: pandas backed seems super super slow for loading occlusion/oov masks
|
||||
full_occlusion = load_text(str(occlusion_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
out_of_view_label_path = '{}/{}/{}/out_of_view.txt'.format(self.base_path, class_name, sequence_name)
|
||||
out_of_view = load_text(str(out_of_view_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/{}/img'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
frames_list = ['{}/{:08d}.jpg'.format(frames_path, frame_number) for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['airplane-1',
|
||||
'airplane-9',
|
||||
'airplane-13',
|
||||
'airplane-15',
|
||||
'basketball-1',
|
||||
'basketball-6',
|
||||
'basketball-7',
|
||||
'basketball-11',
|
||||
'bear-2',
|
||||
'bear-4',
|
||||
'bear-6',
|
||||
'bear-17',
|
||||
'bicycle-2',
|
||||
'bicycle-7',
|
||||
'bicycle-9',
|
||||
'bicycle-18',
|
||||
'bird-2',
|
||||
'bird-3',
|
||||
'bird-15',
|
||||
'bird-17',
|
||||
'boat-3',
|
||||
'boat-4',
|
||||
'boat-12',
|
||||
'boat-17',
|
||||
'book-3',
|
||||
'book-10',
|
||||
'book-11',
|
||||
'book-19',
|
||||
'bottle-1',
|
||||
'bottle-12',
|
||||
'bottle-14',
|
||||
'bottle-18',
|
||||
'bus-2',
|
||||
'bus-5',
|
||||
'bus-17',
|
||||
'bus-19',
|
||||
'car-2',
|
||||
'car-6',
|
||||
'car-9',
|
||||
'car-17',
|
||||
'cat-1',
|
||||
'cat-3',
|
||||
'cat-18',
|
||||
'cat-20',
|
||||
'cattle-2',
|
||||
'cattle-7',
|
||||
'cattle-12',
|
||||
'cattle-13',
|
||||
'spider-14',
|
||||
'spider-16',
|
||||
'spider-18',
|
||||
'spider-20',
|
||||
'coin-3',
|
||||
'coin-6',
|
||||
'coin-7',
|
||||
'coin-18',
|
||||
'crab-3',
|
||||
'crab-6',
|
||||
'crab-12',
|
||||
'crab-18',
|
||||
'surfboard-12',
|
||||
'surfboard-4',
|
||||
'surfboard-5',
|
||||
'surfboard-8',
|
||||
'cup-1',
|
||||
'cup-4',
|
||||
'cup-7',
|
||||
'cup-17',
|
||||
'deer-4',
|
||||
'deer-8',
|
||||
'deer-10',
|
||||
'deer-14',
|
||||
'dog-1',
|
||||
'dog-7',
|
||||
'dog-15',
|
||||
'dog-19',
|
||||
'guitar-3',
|
||||
'guitar-8',
|
||||
'guitar-10',
|
||||
'guitar-16',
|
||||
'person-1',
|
||||
'person-5',
|
||||
'person-10',
|
||||
'person-12',
|
||||
'pig-2',
|
||||
'pig-10',
|
||||
'pig-13',
|
||||
'pig-18',
|
||||
'rubicCube-1',
|
||||
'rubicCube-6',
|
||||
'rubicCube-14',
|
||||
'rubicCube-19',
|
||||
'swing-10',
|
||||
'swing-14',
|
||||
'swing-17',
|
||||
'swing-20',
|
||||
'drone-13',
|
||||
'drone-15',
|
||||
'drone-2',
|
||||
'drone-7',
|
||||
'pool-12',
|
||||
'pool-15',
|
||||
'pool-3',
|
||||
'pool-7',
|
||||
'rabbit-10',
|
||||
'rabbit-13',
|
||||
'rabbit-17',
|
||||
'rabbit-19',
|
||||
'racing-10',
|
||||
'racing-15',
|
||||
'racing-16',
|
||||
'racing-20',
|
||||
'robot-1',
|
||||
'robot-19',
|
||||
'robot-5',
|
||||
'robot-8',
|
||||
'sepia-13',
|
||||
'sepia-16',
|
||||
'sepia-6',
|
||||
'sepia-8',
|
||||
'sheep-3',
|
||||
'sheep-5',
|
||||
'sheep-7',
|
||||
'sheep-9',
|
||||
'skateboard-16',
|
||||
'skateboard-19',
|
||||
'skateboard-3',
|
||||
'skateboard-8',
|
||||
'tank-14',
|
||||
'tank-16',
|
||||
'tank-6',
|
||||
'tank-9',
|
||||
'tiger-12',
|
||||
'tiger-18',
|
||||
'tiger-4',
|
||||
'tiger-6',
|
||||
'train-1',
|
||||
'train-11',
|
||||
'train-20',
|
||||
'train-7',
|
||||
'truck-16',
|
||||
'truck-3',
|
||||
'truck-6',
|
||||
'truck-7',
|
||||
'turtle-16',
|
||||
'turtle-5',
|
||||
'turtle-8',
|
||||
'turtle-9',
|
||||
'umbrella-17',
|
||||
'umbrella-19',
|
||||
'umbrella-2',
|
||||
'umbrella-9',
|
||||
'yoyo-15',
|
||||
'yoyo-17',
|
||||
'yoyo-19',
|
||||
'yoyo-7',
|
||||
'zebra-10',
|
||||
'zebra-14',
|
||||
'zebra-16',
|
||||
'zebra-17',
|
||||
'elephant-1',
|
||||
'elephant-12',
|
||||
'elephant-16',
|
||||
'elephant-18',
|
||||
'goldfish-3',
|
||||
'goldfish-7',
|
||||
'goldfish-8',
|
||||
'goldfish-10',
|
||||
'hat-1',
|
||||
'hat-2',
|
||||
'hat-5',
|
||||
'hat-18',
|
||||
'kite-4',
|
||||
'kite-6',
|
||||
'kite-10',
|
||||
'kite-15',
|
||||
'motorcycle-1',
|
||||
'motorcycle-3',
|
||||
'motorcycle-9',
|
||||
'motorcycle-18',
|
||||
'mouse-1',
|
||||
'mouse-8',
|
||||
'mouse-9',
|
||||
'mouse-17',
|
||||
'flag-3',
|
||||
'flag-9',
|
||||
'flag-5',
|
||||
'flag-2',
|
||||
'frog-3',
|
||||
'frog-4',
|
||||
'frog-20',
|
||||
'frog-9',
|
||||
'gametarget-1',
|
||||
'gametarget-2',
|
||||
'gametarget-7',
|
||||
'gametarget-13',
|
||||
'hand-2',
|
||||
'hand-3',
|
||||
'hand-9',
|
||||
'hand-16',
|
||||
'helmet-5',
|
||||
'helmet-11',
|
||||
'helmet-19',
|
||||
'helmet-13',
|
||||
'licenseplate-6',
|
||||
'licenseplate-12',
|
||||
'licenseplate-13',
|
||||
'licenseplate-15',
|
||||
'electricfan-1',
|
||||
'electricfan-10',
|
||||
'electricfan-18',
|
||||
'electricfan-20',
|
||||
'chameleon-3',
|
||||
'chameleon-6',
|
||||
'chameleon-11',
|
||||
'chameleon-20',
|
||||
'crocodile-3',
|
||||
'crocodile-4',
|
||||
'crocodile-10',
|
||||
'crocodile-14',
|
||||
'gecko-1',
|
||||
'gecko-5',
|
||||
'gecko-16',
|
||||
'gecko-19',
|
||||
'fox-2',
|
||||
'fox-3',
|
||||
'fox-5',
|
||||
'fox-20',
|
||||
'giraffe-2',
|
||||
'giraffe-10',
|
||||
'giraffe-13',
|
||||
'giraffe-15',
|
||||
'gorilla-4',
|
||||
'gorilla-6',
|
||||
'gorilla-9',
|
||||
'gorilla-13',
|
||||
'hippo-1',
|
||||
'hippo-7',
|
||||
'hippo-9',
|
||||
'hippo-20',
|
||||
'horse-1',
|
||||
'horse-4',
|
||||
'horse-12',
|
||||
'horse-15',
|
||||
'kangaroo-2',
|
||||
'kangaroo-5',
|
||||
'kangaroo-11',
|
||||
'kangaroo-14',
|
||||
'leopard-1',
|
||||
'leopard-7',
|
||||
'leopard-16',
|
||||
'leopard-20',
|
||||
'lion-1',
|
||||
'lion-5',
|
||||
'lion-12',
|
||||
'lion-20',
|
||||
'lizard-1',
|
||||
'lizard-3',
|
||||
'lizard-6',
|
||||
'lizard-13',
|
||||
'microphone-2',
|
||||
'microphone-6',
|
||||
'microphone-14',
|
||||
'microphone-16',
|
||||
'monkey-3',
|
||||
'monkey-4',
|
||||
'monkey-9',
|
||||
'monkey-17',
|
||||
'shark-2',
|
||||
'shark-3',
|
||||
'shark-5',
|
||||
'shark-6',
|
||||
'squirrel-8',
|
||||
'squirrel-11',
|
||||
'squirrel-13',
|
||||
'squirrel-19',
|
||||
'volleyball-1',
|
||||
'volleyball-13',
|
||||
'volleyball-18',
|
||||
'volleyball-19']
|
||||
return sequence_list
|
||||
@@ -0,0 +1,211 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class LaSOTExtensionSubsetDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
Publication:
|
||||
LaSOT: A High-quality Large-scale Single Object Tracking Benchmark
|
||||
Heng Fan, Hexin Bai, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Harshit, Mingzhen Huang, Juehuan Liu,
|
||||
Yong Xu, Chunyuan Liao, Lin Yuan, Haibin Ling
|
||||
IJCV, 2020
|
||||
https://arxiv.org/pdf/2009.03465.pdf
|
||||
Download the dataset from http://vision.cs.stonybrook.edu/~lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_extension_subset_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/{}/groundtruth.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
occlusion_label_path = '{}/{}/{}/full_occlusion.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
# NOTE: pandas backed seems super super slow for loading occlusion/oov masks
|
||||
full_occlusion = load_text(str(occlusion_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
out_of_view_label_path = '{}/{}/{}/out_of_view.txt'.format(self.base_path, class_name, sequence_name)
|
||||
out_of_view = load_text(str(out_of_view_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/{}/img'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
frames_list = ['{}/{:08d}.jpg'.format(frames_path, frame_number) for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot_extension_subset', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['atv-1',
|
||||
'atv-2',
|
||||
'atv-3',
|
||||
'atv-4',
|
||||
'atv-5',
|
||||
'atv-6',
|
||||
'atv-7',
|
||||
'atv-8',
|
||||
'atv-9',
|
||||
'atv-10',
|
||||
'badminton-1',
|
||||
'badminton-2',
|
||||
'badminton-3',
|
||||
'badminton-4',
|
||||
'badminton-5',
|
||||
'badminton-6',
|
||||
'badminton-7',
|
||||
'badminton-8',
|
||||
'badminton-9',
|
||||
'badminton-10',
|
||||
'cosplay-1',
|
||||
'cosplay-10',
|
||||
'cosplay-2',
|
||||
'cosplay-3',
|
||||
'cosplay-4',
|
||||
'cosplay-5',
|
||||
'cosplay-6',
|
||||
'cosplay-7',
|
||||
'cosplay-8',
|
||||
'cosplay-9',
|
||||
'dancingshoe-1',
|
||||
'dancingshoe-2',
|
||||
'dancingshoe-3',
|
||||
'dancingshoe-4',
|
||||
'dancingshoe-5',
|
||||
'dancingshoe-6',
|
||||
'dancingshoe-7',
|
||||
'dancingshoe-8',
|
||||
'dancingshoe-9',
|
||||
'dancingshoe-10',
|
||||
'footbag-1',
|
||||
'footbag-2',
|
||||
'footbag-3',
|
||||
'footbag-4',
|
||||
'footbag-5',
|
||||
'footbag-6',
|
||||
'footbag-7',
|
||||
'footbag-8',
|
||||
'footbag-9',
|
||||
'footbag-10',
|
||||
'frisbee-1',
|
||||
'frisbee-2',
|
||||
'frisbee-3',
|
||||
'frisbee-4',
|
||||
'frisbee-5',
|
||||
'frisbee-6',
|
||||
'frisbee-7',
|
||||
'frisbee-8',
|
||||
'frisbee-9',
|
||||
'frisbee-10',
|
||||
'jianzi-1',
|
||||
'jianzi-2',
|
||||
'jianzi-3',
|
||||
'jianzi-4',
|
||||
'jianzi-5',
|
||||
'jianzi-6',
|
||||
'jianzi-7',
|
||||
'jianzi-8',
|
||||
'jianzi-9',
|
||||
'jianzi-10',
|
||||
'lantern-1',
|
||||
'lantern-2',
|
||||
'lantern-3',
|
||||
'lantern-4',
|
||||
'lantern-5',
|
||||
'lantern-6',
|
||||
'lantern-7',
|
||||
'lantern-8',
|
||||
'lantern-9',
|
||||
'lantern-10',
|
||||
'misc-1',
|
||||
'misc-2',
|
||||
'misc-3',
|
||||
'misc-4',
|
||||
'misc-5',
|
||||
'misc-6',
|
||||
'misc-7',
|
||||
'misc-8',
|
||||
'misc-9',
|
||||
'misc-10',
|
||||
'opossum-1',
|
||||
'opossum-2',
|
||||
'opossum-3',
|
||||
'opossum-4',
|
||||
'opossum-5',
|
||||
'opossum-6',
|
||||
'opossum-7',
|
||||
'opossum-8',
|
||||
'opossum-9',
|
||||
'opossum-10',
|
||||
'paddle-1',
|
||||
'paddle-2',
|
||||
'paddle-3',
|
||||
'paddle-4',
|
||||
'paddle-5',
|
||||
'paddle-6',
|
||||
'paddle-7',
|
||||
'paddle-8',
|
||||
'paddle-9',
|
||||
'paddle-10',
|
||||
'raccoon-1',
|
||||
'raccoon-2',
|
||||
'raccoon-3',
|
||||
'raccoon-4',
|
||||
'raccoon-5',
|
||||
'raccoon-6',
|
||||
'raccoon-7',
|
||||
'raccoon-8',
|
||||
'raccoon-9',
|
||||
'raccoon-10',
|
||||
'rhino-1',
|
||||
'rhino-2',
|
||||
'rhino-3',
|
||||
'rhino-4',
|
||||
'rhino-5',
|
||||
'rhino-6',
|
||||
'rhino-7',
|
||||
'rhino-8',
|
||||
'rhino-9',
|
||||
'rhino-10',
|
||||
'skatingshoe-1',
|
||||
'skatingshoe-2',
|
||||
'skatingshoe-3',
|
||||
'skatingshoe-4',
|
||||
'skatingshoe-5',
|
||||
'skatingshoe-6',
|
||||
'skatingshoe-7',
|
||||
'skatingshoe-8',
|
||||
'skatingshoe-9',
|
||||
'skatingshoe-10',
|
||||
'wingsuit-1',
|
||||
'wingsuit-2',
|
||||
'wingsuit-3',
|
||||
'wingsuit-4',
|
||||
'wingsuit-5',
|
||||
'wingsuit-6',
|
||||
'wingsuit-7',
|
||||
'wingsuit-8',
|
||||
'wingsuit-9',
|
||||
'wingsuit-10']
|
||||
return sequence_list
|
||||
@@ -0,0 +1,38 @@
|
||||
from lib.test.evaluation.environment import EnvSettings
|
||||
|
||||
def local_env_settings():
|
||||
settings = EnvSettings()
|
||||
|
||||
# Set your local paths here.
|
||||
|
||||
settings.lasot_path = '/home/cycyang/code/vot-sam/data/LaSOT'
|
||||
settings.lasot_extension_subset_path = '/home/cycyang/code/vot-sam/data/LaSOT-ext'
|
||||
settings.nfs_path = '/home/cycyang/code/vot-sam/data/NFS'
|
||||
settings.otb_path = '/home/cycyang/code/vot-sam/data/otb'
|
||||
settings.uav_path = '//home/cycyang/code/vot-sam/data/uav'
|
||||
settings.results_path = '/home/cycyang/code/vot-sam/raw_results'
|
||||
settings.result_plot_path = '/home/cycyang/code/vot-sam/evaluation_results'
|
||||
settings.save_dir = '/home/cycyang/code/vot-sam/evaluation_results'
|
||||
|
||||
settings.davis_dir = ''
|
||||
settings.got10k_lmdb_path = '/home/baiyifan/code/OSTrack/data/got10k_lmdb'
|
||||
settings.got10k_path = '/home/baiyifan/GOT-10k'
|
||||
settings.got_packed_results_path = ''
|
||||
settings.got_reports_path = ''
|
||||
settings.itb_path = '/home/baiyifan/code/OSTrack/data/itb'
|
||||
settings.lasot_lmdb_path = '/home/baiyifan/code/OSTrack/data/lasot_lmdb'
|
||||
settings.network_path = '/ssddata/baiyifan/artrack_256_full_re/' # Where tracking networks are stored.
|
||||
settings.prj_dir = '/home/baiyifan/code/2d_autoregressive/bins_mask'
|
||||
settings.segmentation_path = '/data1/os/test/segmentation_results'
|
||||
settings.tc128_path = '/home/baiyifan/code/OSTrack/data/TC128'
|
||||
settings.tn_packed_results_path = ''
|
||||
settings.tnl2k_path = '/home/baiyifan/code/OSTrack/data/tnl2k'
|
||||
settings.tpl_path = ''
|
||||
settings.trackingnet_path = '/ssddata/TrackingNet/all_zip'
|
||||
settings.vot18_path = '/home/baiyifan/code/OSTrack/data/vot2018'
|
||||
settings.vot22_path = '/home/baiyifan/code/OSTrack/data/vot2022'
|
||||
settings.vot_path = '/home/baiyifan/code/OSTrack/data/VOT2019'
|
||||
settings.youtubevos_dir = ''
|
||||
|
||||
return settings
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class NFSDataset(BaseDataset):
|
||||
""" NFS dataset.
|
||||
Publication:
|
||||
Need for Speed: A Benchmark for Higher Frame Rate Object Tracking
|
||||
H. Kiani Galoogahi, A. Fagg, C. Huang, D. Ramanan, and S.Lucey
|
||||
ICCV, 2017
|
||||
http://openaccess.thecvf.com/content_ICCV_2017/papers/Galoogahi_Need_for_Speed_ICCV_2017_paper.pdf
|
||||
Download the dataset from http://ci2cv.net/nfs/index.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.nfs_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
# anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
anno_path = f"{self.base_path}/{sequence_info['name'][4:]}/30/groundtruth.txt"
|
||||
|
||||
# ground_truth_rect = load_text(str(anno_path), delimiter='\t', dtype=np.float64)
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
return Sequence(sequence_info['name'][4:], frames, 'nfs', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "nfs_Gymnastics", "path": "sequences/Gymnastics", "startFrame": 1, "endFrame": 368, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Gymnastics.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_MachLoop_jet", "path": "sequences/MachLoop_jet", "startFrame": 1, "endFrame": 99, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_MachLoop_jet.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_Skiing_red", "path": "sequences/Skiing_red", "startFrame": 1, "endFrame": 69, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Skiing_red.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_Skydiving", "path": "sequences/Skydiving", "startFrame": 1, "endFrame": 196, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Skydiving.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_airboard_1", "path": "sequences/airboard_1", "startFrame": 1, "endFrame": 425, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airboard_1.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_airplane_landing", "path": "sequences/airplane_landing", "startFrame": 1, "endFrame": 81, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airplane_landing.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_airtable_3", "path": "sequences/airtable_3", "startFrame": 1, "endFrame": 482, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airtable_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_1", "path": "sequences/basketball_1", "startFrame": 1, "endFrame": 282, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_1.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_2", "path": "sequences/basketball_2", "startFrame": 1, "endFrame": 102, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_3", "path": "sequences/basketball_3", "startFrame": 1, "endFrame": 421, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_6", "path": "sequences/basketball_6", "startFrame": 1, "endFrame": 224, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_7", "path": "sequences/basketball_7", "startFrame": 1, "endFrame": 240, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_7.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_basketball_player", "path": "sequences/basketball_player", "startFrame": 1, "endFrame": 369, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_player.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_basketball_player_2", "path": "sequences/basketball_player_2", "startFrame": 1, "endFrame": 437, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_player_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_beach_flipback_person", "path": "sequences/beach_flipback_person", "startFrame": 1, "endFrame": 61, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_beach_flipback_person.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_bee", "path": "sequences/bee", "startFrame": 1, "endFrame": 45, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bee.txt", "object_class": "insect", 'occlusion': False},
|
||||
{"name": "nfs_biker_acrobat", "path": "sequences/biker_acrobat", "startFrame": 1, "endFrame": 128, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_acrobat.txt", "object_class": "bicycle", 'occlusion': False},
|
||||
{"name": "nfs_biker_all_1", "path": "sequences/biker_all_1", "startFrame": 1, "endFrame": 113, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_all_1.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_biker_head_2", "path": "sequences/biker_head_2", "startFrame": 1, "endFrame": 132, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_head_2.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_biker_head_3", "path": "sequences/biker_head_3", "startFrame": 1, "endFrame": 254, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_head_3.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_biker_upper_body", "path": "sequences/biker_upper_body", "startFrame": 1, "endFrame": 194, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_upper_body.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_biker_whole_body", "path": "sequences/biker_whole_body", "startFrame": 1, "endFrame": 572, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_whole_body.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_billiard_2", "path": "sequences/billiard_2", "startFrame": 1, "endFrame": 604, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_3", "path": "sequences/billiard_3", "startFrame": 1, "endFrame": 698, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_6", "path": "sequences/billiard_6", "startFrame": 1, "endFrame": 771, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_7", "path": "sequences/billiard_7", "startFrame": 1, "endFrame": 724, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_7.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_8", "path": "sequences/billiard_8", "startFrame": 1, "endFrame": 778, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_8.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_bird_2", "path": "sequences/bird_2", "startFrame": 1, "endFrame": 476, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bird_2.txt", "object_class": "bird", 'occlusion': False},
|
||||
{"name": "nfs_book", "path": "sequences/book", "startFrame": 1, "endFrame": 288, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_book.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_bottle", "path": "sequences/bottle", "startFrame": 1, "endFrame": 2103, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bottle.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_bowling_1", "path": "sequences/bowling_1", "startFrame": 1, "endFrame": 303, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_1.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_2", "path": "sequences/bowling_2", "startFrame": 1, "endFrame": 710, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_2.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_3", "path": "sequences/bowling_3", "startFrame": 1, "endFrame": 271, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_3.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_6", "path": "sequences/bowling_6", "startFrame": 1, "endFrame": 260, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_bowling_ball", "path": "sequences/bowling_ball", "startFrame": 1, "endFrame": 275, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_ball.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bunny", "path": "sequences/bunny", "startFrame": 1, "endFrame": 705, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bunny.txt", "object_class": "mammal", 'occlusion': False},
|
||||
{"name": "nfs_car", "path": "sequences/car", "startFrame": 1, "endFrame": 2020, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car.txt", "object_class": "car", 'occlusion': True},
|
||||
{"name": "nfs_car_camaro", "path": "sequences/car_camaro", "startFrame": 1, "endFrame": 36, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_camaro.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_drifting", "path": "sequences/car_drifting", "startFrame": 1, "endFrame": 173, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_drifting.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_jumping", "path": "sequences/car_jumping", "startFrame": 1, "endFrame": 22, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_jumping.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_rc_rolling", "path": "sequences/car_rc_rolling", "startFrame": 1, "endFrame": 62, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_rc_rolling.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_rc_rotating", "path": "sequences/car_rc_rotating", "startFrame": 1, "endFrame": 80, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_rc_rotating.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_side", "path": "sequences/car_side", "startFrame": 1, "endFrame": 108, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_side.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_white", "path": "sequences/car_white", "startFrame": 1, "endFrame": 2063, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_white.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_cheetah", "path": "sequences/cheetah", "startFrame": 1, "endFrame": 167, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cheetah.txt", "object_class": "mammal", 'occlusion': True},
|
||||
{"name": "nfs_cup", "path": "sequences/cup", "startFrame": 1, "endFrame": 1281, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cup.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_cup_2", "path": "sequences/cup_2", "startFrame": 1, "endFrame": 182, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cup_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_dog", "path": "sequences/dog", "startFrame": 1, "endFrame": 1030, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dog_1", "path": "sequences/dog_1", "startFrame": 1, "endFrame": 168, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_1.txt", "object_class": "dog", 'occlusion': False},
|
||||
# {"name": "nfs_dog_2", "path": "sequences/dog_2", "startFrame": 1, "endFrame": 594, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_2.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dog_3", "path": "sequences/dog_3", "startFrame": 1, "endFrame": 200, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_3.txt", "object_class": "dog", 'occlusion': False},
|
||||
{"name": "nfs_dogs", "path": "sequences/dogs", "startFrame": 1, "endFrame": 198, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dogs.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dollar", "path": "sequences/dollar", "startFrame": 1, "endFrame": 1426, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dollar.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_drone", "path": "sequences/drone", "startFrame": 1, "endFrame": 70, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_drone.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_ducks_lake", "path": "sequences/ducks_lake", "startFrame": 1, "endFrame": 107, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_ducks_lake.txt", "object_class": "bird", 'occlusion': False},
|
||||
{"name": "nfs_exit", "path": "sequences/exit", "startFrame": 1, "endFrame": 359, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_exit.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_first", "path": "sequences/first", "startFrame": 1, "endFrame": 435, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_first.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_flower", "path": "sequences/flower", "startFrame": 1, "endFrame": 448, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_flower.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_footbal_skill", "path": "sequences/footbal_skill", "startFrame": 1, "endFrame": 131, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_footbal_skill.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_helicopter", "path": "sequences/helicopter", "startFrame": 1, "endFrame": 310, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_helicopter.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_horse_jumping", "path": "sequences/horse_jumping", "startFrame": 1, "endFrame": 117, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_horse_jumping.txt", "object_class": "horse", 'occlusion': True},
|
||||
{"name": "nfs_horse_running", "path": "sequences/horse_running", "startFrame": 1, "endFrame": 139, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_horse_running.txt", "object_class": "horse", 'occlusion': False},
|
||||
{"name": "nfs_iceskating_6", "path": "sequences/iceskating_6", "startFrame": 1, "endFrame": 603, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_iceskating_6.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_jellyfish_5", "path": "sequences/jellyfish_5", "startFrame": 1, "endFrame": 746, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_jellyfish_5.txt", "object_class": "invertebrate", 'occlusion': False},
|
||||
{"name": "nfs_kid_swing", "path": "sequences/kid_swing", "startFrame": 1, "endFrame": 169, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_kid_swing.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_motorcross", "path": "sequences/motorcross", "startFrame": 1, "endFrame": 39, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_motorcross.txt", "object_class": "vehicle", 'occlusion': True},
|
||||
{"name": "nfs_motorcross_kawasaki", "path": "sequences/motorcross_kawasaki", "startFrame": 1, "endFrame": 65, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_motorcross_kawasaki.txt", "object_class": "vehicle", 'occlusion': False},
|
||||
{"name": "nfs_parkour", "path": "sequences/parkour", "startFrame": 1, "endFrame": 58, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_parkour.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_person_scooter", "path": "sequences/person_scooter", "startFrame": 1, "endFrame": 413, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_person_scooter.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_pingpong_2", "path": "sequences/pingpong_2", "startFrame": 1, "endFrame": 1277, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_pingpong_7", "path": "sequences/pingpong_7", "startFrame": 1, "endFrame": 1290, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_7.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_pingpong_8", "path": "sequences/pingpong_8", "startFrame": 1, "endFrame": 296, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_8.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_purse", "path": "sequences/purse", "startFrame": 1, "endFrame": 968, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_purse.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_rubber", "path": "sequences/rubber", "startFrame": 1, "endFrame": 1328, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_rubber.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_running", "path": "sequences/running", "startFrame": 1, "endFrame": 677, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_running_100_m", "path": "sequences/running_100_m", "startFrame": 1, "endFrame": 313, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_100_m.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_running_100_m_2", "path": "sequences/running_100_m_2", "startFrame": 1, "endFrame": 337, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_100_m_2.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_running_2", "path": "sequences/running_2", "startFrame": 1, "endFrame": 363, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_1", "path": "sequences/shuffleboard_1", "startFrame": 1, "endFrame": 42, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_1.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_2", "path": "sequences/shuffleboard_2", "startFrame": 1, "endFrame": 41, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_4", "path": "sequences/shuffleboard_4", "startFrame": 1, "endFrame": 62, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_4.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_5", "path": "sequences/shuffleboard_5", "startFrame": 1, "endFrame": 32, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_5.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_6", "path": "sequences/shuffleboard_6", "startFrame": 1, "endFrame": 52, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_6.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_2", "path": "sequences/shuffletable_2", "startFrame": 1, "endFrame": 372, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_3", "path": "sequences/shuffletable_3", "startFrame": 1, "endFrame": 368, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_3.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_4", "path": "sequences/shuffletable_4", "startFrame": 1, "endFrame": 101, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_4.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_ski_long", "path": "sequences/ski_long", "startFrame": 1, "endFrame": 274, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_ski_long.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball", "path": "sequences/soccer_ball", "startFrame": 1, "endFrame": 163, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball_2", "path": "sequences/soccer_ball_2", "startFrame": 1, "endFrame": 1934, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball_3", "path": "sequences/soccer_ball_3", "startFrame": 1, "endFrame": 1381, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_player_2", "path": "sequences/soccer_player_2", "startFrame": 1, "endFrame": 475, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_player_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_soccer_player_3", "path": "sequences/soccer_player_3", "startFrame": 1, "endFrame": 319, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_player_3.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_stop_sign", "path": "sequences/stop_sign", "startFrame": 1, "endFrame": 302, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_stop_sign.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_suv", "path": "sequences/suv", "startFrame": 1, "endFrame": 2584, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_suv.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_tiger", "path": "sequences/tiger", "startFrame": 1, "endFrame": 1556, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_tiger.txt", "object_class": "mammal", 'occlusion': False},
|
||||
{"name": "nfs_walking", "path": "sequences/walking", "startFrame": 1, "endFrame": 555, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_walking.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_walking_3", "path": "sequences/walking_3", "startFrame": 1, "endFrame": 1427, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_walking_3.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_water_ski_2", "path": "sequences/water_ski_2", "startFrame": 1, "endFrame": 47, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_water_ski_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_yoyo", "path": "sequences/yoyo", "startFrame": 1, "endFrame": 67, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_yoyo.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_zebra_fish", "path": "sequences/zebra_fish", "startFrame": 1, "endFrame": 671, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_zebra_fish.txt", "object_class": "fish", 'occlusion': False},
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
@@ -0,0 +1,259 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class OTBDataset(BaseDataset):
|
||||
""" OTB-2015 dataset
|
||||
Publication:
|
||||
Object Tracking Benchmark
|
||||
Wu, Yi, Jongwoo Lim, and Ming-hsuan Yan
|
||||
TPAMI, 2015
|
||||
http://faculty.ucmerced.edu/mhyang/papers/pami15_tracking_benchmark.pdf
|
||||
Download the dataset from http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.otb_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
# anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_info['name'])
|
||||
|
||||
# NOTE: OTB has some weird annos which panda cannot handle
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy')
|
||||
|
||||
return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "Basketball", "path": "Basketball/img", "startFrame": 1, "endFrame": 725, "nz": 4, "ext": "jpg", "anno_path": "Basketball/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Biker", "path": "Biker/img", "startFrame": 1, "endFrame": 142, "nz": 4, "ext": "jpg", "anno_path": "Biker/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Bird1", "path": "Bird1/img", "startFrame": 1, "endFrame": 408, "nz": 4, "ext": "jpg", "anno_path": "Bird1/groundtruth_rect.txt",
|
||||
"object_class": "bird"},
|
||||
{"name": "Bird2", "path": "Bird2/img", "startFrame": 1, "endFrame": 99, "nz": 4, "ext": "jpg", "anno_path": "Bird2/groundtruth_rect.txt",
|
||||
"object_class": "bird"},
|
||||
{"name": "BlurBody", "path": "BlurBody/img", "startFrame": 1, "endFrame": 334, "nz": 4, "ext": "jpg", "anno_path": "BlurBody/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "BlurCar1", "path": "BlurCar1/img", "startFrame": 247, "endFrame": 988, "nz": 4, "ext": "jpg", "anno_path": "BlurCar1/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar2", "path": "BlurCar2/img", "startFrame": 1, "endFrame": 585, "nz": 4, "ext": "jpg", "anno_path": "BlurCar2/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar3", "path": "BlurCar3/img", "startFrame": 3, "endFrame": 359, "nz": 4, "ext": "jpg", "anno_path": "BlurCar3/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar4", "path": "BlurCar4/img", "startFrame": 18, "endFrame": 397, "nz": 4, "ext": "jpg", "anno_path": "BlurCar4/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurFace", "path": "BlurFace/img", "startFrame": 1, "endFrame": 493, "nz": 4, "ext": "jpg", "anno_path": "BlurFace/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "BlurOwl", "path": "BlurOwl/img", "startFrame": 1, "endFrame": 631, "nz": 4, "ext": "jpg", "anno_path": "BlurOwl/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Board", "path": "Board/img", "startFrame": 1, "endFrame": 698, "nz": 5, "ext": "jpg", "anno_path": "Board/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Bolt", "path": "Bolt/img", "startFrame": 1, "endFrame": 350, "nz": 4, "ext": "jpg", "anno_path": "Bolt/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Bolt2", "path": "Bolt2/img", "startFrame": 1, "endFrame": 293, "nz": 4, "ext": "jpg", "anno_path": "Bolt2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Box", "path": "Box/img", "startFrame": 1, "endFrame": 1161, "nz": 4, "ext": "jpg", "anno_path": "Box/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Boy", "path": "Boy/img", "startFrame": 1, "endFrame": 602, "nz": 4, "ext": "jpg", "anno_path": "Boy/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Car1", "path": "Car1/img", "startFrame": 1, "endFrame": 1020, "nz": 4, "ext": "jpg", "anno_path": "Car1/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car2", "path": "Car2/img", "startFrame": 1, "endFrame": 913, "nz": 4, "ext": "jpg", "anno_path": "Car2/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car24", "path": "Car24/img", "startFrame": 1, "endFrame": 3059, "nz": 4, "ext": "jpg", "anno_path": "Car24/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car4", "path": "Car4/img", "startFrame": 1, "endFrame": 659, "nz": 4, "ext": "jpg", "anno_path": "Car4/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "CarDark", "path": "CarDark/img", "startFrame": 1, "endFrame": 393, "nz": 4, "ext": "jpg", "anno_path": "CarDark/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "CarScale", "path": "CarScale/img", "startFrame": 1, "endFrame": 252, "nz": 4, "ext": "jpg", "anno_path": "CarScale/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "ClifBar", "path": "ClifBar/img", "startFrame": 1, "endFrame": 472, "nz": 4, "ext": "jpg", "anno_path": "ClifBar/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Coke", "path": "Coke/img", "startFrame": 1, "endFrame": 291, "nz": 4, "ext": "jpg", "anno_path": "Coke/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Couple", "path": "Couple/img", "startFrame": 1, "endFrame": 140, "nz": 4, "ext": "jpg", "anno_path": "Couple/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Coupon", "path": "Coupon/img", "startFrame": 1, "endFrame": 327, "nz": 4, "ext": "jpg", "anno_path": "Coupon/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Crossing", "path": "Crossing/img", "startFrame": 1, "endFrame": 120, "nz": 4, "ext": "jpg", "anno_path": "Crossing/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Crowds", "path": "Crowds/img", "startFrame": 1, "endFrame": 347, "nz": 4, "ext": "jpg", "anno_path": "Crowds/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dancer", "path": "Dancer/img", "startFrame": 1, "endFrame": 225, "nz": 4, "ext": "jpg", "anno_path": "Dancer/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dancer2", "path": "Dancer2/img", "startFrame": 1, "endFrame": 150, "nz": 4, "ext": "jpg", "anno_path": "Dancer2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "David", "path": "David/img", "startFrame": 300, "endFrame": 770, "nz": 4, "ext": "jpg", "anno_path": "David/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "David2", "path": "David2/img", "startFrame": 1, "endFrame": 537, "nz": 4, "ext": "jpg", "anno_path": "David2/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "David3", "path": "David3/img", "startFrame": 1, "endFrame": 252, "nz": 4, "ext": "jpg", "anno_path": "David3/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Deer", "path": "Deer/img", "startFrame": 1, "endFrame": 71, "nz": 4, "ext": "jpg", "anno_path": "Deer/groundtruth_rect.txt",
|
||||
"object_class": "mammal"},
|
||||
{"name": "Diving", "path": "Diving/img", "startFrame": 1, "endFrame": 215, "nz": 4, "ext": "jpg", "anno_path": "Diving/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dog", "path": "Dog/img", "startFrame": 1, "endFrame": 127, "nz": 4, "ext": "jpg", "anno_path": "Dog/groundtruth_rect.txt",
|
||||
"object_class": "dog"},
|
||||
{"name": "Dog1", "path": "Dog1/img", "startFrame": 1, "endFrame": 1350, "nz": 4, "ext": "jpg", "anno_path": "Dog1/groundtruth_rect.txt",
|
||||
"object_class": "dog"},
|
||||
{"name": "Doll", "path": "Doll/img", "startFrame": 1, "endFrame": 3872, "nz": 4, "ext": "jpg", "anno_path": "Doll/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "DragonBaby", "path": "DragonBaby/img", "startFrame": 1, "endFrame": 113, "nz": 4, "ext": "jpg", "anno_path": "DragonBaby/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Dudek", "path": "Dudek/img", "startFrame": 1, "endFrame": 1145, "nz": 4, "ext": "jpg", "anno_path": "Dudek/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "FaceOcc1", "path": "FaceOcc1/img", "startFrame": 1, "endFrame": 892, "nz": 4, "ext": "jpg", "anno_path": "FaceOcc1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "FaceOcc2", "path": "FaceOcc2/img", "startFrame": 1, "endFrame": 812, "nz": 4, "ext": "jpg", "anno_path": "FaceOcc2/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Fish", "path": "Fish/img", "startFrame": 1, "endFrame": 476, "nz": 4, "ext": "jpg", "anno_path": "Fish/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "FleetFace", "path": "FleetFace/img", "startFrame": 1, "endFrame": 707, "nz": 4, "ext": "jpg", "anno_path": "FleetFace/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Football", "path": "Football/img", "startFrame": 1, "endFrame": 362, "nz": 4, "ext": "jpg", "anno_path": "Football/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Football1", "path": "Football1/img", "startFrame": 1, "endFrame": 74, "nz": 4, "ext": "jpg", "anno_path": "Football1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman1", "path": "Freeman1/img", "startFrame": 1, "endFrame": 326, "nz": 4, "ext": "jpg", "anno_path": "Freeman1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman3", "path": "Freeman3/img", "startFrame": 1, "endFrame": 460, "nz": 4, "ext": "jpg", "anno_path": "Freeman3/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman4", "path": "Freeman4/img", "startFrame": 1, "endFrame": 283, "nz": 4, "ext": "jpg", "anno_path": "Freeman4/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Girl", "path": "Girl/img", "startFrame": 1, "endFrame": 500, "nz": 4, "ext": "jpg", "anno_path": "Girl/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Girl2", "path": "Girl2/img", "startFrame": 1, "endFrame": 1500, "nz": 4, "ext": "jpg", "anno_path": "Girl2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Gym", "path": "Gym/img", "startFrame": 1, "endFrame": 767, "nz": 4, "ext": "jpg", "anno_path": "Gym/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human2", "path": "Human2/img", "startFrame": 1, "endFrame": 1128, "nz": 4, "ext": "jpg", "anno_path": "Human2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human3", "path": "Human3/img", "startFrame": 1, "endFrame": 1698, "nz": 4, "ext": "jpg", "anno_path": "Human3/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human4_2", "path": "Human4/img", "startFrame": 1, "endFrame": 667, "nz": 4, "ext": "jpg", "anno_path": "Human4/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human4", "path": "Human4/img", "startFrame": 1, "endFrame": 667, "nz": 4, "ext": "jpg", "anno_path": "Human4/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human5", "path": "Human5/img", "startFrame": 1, "endFrame": 713, "nz": 4, "ext": "jpg", "anno_path": "Human5/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human6", "path": "Human6/img", "startFrame": 1, "endFrame": 792, "nz": 4, "ext": "jpg", "anno_path": "Human6/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human7", "path": "Human7/img", "startFrame": 1, "endFrame": 250, "nz": 4, "ext": "jpg", "anno_path": "Human7/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human8", "path": "Human8/img", "startFrame": 1, "endFrame": 128, "nz": 4, "ext": "jpg", "anno_path": "Human8/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human9", "path": "Human9/img", "startFrame": 1, "endFrame": 305, "nz": 4, "ext": "jpg", "anno_path": "Human9/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Ironman", "path": "Ironman/img", "startFrame": 1, "endFrame": 166, "nz": 4, "ext": "jpg", "anno_path": "Ironman/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Jogging", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
# {"name": "Jogging_1", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.1.txt",
|
||||
# "object_class": "person"},
|
||||
# {"name": "Jogging_2", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.2.txt",
|
||||
# "object_class": "person"},
|
||||
{"name": "Jump", "path": "Jump/img", "startFrame": 1, "endFrame": 122, "nz": 4, "ext": "jpg", "anno_path": "Jump/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Jumping", "path": "Jumping/img", "startFrame": 1, "endFrame": 313, "nz": 4, "ext": "jpg", "anno_path": "Jumping/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "KiteSurf", "path": "KiteSurf/img", "startFrame": 1, "endFrame": 84, "nz": 4, "ext": "jpg", "anno_path": "KiteSurf/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Lemming", "path": "Lemming/img", "startFrame": 1, "endFrame": 1336, "nz": 4, "ext": "jpg", "anno_path": "Lemming/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Liquor", "path": "Liquor/img", "startFrame": 1, "endFrame": 1741, "nz": 4, "ext": "jpg", "anno_path": "Liquor/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Man", "path": "Man/img", "startFrame": 1, "endFrame": 134, "nz": 4, "ext": "jpg", "anno_path": "Man/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Matrix", "path": "Matrix/img", "startFrame": 1, "endFrame": 100, "nz": 4, "ext": "jpg", "anno_path": "Matrix/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Mhyang", "path": "Mhyang/img", "startFrame": 1, "endFrame": 1490, "nz": 4, "ext": "jpg", "anno_path": "Mhyang/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "MotorRolling", "path": "MotorRolling/img", "startFrame": 1, "endFrame": 164, "nz": 4, "ext": "jpg", "anno_path": "MotorRolling/groundtruth_rect.txt",
|
||||
"object_class": "vehicle"},
|
||||
{"name": "MountainBike", "path": "MountainBike/img", "startFrame": 1, "endFrame": 228, "nz": 4, "ext": "jpg", "anno_path": "MountainBike/groundtruth_rect.txt",
|
||||
"object_class": "bicycle"},
|
||||
{"name": "Panda", "path": "Panda/img", "startFrame": 1, "endFrame": 1000, "nz": 4, "ext": "jpg", "anno_path": "Panda/groundtruth_rect.txt",
|
||||
"object_class": "mammal"},
|
||||
{"name": "RedTeam", "path": "RedTeam/img", "startFrame": 1, "endFrame": 1918, "nz": 4, "ext": "jpg", "anno_path": "RedTeam/groundtruth_rect.txt",
|
||||
"object_class": "vehicle"},
|
||||
{"name": "Rubik", "path": "Rubik/img", "startFrame": 1, "endFrame": 1997, "nz": 4, "ext": "jpg", "anno_path": "Rubik/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Shaking", "path": "Shaking/img", "startFrame": 1, "endFrame": 365, "nz": 4, "ext": "jpg", "anno_path": "Shaking/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Singer1", "path": "Singer1/img", "startFrame": 1, "endFrame": 351, "nz": 4, "ext": "jpg", "anno_path": "Singer1/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Singer2", "path": "Singer2/img", "startFrame": 1, "endFrame": 366, "nz": 4, "ext": "jpg", "anno_path": "Singer2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skater", "path": "Skater/img", "startFrame": 1, "endFrame": 160, "nz": 4, "ext": "jpg", "anno_path": "Skater/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skater2", "path": "Skater2/img", "startFrame": 1, "endFrame": 435, "nz": 4, "ext": "jpg", "anno_path": "Skater2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating1", "path": "Skating1/img", "startFrame": 1, "endFrame": 400, "nz": 4, "ext": "jpg", "anno_path": "Skating1/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2_1", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2_2", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skiing", "path": "Skiing/img", "startFrame": 1, "endFrame": 81, "nz": 4, "ext": "jpg", "anno_path": "Skiing/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Soccer", "path": "Soccer/img", "startFrame": 1, "endFrame": 392, "nz": 4, "ext": "jpg", "anno_path": "Soccer/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Subway", "path": "Subway/img", "startFrame": 1, "endFrame": 175, "nz": 4, "ext": "jpg", "anno_path": "Subway/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Surfer", "path": "Surfer/img", "startFrame": 1, "endFrame": 376, "nz": 4, "ext": "jpg", "anno_path": "Surfer/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Suv", "path": "Suv/img", "startFrame": 1, "endFrame": 945, "nz": 4, "ext": "jpg", "anno_path": "Suv/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Sylvester", "path": "Sylvester/img", "startFrame": 1, "endFrame": 1345, "nz": 4, "ext": "jpg", "anno_path": "Sylvester/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Tiger1", "path": "Tiger1/img", "startFrame": 1, "endFrame": 354, "nz": 4, "ext": "jpg", "anno_path": "Tiger1/groundtruth_rect.txt", "initOmit": 5,
|
||||
"object_class": "other"},
|
||||
{"name": "Tiger2", "path": "Tiger2/img", "startFrame": 1, "endFrame": 365, "nz": 4, "ext": "jpg", "anno_path": "Tiger2/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Toy", "path": "Toy/img", "startFrame": 1, "endFrame": 271, "nz": 4, "ext": "jpg", "anno_path": "Toy/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Trans", "path": "Trans/img", "startFrame": 1, "endFrame": 124, "nz": 4, "ext": "jpg", "anno_path": "Trans/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Trellis", "path": "Trellis/img", "startFrame": 1, "endFrame": 569, "nz": 4, "ext": "jpg", "anno_path": "Trellis/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Twinnings", "path": "Twinnings/img", "startFrame": 1, "endFrame": 472, "nz": 4, "ext": "jpg", "anno_path": "Twinnings/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Vase", "path": "Vase/img", "startFrame": 1, "endFrame": 271, "nz": 4, "ext": "jpg", "anno_path": "Vase/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Walking", "path": "Walking/img", "startFrame": 1, "endFrame": 412, "nz": 4, "ext": "jpg", "anno_path": "Walking/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Walking2", "path": "Walking2/img", "startFrame": 1, "endFrame": 500, "nz": 4, "ext": "jpg", "anno_path": "Walking2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Woman", "path": "Woman/img", "startFrame": 1, "endFrame": 597, "nz": 4, "ext": "jpg", "anno_path": "Woman/groundtruth_rect.txt",
|
||||
"object_class": "person"}
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
@@ -0,0 +1,183 @@
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from itertools import product
|
||||
from collections import OrderedDict
|
||||
from lib.test.evaluation import Sequence, Tracker
|
||||
import torch
|
||||
|
||||
|
||||
def _save_tracker_output(seq: Sequence, tracker: Tracker, output: dict):
|
||||
"""Saves the output of the tracker."""
|
||||
|
||||
if not os.path.exists(tracker.results_dir):
|
||||
print("create tracking result dir:", tracker.results_dir)
|
||||
os.makedirs(tracker.results_dir)
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
if not os.path.exists(os.path.join(tracker.results_dir, seq.dataset)):
|
||||
os.makedirs(os.path.join(tracker.results_dir, seq.dataset))
|
||||
'''2021.1.5 create new folder for these two datasets'''
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.dataset, seq.name)
|
||||
else:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.name)
|
||||
|
||||
def save_bb(file, data):
|
||||
tracked_bb = np.array(data).astype(int)
|
||||
np.savetxt(file, tracked_bb, delimiter='\t', fmt='%d')
|
||||
|
||||
def save_time(file, data):
|
||||
exec_times = np.array(data).astype(float)
|
||||
np.savetxt(file, exec_times, delimiter='\t', fmt='%f')
|
||||
|
||||
def save_score(file, data):
|
||||
scores = np.array(data).astype(float)
|
||||
np.savetxt(file, scores, delimiter='\t', fmt='%.2f')
|
||||
|
||||
def _convert_dict(input_dict):
|
||||
data_dict = {}
|
||||
for elem in input_dict:
|
||||
for k, v in elem.items():
|
||||
if k in data_dict.keys():
|
||||
data_dict[k].append(v)
|
||||
else:
|
||||
data_dict[k] = [v, ]
|
||||
return data_dict
|
||||
|
||||
for key, data in output.items():
|
||||
# If data is empty
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if key == 'target_bbox':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}.txt'.format(base_results_path, obj_id)
|
||||
save_bb(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
save_bb(bbox_file, data)
|
||||
|
||||
if key == 'all_boxes':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}_all_boxes.txt'.format(base_results_path, obj_id)
|
||||
save_bb(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
bbox_file = '{}_all_boxes.txt'.format(base_results_path)
|
||||
save_bb(bbox_file, data)
|
||||
|
||||
if key == 'all_scores':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}_all_scores.txt'.format(base_results_path, obj_id)
|
||||
save_score(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
print("saving scores...")
|
||||
bbox_file = '{}_all_scores.txt'.format(base_results_path)
|
||||
save_score(bbox_file, data)
|
||||
|
||||
elif key == 'time':
|
||||
if isinstance(data[0], dict):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
timings_file = '{}_{}_time.txt'.format(base_results_path, obj_id)
|
||||
save_time(timings_file, d)
|
||||
else:
|
||||
timings_file = '{}_time.txt'.format(base_results_path)
|
||||
save_time(timings_file, data)
|
||||
|
||||
|
||||
def run_sequence(seq: Sequence, tracker: Tracker, debug=False, num_gpu=8):
|
||||
"""Runs a tracker on a sequence."""
|
||||
'''2021.1.2 Add multiple gpu support'''
|
||||
try:
|
||||
worker_name = multiprocessing.current_process().name
|
||||
worker_id = int(worker_name[worker_name.find('-') + 1:]) - 1
|
||||
gpu_id = worker_id % num_gpu
|
||||
torch.cuda.set_device(gpu_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
def _results_exist():
|
||||
if seq.object_ids is None:
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.dataset, seq.name)
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
else:
|
||||
bbox_file = '{}/{}.txt'.format(tracker.results_dir, seq.name)
|
||||
return os.path.isfile(bbox_file)
|
||||
else:
|
||||
bbox_files = ['{}/{}_{}.txt'.format(tracker.results_dir, seq.name, obj_id) for obj_id in seq.object_ids]
|
||||
missing = [not os.path.isfile(f) for f in bbox_files]
|
||||
return sum(missing) == 0
|
||||
|
||||
if _results_exist() and not debug:
|
||||
print('FPS: {}'.format(-1))
|
||||
return
|
||||
|
||||
print('Tracker: {} {} {} , Sequence: {}'.format(tracker.name, tracker.parameter_name, tracker.run_id, seq.name))
|
||||
|
||||
if debug:
|
||||
output = tracker.run_sequence(seq, debug=debug)
|
||||
else:
|
||||
try:
|
||||
output = tracker.run_sequence(seq, debug=debug)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return
|
||||
|
||||
sys.stdout.flush()
|
||||
|
||||
if isinstance(output['time'][0], (dict, OrderedDict)):
|
||||
exec_time = sum([sum(times.values()) for times in output['time']])
|
||||
num_frames = len(output['time'])
|
||||
else:
|
||||
exec_time = sum(output['time'])
|
||||
num_frames = len(output['time'])
|
||||
|
||||
print('FPS: {}'.format(num_frames / exec_time))
|
||||
|
||||
if not debug:
|
||||
_save_tracker_output(seq, tracker, output)
|
||||
|
||||
|
||||
def run_dataset(dataset, trackers, debug=False, threads=0, num_gpus=8):
|
||||
"""Runs a list of trackers on a dataset.
|
||||
args:
|
||||
dataset: List of Sequence instances, forming a dataset.
|
||||
trackers: List of Tracker instances.
|
||||
debug: Debug level.
|
||||
threads: Number of threads to use (default 0).
|
||||
"""
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
print('Evaluating {:4d} trackers on {:5d} sequences'.format(len(trackers), len(dataset)))
|
||||
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
if threads == 0:
|
||||
mode = 'sequential'
|
||||
else:
|
||||
mode = 'parallel'
|
||||
|
||||
if mode == 'sequential':
|
||||
for seq in dataset:
|
||||
for tracker_info in trackers:
|
||||
run_sequence(seq, tracker_info, debug=debug)
|
||||
elif mode == 'parallel':
|
||||
param_list = [(seq, tracker_info, debug, num_gpus) for seq, tracker_info in product(dataset, trackers)]
|
||||
with multiprocessing.Pool(processes=threads) as pool:
|
||||
pool.starmap(run_sequence, param_list)
|
||||
print('Done')
|
||||
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
import glob
|
||||
import six
|
||||
|
||||
|
||||
class TC128CEDataset(BaseDataset):
|
||||
"""
|
||||
TC-128 Dataset (78 newly added sequences)
|
||||
modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tc128_path
|
||||
self.anno_files = sorted(glob.glob(
|
||||
os.path.join(self.base_path, '*/*_gt.txt')))
|
||||
"""filter the newly added sequences (_ce)"""
|
||||
self.anno_files = [s for s in self.anno_files if "_ce" in s]
|
||||
self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
|
||||
self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
|
||||
# valid frame range for each sequence
|
||||
self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.seq_names])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
if isinstance(sequence_name, six.string_types):
|
||||
if not sequence_name in self.seq_names:
|
||||
raise Exception('Sequence {} not found.'.format(sequence_name))
|
||||
index = self.seq_names.index(sequence_name)
|
||||
# load valid frame range
|
||||
frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
|
||||
img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
|
||||
|
||||
# load annotations
|
||||
anno = np.loadtxt(self.anno_files[index], delimiter=',')
|
||||
assert len(img_files) == len(anno)
|
||||
assert anno.shape[1] == 4
|
||||
|
||||
# return img_files, anno
|
||||
return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
@@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
import glob
|
||||
import six
|
||||
|
||||
|
||||
class TC128Dataset(BaseDataset):
|
||||
"""
|
||||
TC-128 Dataset
|
||||
modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tc128_path
|
||||
self.anno_files = sorted(glob.glob(
|
||||
os.path.join(self.base_path, '*/*_gt.txt')))
|
||||
self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
|
||||
self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
|
||||
# valid frame range for each sequence
|
||||
self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.seq_names])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
if isinstance(sequence_name, six.string_types):
|
||||
if not sequence_name in self.seq_names:
|
||||
raise Exception('Sequence {} not found.'.format(sequence_name))
|
||||
index = self.seq_names.index(sequence_name)
|
||||
# load valid frame range
|
||||
frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
|
||||
img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
|
||||
|
||||
# load annotations
|
||||
anno = np.loadtxt(self.anno_files[index], delimiter=',')
|
||||
assert len(img_files) == len(anno)
|
||||
assert anno.shape[1] == 4
|
||||
|
||||
# return img_files, anno
|
||||
return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text, load_str
|
||||
|
||||
############
|
||||
# current 00000492.png of test_015_Sord_video_Q01_done is damaged and replaced by a copy of 00000491.png
|
||||
############
|
||||
|
||||
|
||||
class TNL2kDataset(BaseDataset):
|
||||
"""
|
||||
TNL2k test set
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tnl2k_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
# class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
text_dsp_path = '{}/{}/language.txt'.format(self.base_path, sequence_name)
|
||||
text_dsp = load_str(text_dsp_path)
|
||||
|
||||
frames_path = '{}/{}/imgs'.format(self.base_path, sequence_name)
|
||||
frames_list = [f for f in os.listdir(frames_path)]
|
||||
frames_list = sorted(frames_list)
|
||||
frames_list = ['{}/{}'.format(frames_path, frame_i) for frame_i in frames_list]
|
||||
|
||||
# target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'tnl2k', ground_truth_rect.reshape(-1, 4), text_dsp=text_dsp)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = []
|
||||
for seq in os.listdir(self.base_path):
|
||||
if os.path.isdir(os.path.join(self.base_path, seq)):
|
||||
sequence_list.append(seq)
|
||||
|
||||
return sequence_list
|
||||
@@ -0,0 +1,291 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
import time
|
||||
import cv2 as cv
|
||||
|
||||
from lib.utils.lmdb_utils import decode_img
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
|
||||
def trackerlist(name: str, parameter_name: str, dataset_name: str, run_ids = None, display_name: str = None,
|
||||
result_only=False):
|
||||
"""Generate list of trackers.
|
||||
args:
|
||||
name: Name of tracking method.
|
||||
parameter_name: Name of parameter file.
|
||||
run_ids: A single or list of run_ids.
|
||||
display_name: Name to be displayed in the result plots.
|
||||
"""
|
||||
if run_ids is None or isinstance(run_ids, int):
|
||||
run_ids = [run_ids]
|
||||
return [Tracker(name, parameter_name, dataset_name, run_id, display_name, result_only) for run_id in run_ids]
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Wraps the tracker for evaluation and running purposes.
|
||||
args:
|
||||
name: Name of tracking method.
|
||||
parameter_name: Name of parameter file.
|
||||
run_id: The run id.
|
||||
display_name: Name to be displayed in the result plots.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameter_name: str, dataset_name: str, run_id: int = None, display_name: str = None,
|
||||
result_only=False):
|
||||
assert run_id is None or isinstance(run_id, int)
|
||||
|
||||
self.name = name
|
||||
self.parameter_name = parameter_name
|
||||
self.dataset_name = dataset_name
|
||||
self.run_id = run_id
|
||||
self.display_name = display_name
|
||||
|
||||
env = env_settings()
|
||||
if self.run_id is None:
|
||||
self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name)
|
||||
else:
|
||||
self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id)
|
||||
if result_only:
|
||||
self.results_dir = '{}/{}'.format(env.results_path, self.name)
|
||||
|
||||
tracker_module_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'..', 'tracker', '%s.py' % self.name))
|
||||
if os.path.isfile(tracker_module_abspath):
|
||||
tracker_module = importlib.import_module('lib.test.tracker.{}'.format(self.name))
|
||||
self.tracker_class = tracker_module.get_tracker_class()
|
||||
else:
|
||||
self.tracker_class = None
|
||||
|
||||
def create_tracker(self, params):
|
||||
tracker = self.tracker_class(params, self.dataset_name)
|
||||
return tracker
|
||||
|
||||
def run_sequence(self, seq, debug=None):
|
||||
"""Run tracker on sequence.
|
||||
args:
|
||||
seq: Sequence to run the tracker on.
|
||||
visualization: Set visualization flag (None means default value specified in the parameters).
|
||||
debug: Set debug level (None means default value specified in the parameters).
|
||||
multiobj_mode: Which mode to use for multiple objects.
|
||||
"""
|
||||
params = self.get_parameters()
|
||||
|
||||
debug_ = debug
|
||||
if debug is None:
|
||||
debug_ = getattr(params, 'debug', 0)
|
||||
|
||||
params.debug = debug_
|
||||
|
||||
# Get init information
|
||||
init_info = seq.init_info()
|
||||
|
||||
tracker = self.create_tracker(params)
|
||||
|
||||
output = self._track_sequence(tracker, seq, init_info)
|
||||
return output
|
||||
|
||||
def _track_sequence(self, tracker, seq, init_info):
|
||||
# Define outputs
|
||||
# Each field in output is a list containing tracker prediction for each frame.
|
||||
|
||||
# In case of single object tracking mode:
|
||||
# target_bbox[i] is the predicted bounding box for frame i
|
||||
# time[i] is the processing time for frame i
|
||||
|
||||
# In case of multi object tracking mode:
|
||||
# target_bbox[i] is an OrderedDict, where target_bbox[i][obj_id] is the predicted box for target obj_id in
|
||||
# frame i
|
||||
# time[i] is either the processing time for frame i, or an OrderedDict containing processing times for each
|
||||
# object in frame i
|
||||
|
||||
output = {'target_bbox': [],
|
||||
'time': []}
|
||||
if tracker.params.save_all_boxes:
|
||||
output['all_boxes'] = []
|
||||
output['all_scores'] = []
|
||||
|
||||
def _store_outputs(tracker_out: dict, defaults=None):
|
||||
defaults = {} if defaults is None else defaults
|
||||
for key in output.keys():
|
||||
val = tracker_out.get(key, defaults.get(key, None))
|
||||
if key in tracker_out or val is not None:
|
||||
output[key].append(val)
|
||||
|
||||
# Initialize
|
||||
image = self._read_image(seq.frames[0])
|
||||
|
||||
start_time = time.time()
|
||||
out = tracker.initialize(image, init_info)
|
||||
if out is None:
|
||||
out = {}
|
||||
|
||||
prev_output = OrderedDict(out)
|
||||
init_default = {'target_bbox': init_info.get('init_bbox'),
|
||||
'time': time.time() - start_time}
|
||||
if tracker.params.save_all_boxes:
|
||||
init_default['all_boxes'] = out['all_boxes']
|
||||
init_default['all_scores'] = out['all_scores']
|
||||
|
||||
_store_outputs(out, init_default)
|
||||
|
||||
for frame_num, frame_path in enumerate(seq.frames[1:], start=1):
|
||||
image = self._read_image(frame_path)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
info = seq.frame_info(frame_num)
|
||||
info['previous_output'] = prev_output
|
||||
|
||||
if len(seq.ground_truth_rect) > 1:
|
||||
info['gt_bbox'] = seq.ground_truth_rect[frame_num]
|
||||
out = tracker.track(image, info)
|
||||
prev_output = OrderedDict(out)
|
||||
_store_outputs(out, {'time': time.time() - start_time})
|
||||
|
||||
for key in ['target_bbox', 'all_boxes', 'all_scores']:
|
||||
if key in output and len(output[key]) <= 1:
|
||||
output.pop(key)
|
||||
|
||||
return output
|
||||
|
||||
def run_video(self, videofilepath, optional_box=None, debug=None, visdom_info=None, save_results=False):
|
||||
"""Run the tracker with the vieofile.
|
||||
args:
|
||||
debug: Debug level.
|
||||
"""
|
||||
|
||||
params = self.get_parameters()
|
||||
|
||||
debug_ = debug
|
||||
if debug is None:
|
||||
debug_ = getattr(params, 'debug', 0)
|
||||
params.debug = debug_
|
||||
|
||||
params.tracker_name = self.name
|
||||
params.param_name = self.parameter_name
|
||||
# self._init_visdom(visdom_info, debug_)
|
||||
|
||||
multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))
|
||||
|
||||
if multiobj_mode == 'default':
|
||||
tracker = self.create_tracker(params)
|
||||
|
||||
elif multiobj_mode == 'parallel':
|
||||
tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
|
||||
else:
|
||||
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))
|
||||
|
||||
assert os.path.isfile(videofilepath), "Invalid param {}".format(videofilepath)
|
||||
", videofilepath must be a valid videofile"
|
||||
|
||||
output_boxes = []
|
||||
|
||||
cap = cv.VideoCapture(videofilepath)
|
||||
display_name = 'Display: ' + tracker.params.tracker_name
|
||||
cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
|
||||
cv.resizeWindow(display_name, 960, 720)
|
||||
success, frame = cap.read()
|
||||
cv.imshow(display_name, frame)
|
||||
|
||||
def _build_init_info(box):
|
||||
return {'init_bbox': box}
|
||||
|
||||
if success is not True:
|
||||
print("Read frame from {} failed.".format(videofilepath))
|
||||
exit(-1)
|
||||
if optional_box is not None:
|
||||
assert isinstance(optional_box, (list, tuple))
|
||||
assert len(optional_box) == 4, "valid box's foramt is [x,y,w,h]"
|
||||
tracker.initialize(frame, _build_init_info(optional_box))
|
||||
output_boxes.append(optional_box)
|
||||
else:
|
||||
while True:
|
||||
# cv.waitKey()
|
||||
frame_disp = frame.copy()
|
||||
|
||||
cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL,
|
||||
1.5, (0, 0, 0), 1)
|
||||
|
||||
x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
|
||||
init_state = [x, y, w, h]
|
||||
tracker.initialize(frame, _build_init_info(init_state))
|
||||
output_boxes.append(init_state)
|
||||
break
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
frame_disp = frame.copy()
|
||||
|
||||
# Draw box
|
||||
out = tracker.track(frame)
|
||||
state = [int(s) for s in out['target_bbox']]
|
||||
output_boxes.append(state)
|
||||
|
||||
cv.rectangle(frame_disp, (state[0], state[1]), (state[2] + state[0], state[3] + state[1]),
|
||||
(0, 255, 0), 5)
|
||||
|
||||
font_color = (0, 0, 0)
|
||||
cv.putText(frame_disp, 'Tracking!', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
cv.putText(frame_disp, 'Press r to reset', (20, 55), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
cv.putText(frame_disp, 'Press q to quit', (20, 80), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
|
||||
# Display the resulting frame
|
||||
cv.imshow(display_name, frame_disp)
|
||||
key = cv.waitKey(1)
|
||||
if key == ord('q'):
|
||||
break
|
||||
elif key == ord('r'):
|
||||
ret, frame = cap.read()
|
||||
frame_disp = frame.copy()
|
||||
|
||||
cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
|
||||
(0, 0, 0), 1)
|
||||
|
||||
cv.imshow(display_name, frame_disp)
|
||||
x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
|
||||
init_state = [x, y, w, h]
|
||||
tracker.initialize(frame, _build_init_info(init_state))
|
||||
output_boxes.append(init_state)
|
||||
|
||||
# When everything done, release the capture
|
||||
cap.release()
|
||||
cv.destroyAllWindows()
|
||||
|
||||
if save_results:
|
||||
if not os.path.exists(self.results_dir):
|
||||
os.makedirs(self.results_dir)
|
||||
video_name = Path(videofilepath).stem
|
||||
base_results_path = os.path.join(self.results_dir, 'video_{}'.format(video_name))
|
||||
|
||||
tracked_bb = np.array(output_boxes).astype(int)
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
np.savetxt(bbox_file, tracked_bb, delimiter='\t', fmt='%d')
|
||||
|
||||
|
||||
def get_parameters(self):
|
||||
"""Get parameters."""
|
||||
param_module = importlib.import_module('lib.test.parameter.{}'.format(self.name))
|
||||
params = param_module.parameters(self.parameter_name)
|
||||
return params
|
||||
|
||||
def _read_image(self, image_file: str):
|
||||
if isinstance(image_file, str):
|
||||
im = cv.imread(image_file)
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
elif isinstance(image_file, list) and len(image_file) == 2:
|
||||
return decode_img(image_file[0], image_file[1])
|
||||
else:
|
||||
raise ValueError("type of image_file should be str or list")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class TrackingNetDataset(BaseDataset):
|
||||
""" TrackingNet test set.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.trackingnet_path
|
||||
|
||||
sets = 'TEST'
|
||||
if not isinstance(sets, (list, tuple)):
|
||||
if sets == 'TEST':
|
||||
sets = ['TEST']
|
||||
elif sets == 'TRAIN':
|
||||
sets = ['TRAIN_{}'.format(i) for i in range(5)]
|
||||
|
||||
self.sequence_list = self._list_sequences(self.base_path, sets)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(set, seq_name) for set, seq_name in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, set, sequence_name):
|
||||
anno_path = '{}/{}/anno/{}.txt'.format(self.base_path, set, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
frames_path = '{}/{}/frames/{}'.format(self.base_path, set, sequence_name)
|
||||
frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
|
||||
frame_list.sort(key=lambda f: int(f[:-4]))
|
||||
frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
|
||||
|
||||
return Sequence(sequence_name, frames_list, 'trackingnet', ground_truth_rect.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _list_sequences(self, root, set_ids):
|
||||
sequence_list = []
|
||||
|
||||
for s in set_ids:
|
||||
anno_dir = os.path.join(root, s, "anno")
|
||||
sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
|
||||
|
||||
sequence_list += sequences_cur_set
|
||||
|
||||
return sequence_list
|
||||
@@ -0,0 +1,298 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class UAVDataset(BaseDataset):
|
||||
""" UAV123 dataset.
|
||||
Publication:
|
||||
A Benchmark and Simulator for UAV Tracking.
|
||||
Matthias Mueller, Neil Smith and Bernard Ghanem
|
||||
ECCV, 2016
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2016/A%20Benchmark%20and%20Simulator%20for%20UAV%20Tracking.pdf
|
||||
Download the dataset from https://ivul.kaust.edu.sa/Pages/pub-benchmark-simulator-uav.aspx
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.uav_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
# return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
return Sequence(sequence_info['name'][4:], frames, 'uav', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "uav_bike1", "path": "data_seq/UAV123/bike1", "startFrame": 1, "endFrame": 3085, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike1.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bike2", "path": "data_seq/UAV123/bike2", "startFrame": 1, "endFrame": 553, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike2.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bike3", "path": "data_seq/UAV123/bike3", "startFrame": 1, "endFrame": 433, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike3.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bird1_1", "path": "data_seq/UAV123/bird1", "startFrame": 1, "endFrame": 253, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_1.txt", "object_class": "bird"},
|
||||
{"name": "uav_bird1_2", "path": "data_seq/UAV123/bird1", "startFrame": 775, "endFrame": 1477, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_2.txt", "object_class": "bird"},
|
||||
{"name": "uav_bird1_3", "path": "data_seq/UAV123/bird1", "startFrame": 1573, "endFrame": 2437, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_3.txt", "object_class": "bird"},
|
||||
{"name": "uav_boat1", "path": "data_seq/UAV123/boat1", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat1.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat2", "path": "data_seq/UAV123/boat2", "startFrame": 1, "endFrame": 799, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat2.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat3", "path": "data_seq/UAV123/boat3", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat3.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat4", "path": "data_seq/UAV123/boat4", "startFrame": 1, "endFrame": 553, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat4.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat5", "path": "data_seq/UAV123/boat5", "startFrame": 1, "endFrame": 505, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat5.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat6", "path": "data_seq/UAV123/boat6", "startFrame": 1, "endFrame": 805, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat6.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat7", "path": "data_seq/UAV123/boat7", "startFrame": 1, "endFrame": 535, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat7.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat8", "path": "data_seq/UAV123/boat8", "startFrame": 1, "endFrame": 685, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat8.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat9", "path": "data_seq/UAV123/boat9", "startFrame": 1, "endFrame": 1399, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat9.txt", "object_class": "vessel"},
|
||||
{"name": "uav_building1", "path": "data_seq/UAV123/building1", "startFrame": 1, "endFrame": 469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building1.txt", "object_class": "other"},
|
||||
{"name": "uav_building2", "path": "data_seq/UAV123/building2", "startFrame": 1, "endFrame": 577, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building2.txt", "object_class": "other"},
|
||||
{"name": "uav_building3", "path": "data_seq/UAV123/building3", "startFrame": 1, "endFrame": 829, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building3.txt", "object_class": "other"},
|
||||
{"name": "uav_building4", "path": "data_seq/UAV123/building4", "startFrame": 1, "endFrame": 787, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building4.txt", "object_class": "other"},
|
||||
{"name": "uav_building5", "path": "data_seq/UAV123/building5", "startFrame": 1, "endFrame": 481, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building5.txt", "object_class": "other"},
|
||||
{"name": "uav_car1_1", "path": "data_seq/UAV123/car1", "startFrame": 1, "endFrame": 751, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_2", "path": "data_seq/UAV123/car1", "startFrame": 751, "endFrame": 1627, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_3", "path": "data_seq/UAV123/car1", "startFrame": 1627, "endFrame": 2629, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_3.txt", "object_class": "car"},
|
||||
{"name": "uav_car10", "path": "data_seq/UAV123/car10", "startFrame": 1, "endFrame": 1405, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car10.txt", "object_class": "car"},
|
||||
{"name": "uav_car11", "path": "data_seq/UAV123/car11", "startFrame": 1, "endFrame": 337, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car11.txt", "object_class": "car"},
|
||||
{"name": "uav_car12", "path": "data_seq/UAV123/car12", "startFrame": 1, "endFrame": 499, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car12.txt", "object_class": "car"},
|
||||
{"name": "uav_car13", "path": "data_seq/UAV123/car13", "startFrame": 1, "endFrame": 415, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car13.txt", "object_class": "car"},
|
||||
{"name": "uav_car14", "path": "data_seq/UAV123/car14", "startFrame": 1, "endFrame": 1327, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car14.txt", "object_class": "car"},
|
||||
{"name": "uav_car15", "path": "data_seq/UAV123/car15", "startFrame": 1, "endFrame": 469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car15.txt", "object_class": "car"},
|
||||
{"name": "uav_car16_1", "path": "data_seq/UAV123/car16", "startFrame": 1, "endFrame": 415, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car16_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car16_2", "path": "data_seq/UAV123/car16", "startFrame": 415, "endFrame": 1993, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car16_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car17", "path": "data_seq/UAV123/car17", "startFrame": 1, "endFrame": 1057, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car17.txt", "object_class": "car"},
|
||||
{"name": "uav_car18", "path": "data_seq/UAV123/car18", "startFrame": 1, "endFrame": 1207, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car18.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_s", "path": "data_seq/UAV123/car1_s", "startFrame": 1, "endFrame": 1475, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car2", "path": "data_seq/UAV123/car2", "startFrame": 1, "endFrame": 1321, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car2.txt", "object_class": "car"},
|
||||
{"name": "uav_car2_s", "path": "data_seq/UAV123/car2_s", "startFrame": 1, "endFrame": 320, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car2_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car3", "path": "data_seq/UAV123/car3", "startFrame": 1, "endFrame": 1717, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car3.txt", "object_class": "car"},
|
||||
{"name": "uav_car3_s", "path": "data_seq/UAV123/car3_s", "startFrame": 1, "endFrame": 1300, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car3_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car4", "path": "data_seq/UAV123/car4", "startFrame": 1, "endFrame": 1345, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car4.txt", "object_class": "car"},
|
||||
{"name": "uav_car4_s", "path": "data_seq/UAV123/car4_s", "startFrame": 1, "endFrame": 830, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car4_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car5", "path": "data_seq/UAV123/car5", "startFrame": 1, "endFrame": 745, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car5.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_1", "path": "data_seq/UAV123/car6", "startFrame": 1, "endFrame": 487, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_2", "path": "data_seq/UAV123/car6", "startFrame": 487, "endFrame": 1807, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_3", "path": "data_seq/UAV123/car6", "startFrame": 1807, "endFrame": 2953, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_3.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_4", "path": "data_seq/UAV123/car6", "startFrame": 2953, "endFrame": 3925, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_4.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_5", "path": "data_seq/UAV123/car6", "startFrame": 3925, "endFrame": 4861, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_5.txt", "object_class": "car"},
|
||||
{"name": "uav_car7", "path": "data_seq/UAV123/car7", "startFrame": 1, "endFrame": 1033, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car7.txt", "object_class": "car"},
|
||||
{"name": "uav_car8_1", "path": "data_seq/UAV123/car8", "startFrame": 1, "endFrame": 1357, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car8_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car8_2", "path": "data_seq/UAV123/car8", "startFrame": 1357, "endFrame": 2575, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car8_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car9", "path": "data_seq/UAV123/car9", "startFrame": 1, "endFrame": 1879, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car9.txt", "object_class": "car"},
|
||||
{"name": "uav_group1_1", "path": "data_seq/UAV123/group1", "startFrame": 1, "endFrame": 1333, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_2", "path": "data_seq/UAV123/group1", "startFrame": 1333, "endFrame": 2515, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_3", "path": "data_seq/UAV123/group1", "startFrame": 2515, "endFrame": 3925, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_4", "path": "data_seq/UAV123/group1", "startFrame": 3925, "endFrame": 4873, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_4.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_1", "path": "data_seq/UAV123/group2", "startFrame": 1, "endFrame": 907, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_2", "path": "data_seq/UAV123/group2", "startFrame": 907, "endFrame": 1771, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_3", "path": "data_seq/UAV123/group2", "startFrame": 1771, "endFrame": 2683, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_1", "path": "data_seq/UAV123/group3", "startFrame": 1, "endFrame": 1567, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_2", "path": "data_seq/UAV123/group3", "startFrame": 1567, "endFrame": 2827, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_3", "path": "data_seq/UAV123/group3", "startFrame": 2827, "endFrame": 4369, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_4", "path": "data_seq/UAV123/group3", "startFrame": 4369, "endFrame": 5527, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_4.txt", "object_class": "person"},
|
||||
{"name": "uav_person1", "path": "data_seq/UAV123/person1", "startFrame": 1, "endFrame": 799, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person1.txt", "object_class": "person"},
|
||||
{"name": "uav_person10", "path": "data_seq/UAV123/person10", "startFrame": 1, "endFrame": 1021, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person10.txt", "object_class": "person"},
|
||||
{"name": "uav_person11", "path": "data_seq/UAV123/person11", "startFrame": 1, "endFrame": 721, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person11.txt", "object_class": "person"},
|
||||
{"name": "uav_person12_1", "path": "data_seq/UAV123/person12", "startFrame": 1, "endFrame": 601, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person12_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person12_2", "path": "data_seq/UAV123/person12", "startFrame": 601, "endFrame": 1621, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person12_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person13", "path": "data_seq/UAV123/person13", "startFrame": 1, "endFrame": 883, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person13.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_1", "path": "data_seq/UAV123/person14", "startFrame": 1, "endFrame": 847, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person14_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_2", "path": "data_seq/UAV123/person14", "startFrame": 847, "endFrame": 1813, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person14_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_3", "path": "data_seq/UAV123/person14", "startFrame": 1813, "endFrame": 2923,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person14_3.txt", "object_class": "person"},
|
||||
{"name": "uav_person15", "path": "data_seq/UAV123/person15", "startFrame": 1, "endFrame": 1339, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person15.txt", "object_class": "person"},
|
||||
{"name": "uav_person16", "path": "data_seq/UAV123/person16", "startFrame": 1, "endFrame": 1147, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person16.txt", "object_class": "person"},
|
||||
{"name": "uav_person17_1", "path": "data_seq/UAV123/person17", "startFrame": 1, "endFrame": 1501, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person17_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person17_2", "path": "data_seq/UAV123/person17", "startFrame": 1501, "endFrame": 2347,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person17_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person18", "path": "data_seq/UAV123/person18", "startFrame": 1, "endFrame": 1393, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person18.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_1", "path": "data_seq/UAV123/person19", "startFrame": 1, "endFrame": 1243, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person19_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_2", "path": "data_seq/UAV123/person19", "startFrame": 1243, "endFrame": 2791,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person19_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_3", "path": "data_seq/UAV123/person19", "startFrame": 2791, "endFrame": 4357,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person19_3.txt", "object_class": "person"},
|
||||
{"name": "uav_person1_s", "path": "data_seq/UAV123/person1_s", "startFrame": 1, "endFrame": 1600, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person1_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_1", "path": "data_seq/UAV123/person2", "startFrame": 1, "endFrame": 1189, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_2", "path": "data_seq/UAV123/person2", "startFrame": 1189, "endFrame": 2623, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person20", "path": "data_seq/UAV123/person20", "startFrame": 1, "endFrame": 1783, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person20.txt", "object_class": "person"},
|
||||
{"name": "uav_person21", "path": "data_seq/UAV123/person21", "startFrame": 1, "endFrame": 487, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person21.txt", "object_class": "person"},
|
||||
{"name": "uav_person22", "path": "data_seq/UAV123/person22", "startFrame": 1, "endFrame": 199, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person22.txt", "object_class": "person"},
|
||||
{"name": "uav_person23", "path": "data_seq/UAV123/person23", "startFrame": 1, "endFrame": 397, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person23.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_s", "path": "data_seq/UAV123/person2_s", "startFrame": 1, "endFrame": 250, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person3", "path": "data_seq/UAV123/person3", "startFrame": 1, "endFrame": 643, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person3.txt", "object_class": "person"},
|
||||
{"name": "uav_person3_s", "path": "data_seq/UAV123/person3_s", "startFrame": 1, "endFrame": 505, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person3_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person4_1", "path": "data_seq/UAV123/person4", "startFrame": 1, "endFrame": 1501, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person4_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person4_2", "path": "data_seq/UAV123/person4", "startFrame": 1501, "endFrame": 2743, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person4_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person5_1", "path": "data_seq/UAV123/person5", "startFrame": 1, "endFrame": 877, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person5_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person5_2", "path": "data_seq/UAV123/person5", "startFrame": 877, "endFrame": 2101, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person5_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person6", "path": "data_seq/UAV123/person6", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person6.txt", "object_class": "person"},
|
||||
{"name": "uav_person7_1", "path": "data_seq/UAV123/person7", "startFrame": 1, "endFrame": 1249, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person7_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person7_2", "path": "data_seq/UAV123/person7", "startFrame": 1249, "endFrame": 2065, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person7_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person8_1", "path": "data_seq/UAV123/person8", "startFrame": 1, "endFrame": 1075, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person8_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person8_2", "path": "data_seq/UAV123/person8", "startFrame": 1075, "endFrame": 1525, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person8_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person9", "path": "data_seq/UAV123/person9", "startFrame": 1, "endFrame": 661, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person9.txt", "object_class": "person"},
|
||||
{"name": "uav_truck1", "path": "data_seq/UAV123/truck1", "startFrame": 1, "endFrame": 463, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck1.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck2", "path": "data_seq/UAV123/truck2", "startFrame": 1, "endFrame": 385, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck2.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck3", "path": "data_seq/UAV123/truck3", "startFrame": 1, "endFrame": 535, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck3.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck4_1", "path": "data_seq/UAV123/truck4", "startFrame": 1, "endFrame": 577, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck4_1.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck4_2", "path": "data_seq/UAV123/truck4", "startFrame": 577, "endFrame": 1261, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck4_2.txt", "object_class": "truck"},
|
||||
{"name": "uav_uav1_1", "path": "data_seq/UAV123/uav1", "startFrame": 1, "endFrame": 1555, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_1.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav1_2", "path": "data_seq/UAV123/uav1", "startFrame": 1555, "endFrame": 2377, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_2.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav1_3", "path": "data_seq/UAV123/uav1", "startFrame": 2473, "endFrame": 3469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_3.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav2", "path": "data_seq/UAV123/uav2", "startFrame": 1, "endFrame": 133, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav2.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav3", "path": "data_seq/UAV123/uav3", "startFrame": 1, "endFrame": 265, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav3.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav4", "path": "data_seq/UAV123/uav4", "startFrame": 1, "endFrame": 157, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav4.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav5", "path": "data_seq/UAV123/uav5", "startFrame": 1, "endFrame": 139, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav5.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav6", "path": "data_seq/UAV123/uav6", "startFrame": 1, "endFrame": 109, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav6.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav7", "path": "data_seq/UAV123/uav7", "startFrame": 1, "endFrame": 373, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav7.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav8", "path": "data_seq/UAV123/uav8", "startFrame": 1, "endFrame": 301, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav8.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_wakeboard1", "path": "data_seq/UAV123/wakeboard1", "startFrame": 1, "endFrame": 421, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard1.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard10", "path": "data_seq/UAV123/wakeboard10", "startFrame": 1, "endFrame": 469,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/wakeboard10.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard2", "path": "data_seq/UAV123/wakeboard2", "startFrame": 1, "endFrame": 733, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard2.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard3", "path": "data_seq/UAV123/wakeboard3", "startFrame": 1, "endFrame": 823, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard3.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard4", "path": "data_seq/UAV123/wakeboard4", "startFrame": 1, "endFrame": 697, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard4.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard5", "path": "data_seq/UAV123/wakeboard5", "startFrame": 1, "endFrame": 1675, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard5.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard6", "path": "data_seq/UAV123/wakeboard6", "startFrame": 1, "endFrame": 1165, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard6.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard7", "path": "data_seq/UAV123/wakeboard7", "startFrame": 1, "endFrame": 199, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard7.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard8", "path": "data_seq/UAV123/wakeboard8", "startFrame": 1, "endFrame": 1543, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard8.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard9", "path": "data_seq/UAV123/wakeboard9", "startFrame": 1, "endFrame": 355, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard9.txt", "object_class": "person"}
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
@@ -0,0 +1,349 @@
|
||||
from typing import Union, TextIO
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
from lib.test.evaluation.data import SequenceList, BaseDataset, Sequence
|
||||
|
||||
|
||||
class VOTDataset(BaseDataset):
|
||||
"""
|
||||
VOT2018 dataset
|
||||
|
||||
Publication:
|
||||
The sixth Visual Object Tracking VOT2018 challenge results.
|
||||
Matej Kristan, Ales Leonardis, Jiri Matas, Michael Felsberg, Roman Pfugfelder, Luka Cehovin Zajc, Tomas Vojir,
|
||||
Goutam Bhat, Alan Lukezic et al.
|
||||
ECCV, 2018
|
||||
https://prints.vicos.si/publications/365
|
||||
|
||||
Download the dataset from http://www.votchallenge.net/vot2018/dataset.html
|
||||
"""
|
||||
def __init__(self, year=18):
|
||||
super().__init__()
|
||||
self.year = year
|
||||
if year == 18:
|
||||
self.base_path = self.env_settings.vot18_path
|
||||
elif year == 20:
|
||||
self.base_path = self.env_settings.vot20_path
|
||||
elif year == 22:
|
||||
self.base_path = self.env_settings.vot22_path
|
||||
self.sequence_list = self._get_sequence_list(year)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
sequence_path = sequence_name
|
||||
nz = 8
|
||||
ext = 'jpg'
|
||||
start_frame = 1
|
||||
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
if self.year == 18 or self.year == 22:
|
||||
try:
|
||||
ground_truth_rect = np.loadtxt(str(anno_path), dtype=np.float64)
|
||||
except:
|
||||
ground_truth_rect = np.loadtxt(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
end_frame = ground_truth_rect.shape[0]
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/color/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext)
|
||||
for frame_num in range(start_frame, end_frame+1)]
|
||||
|
||||
# Convert gt
|
||||
if ground_truth_rect.shape[1] > 4:
|
||||
gt_x_all = ground_truth_rect[:, [0, 2, 4, 6]]
|
||||
gt_y_all = ground_truth_rect[:, [1, 3, 5, 7]]
|
||||
|
||||
x1 = np.amin(gt_x_all, 1).reshape(-1,1)
|
||||
y1 = np.amin(gt_y_all, 1).reshape(-1,1)
|
||||
x2 = np.amax(gt_x_all, 1).reshape(-1,1)
|
||||
y2 = np.amax(gt_y_all, 1).reshape(-1,1)
|
||||
|
||||
ground_truth_rect = np.concatenate((x1, y1, x2-x1, y2-y1), 1)
|
||||
|
||||
elif self.year == 20:
|
||||
ground_truth_rect = read_file(str(anno_path))
|
||||
ground_truth_rect = np.array(ground_truth_rect, dtype=np.float64)
|
||||
end_frame = ground_truth_rect.shape[0]
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/color/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path,
|
||||
frame=frame_num, nz=nz, ext=ext)
|
||||
for frame_num in range(start_frame, end_frame + 1)]
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return Sequence(sequence_name, frames, 'vot', ground_truth_rect)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self, year):
|
||||
if year == 18:
|
||||
sequence_list= ['ants1',
|
||||
'ants3',
|
||||
'bag',
|
||||
'ball1',
|
||||
'ball2',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'blanket',
|
||||
'bmx',
|
||||
'bolt1',
|
||||
'bolt2',
|
||||
'book',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'crossing',
|
||||
'dinosaur',
|
||||
'drone_across',
|
||||
'drone_flip',
|
||||
'drone1',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'fish3',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'glove',
|
||||
'godfather',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'leaves',
|
||||
'matrix',
|
||||
'motocross1',
|
||||
'motocross2',
|
||||
'nature',
|
||||
'pedestrian1',
|
||||
'rabbit',
|
||||
'racing',
|
||||
'road',
|
||||
'shaking',
|
||||
'sheep',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'tiger',
|
||||
'traffic',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
elif year == 20:
|
||||
|
||||
sequence_list= ['agility',
|
||||
'ants1',
|
||||
'ball2',
|
||||
'ball3',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'bolt1',
|
||||
'book',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'dinosaur',
|
||||
'dribble',
|
||||
'drone1',
|
||||
'drone_across',
|
||||
'drone_flip',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'glove',
|
||||
'godfather',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'hand02',
|
||||
'hand2',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'lamb',
|
||||
'leaves',
|
||||
'marathon',
|
||||
'matrix',
|
||||
'monkey',
|
||||
'motocross1',
|
||||
'nature',
|
||||
'polo',
|
||||
'rabbit',
|
||||
'rabbit2',
|
||||
'road',
|
||||
'rowing',
|
||||
'shaking',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'surfing',
|
||||
'tiger',
|
||||
'wheel',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
elif year == 22:
|
||||
sequence_list= ['agility',
|
||||
'animal',
|
||||
'ants1',
|
||||
'bag',
|
||||
'ball2',
|
||||
'ball3',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'birds2',
|
||||
'bolt1',
|
||||
'book',
|
||||
'bubble',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'dinosaur',
|
||||
'diver',
|
||||
'drone1',
|
||||
'drone_across',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'hand2',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'kangaroo',
|
||||
'lamb',
|
||||
'leaves',
|
||||
'marathon',
|
||||
'matrix',
|
||||
'monkey',
|
||||
'motocross1',
|
||||
'nature',
|
||||
'polo',
|
||||
'rabbit',
|
||||
'rabbit2',
|
||||
'rowing',
|
||||
'shaking',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'snake',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'surfing',
|
||||
'tennis',
|
||||
'tiger',
|
||||
'wheel',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
def parse(string):
|
||||
"""
|
||||
parse string to the appropriate region format and return region object
|
||||
"""
|
||||
from vot.region.shapes import Rectangle, Polygon, Mask
|
||||
|
||||
|
||||
if string[0] == 'm':
|
||||
# input is a mask - decode it
|
||||
m_, offset_, region = create_mask_from_string(string[1:].split(','))
|
||||
# return Mask(m_, offset=offset_)
|
||||
return region
|
||||
else:
|
||||
# input is not a mask - check if special, rectangle or polygon
|
||||
raise NotImplementedError
|
||||
print('Unknown region format.')
|
||||
return None
|
||||
|
||||
|
||||
def read_file(fp: Union[str, TextIO]):
|
||||
if isinstance(fp, str):
|
||||
with open(fp) as file:
|
||||
lines = file.readlines()
|
||||
else:
|
||||
lines = fp.readlines()
|
||||
|
||||
regions = []
|
||||
# iterate over all lines in the file
|
||||
for i, line in enumerate(lines):
|
||||
regions.append(parse(line.strip()))
|
||||
return regions
|
||||
|
||||
|
||||
def create_mask_from_string(mask_encoding):
|
||||
"""
|
||||
mask_encoding: a string in the following format: x0, y0, w, h, RLE
|
||||
output: mask, offset
|
||||
mask: 2-D binary mask, size defined in the mask encoding
|
||||
offset: (x, y) offset of the mask in the image coordinates
|
||||
"""
|
||||
elements = [int(el) for el in mask_encoding]
|
||||
tl_x, tl_y, region_w, region_h = elements[:4]
|
||||
rle = np.array([el for el in elements[4:]], dtype=np.int32)
|
||||
|
||||
# create mask from RLE within target region
|
||||
mask = rle_to_mask(rle, region_w, region_h)
|
||||
region = [tl_x, tl_y, region_w, region_h]
|
||||
|
||||
return mask, (tl_x, tl_y), region
|
||||
|
||||
@jit(nopython=True)
|
||||
def rle_to_mask(rle, width, height):
|
||||
"""
|
||||
rle: input rle mask encoding
|
||||
each evenly-indexed element represents number of consecutive 0s
|
||||
each oddly indexed element represents number of consecutive 1s
|
||||
width and height are dimensions of the mask
|
||||
output: 2-D binary mask
|
||||
"""
|
||||
# allocate list of zeros
|
||||
v = [0] * (width * height)
|
||||
|
||||
# set id of the last different element to the beginning of the vector
|
||||
idx_ = 0
|
||||
for i in range(len(rle)):
|
||||
if i % 2 != 0:
|
||||
# write as many 1s as RLE says (zeros are already in the vector)
|
||||
for j in range(rle[i]):
|
||||
v[idx_+j] = 1
|
||||
idx_ += rle[i]
|
||||
@@ -0,0 +1,30 @@
|
||||
from lib.test.utils import TrackerParams
|
||||
import os
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.config.artrack.config import cfg, update_config_from_file
|
||||
|
||||
|
||||
def parameters(yaml_name: str):
|
||||
params = TrackerParams()
|
||||
prj_dir = env_settings().prj_dir
|
||||
save_dir = env_settings().save_dir
|
||||
# update default config from yaml file
|
||||
yaml_file = os.path.join(prj_dir, 'experiments/artrack/%s.yaml' % yaml_name)
|
||||
update_config_from_file(yaml_file)
|
||||
params.cfg = cfg
|
||||
print("test config: ", cfg)
|
||||
|
||||
# template and search region
|
||||
params.template_factor = cfg.TEST.TEMPLATE_FACTOR
|
||||
params.template_size = cfg.TEST.TEMPLATE_SIZE
|
||||
params.search_factor = cfg.TEST.SEARCH_FACTOR
|
||||
params.search_size = cfg.TEST.SEARCH_SIZE
|
||||
|
||||
# Network checkpoint path
|
||||
params.checkpoint = os.path.join(save_dir, "checkpoints/train/artrack/%s/ARTrack_ep%04d.pth.tar" %
|
||||
(yaml_name, cfg.TEST.EPOCH))
|
||||
|
||||
# whether to save boxes from all queries
|
||||
params.save_all_boxes = False
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,30 @@
|
||||
from lib.test.utils import TrackerParams
|
||||
import os
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.config.artrack_seq.config import cfg, update_config_from_file
|
||||
|
||||
|
||||
def parameters(yaml_name: str):
|
||||
params = TrackerParams()
|
||||
prj_dir = env_settings().prj_dir
|
||||
save_dir = env_settings().save_dir
|
||||
# update default config from yaml file
|
||||
yaml_file = os.path.join(prj_dir, 'experiments/artrack_seq/%s.yaml' % yaml_name)
|
||||
update_config_from_file(yaml_file)
|
||||
params.cfg = cfg
|
||||
print("test config: ", cfg)
|
||||
|
||||
# template and search region
|
||||
params.template_factor = cfg.TEST.TEMPLATE_FACTOR
|
||||
params.template_size = cfg.TEST.TEMPLATE_SIZE
|
||||
params.search_factor = cfg.TEST.SEARCH_FACTOR
|
||||
params.search_size = cfg.TEST.SEARCH_SIZE
|
||||
|
||||
# Network checkpoint path
|
||||
params.checkpoint = os.path.join(save_dir, "checkpoints/train/artrack_seq/%s/ARTrackSeq_ep%04d.pth.tar" %
|
||||
(yaml_name, cfg.TEST.EPOCH))
|
||||
|
||||
# whether to save boxes from all queries
|
||||
params.save_all_boxes = False
|
||||
|
||||
return params
|
||||
@@ -0,0 +1,225 @@
|
||||
import math
|
||||
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.test.tracker.basetracker import BaseTracker
|
||||
import torch
|
||||
|
||||
from lib.test.tracker.vis_utils import gen_visualization
|
||||
from lib.test.utils.hann import hann2d
|
||||
from lib.train.data.processing_utils import sample_target
|
||||
# for debug
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from lib.test.tracker.data_utils import Preprocessor
|
||||
from lib.utils.box_ops import clip_box
|
||||
from lib.utils.ce_utils import generate_mask_cond
|
||||
import random
|
||||
|
||||
class RandomErasing(object):
|
||||
def __init__(self, EPSILON=0.5, sl=0.02, sh=0.33, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
|
||||
self.EPSILON = EPSILON
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
if random.uniform(0, 1) > self.EPSILON:
|
||||
return img
|
||||
|
||||
for attempt in range(100):
|
||||
print(img.size())
|
||||
area = img.size()[1] * img.size()[2]
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.size()[2] and h < img.size()[1]:
|
||||
x1 = random.randint(0, img.size()[1] - h)
|
||||
y1 = random.randint(0, img.size()[2] - w)
|
||||
if img.size()[0] == 3:
|
||||
# img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
# img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
# img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
||||
# img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w))
|
||||
else:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
# img[0, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, h, w))
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ARTrack(BaseTracker):
|
||||
def __init__(self, params, dataset_name):
|
||||
super(ARTrack, self).__init__(params)
|
||||
network = build_artrack(params.cfg, training=False)
|
||||
print(self.params.checkpoint)
|
||||
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
|
||||
self.cfg = params.cfg
|
||||
self.bins = self.cfg.MODEL.BINS
|
||||
self.network = network.cuda()
|
||||
self.network.eval()
|
||||
self.preprocessor = Preprocessor()
|
||||
self.state = None
|
||||
self.range = self.cfg.MODEL.RANGE
|
||||
|
||||
self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
|
||||
# motion constrain
|
||||
self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()
|
||||
|
||||
# for debug
|
||||
self.debug = params.debug
|
||||
self.use_visdom = params.debug
|
||||
self.frame_id = 0
|
||||
self.erase = RandomErasing()
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
self.save_dir = "debug"
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
else:
|
||||
# self.add_hook()
|
||||
self._init_visdom(None, 1)
|
||||
# for save boxes from all queries
|
||||
self.save_all_boxes = params.save_all_boxes
|
||||
self.z_dict1 = {}
|
||||
|
||||
def initialize(self, image, info: dict):
|
||||
# forward the template once
|
||||
|
||||
z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
|
||||
output_sz=self.params.template_size)#output_sz=self.params.template_size
|
||||
self.z_patch_arr = z_patch_arr
|
||||
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
|
||||
with torch.no_grad():
|
||||
self.z_dict1 = template
|
||||
|
||||
self.box_mask_z = None
|
||||
#if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
# template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor,
|
||||
# template.tensors.device).squeeze(1)
|
||||
# self.box_mask_z = generate_mask_cond(self.cfg, 1, template.tensors.device, template_bbox)
|
||||
|
||||
# save states
|
||||
self.state = info['init_bbox']
|
||||
self.frame_id = 0
|
||||
if self.save_all_boxes:
|
||||
'''save all predicted boxes'''
|
||||
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
|
||||
return {"all_boxes": all_boxes_save}
|
||||
|
||||
def track(self, image, info: dict = None):
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
H, W, _ = image.shape
|
||||
self.frame_id += 1
|
||||
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
|
||||
output_sz=self.params.search_size) # (x1, y1, w, h)
|
||||
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
x_dict = search
|
||||
# merge the template and the search
|
||||
# run the transformer
|
||||
out_dict = self.network.forward(
|
||||
template=self.z_dict1.tensors, search=x_dict.tensors)
|
||||
|
||||
# add hann windows
|
||||
# pred_score_map = out_dict['score_map']
|
||||
# response = self.output_window * pred_score_map
|
||||
# pred_boxes = self.network.box_head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map'])
|
||||
# pred_boxes = pred_boxes.view(-1, 4)
|
||||
|
||||
pred_boxes = out_dict['seqs'][:, 0:4] / (self.bins - 1) - magic_num
|
||||
pred_boxes = pred_boxes.view(-1, 4).mean(dim=0)
|
||||
pred_new = pred_boxes
|
||||
pred_new[2] = pred_boxes[2] - pred_boxes[0]
|
||||
pred_new[3] = pred_boxes[3] - pred_boxes[1]
|
||||
pred_new[0] = pred_boxes[0] + pred_boxes[2]/2
|
||||
pred_new[1] = pred_boxes[1] + pred_boxes[3]/2
|
||||
|
||||
pred_boxes = (pred_new * self.params.search_size / resize_factor).tolist()
|
||||
|
||||
# Baseline: Take the mean of all pred boxes as the final result
|
||||
#pred_box = (pred_boxes.mean(
|
||||
# dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
||||
# get the final box result
|
||||
self.state = clip_box(self.map_box_back(pred_boxes, resize_factor), H, W, margin=10)
|
||||
|
||||
# for debug
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
x1, y1, w, h = self.state
|
||||
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(image_BGR, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
||||
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
|
||||
cv2.imwrite(save_path, image_BGR)
|
||||
else:
|
||||
self.visdom.register((image, info['gt_bbox'].tolist(), self.state), 'Tracking', 1, 'Tracking')
|
||||
|
||||
self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region')
|
||||
self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template')
|
||||
self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map')
|
||||
self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map_hann')
|
||||
|
||||
if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']:
|
||||
removed_indexes_s = out_dict['removed_indexes_s']
|
||||
removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s]
|
||||
masked_search = gen_visualization(x_patch_arr, removed_indexes_s)
|
||||
self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search')
|
||||
|
||||
while self.pause_mode:
|
||||
if self.step:
|
||||
self.step = False
|
||||
break
|
||||
|
||||
if self.save_all_boxes:
|
||||
'''save all predictions'''
|
||||
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
|
||||
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
|
||||
return {"target_bbox": self.state,
|
||||
"all_boxes": all_boxes_save}
|
||||
else:
|
||||
return {"target_bbox": self.state}
|
||||
|
||||
def map_box_back(self, pred_box: list, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
#cx_real = cx + cx_prev
|
||||
#cy_real = cy + cy_prev
|
||||
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
||||
|
||||
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
||||
|
||||
def add_hook(self):
|
||||
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
|
||||
|
||||
for i in range(12):
|
||||
self.network.backbone.blocks[i].attn.register_forward_hook(
|
||||
# lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
)
|
||||
|
||||
self.enc_attn_weights = enc_attn_weights
|
||||
|
||||
|
||||
def get_tracker_class():
|
||||
return ARTrack
|
||||
@@ -0,0 +1,209 @@
|
||||
import math
|
||||
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
from lib.test.tracker.basetracker import BaseTracker
|
||||
import torch
|
||||
|
||||
from lib.test.tracker.vis_utils import gen_visualization
|
||||
from lib.test.utils.hann import hann2d
|
||||
from lib.train.data.processing_utils import sample_target, transform_image_to_crop
|
||||
# for debug
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from lib.test.tracker.data_utils import Preprocessor
|
||||
from lib.utils.box_ops import clip_box
|
||||
from lib.utils.ce_utils import generate_mask_cond
|
||||
|
||||
|
||||
class ARTrackSeq(BaseTracker):
|
||||
def __init__(self, params, dataset_name):
|
||||
super(ARTrackSeq, self).__init__(params)
|
||||
network = build_artrack_seq(params.cfg, training=False)
|
||||
print(self.params.checkpoint)
|
||||
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
|
||||
self.cfg = params.cfg
|
||||
self.bins = self.cfg.MODEL.BINS
|
||||
self.network = network.cuda()
|
||||
self.network.eval()
|
||||
self.preprocessor = Preprocessor()
|
||||
self.state = None
|
||||
|
||||
self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
|
||||
# motion constrain
|
||||
self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()
|
||||
|
||||
# for debug
|
||||
self.debug = params.debug
|
||||
self.use_visdom = params.debug
|
||||
self.frame_id = 0
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
self.save_dir = "debug"
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
else:
|
||||
# self.add_hook()
|
||||
self._init_visdom(None, 1)
|
||||
# for save boxes from all queries
|
||||
self.save_all_boxes = params.save_all_boxes
|
||||
self.z_dict1 = {}
|
||||
self.store_result = None
|
||||
self.save_all = 7
|
||||
self.x_feat = None
|
||||
self.update = None
|
||||
self.update_threshold = 5.0
|
||||
self.update_intervals = 1
|
||||
|
||||
def initialize(self, image, info: dict):
|
||||
# forward the template once
|
||||
self.x_feat = None
|
||||
|
||||
z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
|
||||
output_sz=self.params.template_size) # output_sz=self.params.template_size
|
||||
self.z_patch_arr = z_patch_arr
|
||||
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
|
||||
with torch.no_grad():
|
||||
self.z_dict1 = template
|
||||
|
||||
self.box_mask_z = None
|
||||
# if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
# template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor,
|
||||
# template.tensors.device).squeeze(1)
|
||||
# self.box_mask_z = generate_mask_cond(self.cfg, 1, template.tensors.device, template_bbox)
|
||||
|
||||
# save states
|
||||
self.state = info['init_bbox']
|
||||
self.store_result = [info['init_bbox'].copy()]
|
||||
for i in range(self.save_all - 1):
|
||||
self.store_result.append(info['init_bbox'].copy())
|
||||
self.frame_id = 0
|
||||
self.update = None
|
||||
if self.save_all_boxes:
|
||||
'''save all predicted boxes'''
|
||||
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
|
||||
return {"all_boxes": all_boxes_save}
|
||||
|
||||
def track(self, image, info: dict = None):
|
||||
H, W, _ = image.shape
|
||||
self.frame_id += 1
|
||||
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
|
||||
output_sz=self.params.search_size) # (x1, y1, w, h)
|
||||
for i in range(len(self.store_result)):
|
||||
box_temp = self.store_result[i].copy()
|
||||
box_out_i = transform_image_to_crop(torch.Tensor(self.store_result[i]), torch.Tensor(self.state),
|
||||
resize_factor,
|
||||
torch.Tensor([self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]),
|
||||
normalize=True)
|
||||
box_out_i[2] = box_out_i[2] + box_out_i[0]
|
||||
box_out_i[3] = box_out_i[3] + box_out_i[1]
|
||||
box_out_i = box_out_i.clamp(min=-0.5, max=1.5)
|
||||
box_out_i = (box_out_i + 0.5) * (self.bins - 1)
|
||||
if i == 0:
|
||||
seqs_out = box_out_i
|
||||
else:
|
||||
seqs_out = torch.cat((seqs_out, box_out_i), dim=-1)
|
||||
seqs_out = seqs_out.unsqueeze(0)
|
||||
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
|
||||
with torch.no_grad():
|
||||
x_dict = search
|
||||
# merge the template and the search
|
||||
# run the transformer
|
||||
out_dict = self.network.forward(
|
||||
template=self.z_dict1.tensors, search=x_dict.tensors,
|
||||
seq_input=seqs_out, stage="sequence", search_feature=self.x_feat, update=None)
|
||||
|
||||
self.x_feat = out_dict['x_feat']
|
||||
|
||||
pred_boxes = out_dict['seqs'][:, 0:4] / (self.bins - 1) - 0.5
|
||||
pred_boxes = pred_boxes.view(-1, 4).mean(dim=0)
|
||||
pred_new = pred_boxes
|
||||
pred_new[2] = pred_boxes[2] - pred_boxes[0]
|
||||
pred_new[3] = pred_boxes[3] - pred_boxes[1]
|
||||
pred_new[0] = pred_boxes[0] + pred_new[2] / 2
|
||||
pred_new[1] = pred_boxes[1] + pred_new[3] / 2
|
||||
pred_boxes = (pred_new * self.params.search_size / resize_factor).tolist()
|
||||
|
||||
# Baseline: Take the mean of all pred boxes as the final result
|
||||
# pred_box = (pred_boxes.mean(
|
||||
# dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
||||
# get the final box result
|
||||
self.state = clip_box(self.map_box_back(pred_boxes, resize_factor), H, W, margin=10)
|
||||
if len(self.store_result) < self.save_all:
|
||||
self.store_result.append(self.state.copy())
|
||||
else:
|
||||
for i in range(self.save_all):
|
||||
if i != self.save_all - 1:
|
||||
self.store_result[i] = self.store_result[i + 1]
|
||||
else:
|
||||
self.store_result[i] = self.state.copy()
|
||||
|
||||
# for debug
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
x1, y1, w, h = self.state
|
||||
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2)
|
||||
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
|
||||
cv2.imwrite(save_path, image_BGR)
|
||||
else:
|
||||
self.visdom.register((image, info['gt_bbox'].tolist(), self.state), 'Tracking', 1, 'Tracking')
|
||||
|
||||
self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region')
|
||||
self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template')
|
||||
self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map')
|
||||
self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap',
|
||||
1, 'score_map_hann')
|
||||
|
||||
if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']:
|
||||
removed_indexes_s = out_dict['removed_indexes_s']
|
||||
removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s]
|
||||
masked_search = gen_visualization(x_patch_arr, removed_indexes_s)
|
||||
self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search')
|
||||
|
||||
while self.pause_mode:
|
||||
if self.step:
|
||||
self.step = False
|
||||
break
|
||||
|
||||
if self.save_all_boxes:
|
||||
'''save all predictions'''
|
||||
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
|
||||
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
|
||||
return {"target_bbox": self.state,
|
||||
"all_boxes": all_boxes_save}
|
||||
else:
|
||||
return {"target_bbox": self.state}
|
||||
|
||||
def map_box_back(self, pred_box: list, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
# cx_real = cx + cx_prev
|
||||
# cy_real = cy + cy_prev
|
||||
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
||||
|
||||
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
||||
|
||||
def add_hook(self):
|
||||
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
|
||||
|
||||
for i in range(12):
|
||||
self.network.backbone.blocks[i].attn.register_forward_hook(
|
||||
# lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
)
|
||||
|
||||
self.enc_attn_weights = enc_attn_weights
|
||||
|
||||
|
||||
def get_tracker_class():
|
||||
return ARTrackSeq
|
||||
@@ -0,0 +1,89 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from _collections import OrderedDict
|
||||
|
||||
from lib.train.data.processing_utils import transform_image_to_crop
|
||||
from lib.vis.visdom_cus import Visdom
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
"""Base class for all trackers."""
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.visdom = None
|
||||
|
||||
def predicts_segmentation_mask(self):
|
||||
return False
|
||||
|
||||
def initialize(self, image, info: dict) -> dict:
|
||||
"""Overload this function in your tracker. This should initialize the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def track(self, image, info: dict = None) -> dict:
|
||||
"""Overload this function in your tracker. This should track in the frame and update the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def visdom_draw_tracking(self, image, box, segmentation=None):
|
||||
if isinstance(box, OrderedDict):
|
||||
box = [v for k, v in box.items()]
|
||||
else:
|
||||
box = (box,)
|
||||
if segmentation is None:
|
||||
self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
|
||||
else:
|
||||
self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')
|
||||
|
||||
def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'):
|
||||
# box_in: list [x1, y1, w, h], not normalized
|
||||
# box_extract: same as box_in
|
||||
# out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized
|
||||
if crop_type == 'template':
|
||||
crop_sz = torch.Tensor([self.params.template_size, self.params.template_size])
|
||||
elif crop_type == 'search':
|
||||
crop_sz = torch.Tensor([self.params.search_size, self.params.search_size])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
box_in = torch.tensor(box_in)
|
||||
if box_extract is None:
|
||||
box_extract = box_in
|
||||
else:
|
||||
box_extract = torch.tensor(box_extract)
|
||||
template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True)
|
||||
template_bbox = template_bbox.view(1, 1, 4).to(device)
|
||||
|
||||
return template_bbox
|
||||
|
||||
def _init_visdom(self, visdom_info, debug):
|
||||
visdom_info = {} if visdom_info is None else visdom_info
|
||||
self.pause_mode = False
|
||||
self.step = False
|
||||
self.next_seq = False
|
||||
if debug > 0 and visdom_info.get('use_visdom', True):
|
||||
try:
|
||||
self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
|
||||
visdom_info=visdom_info)
|
||||
|
||||
# # Show help
|
||||
# help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
|
||||
# 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
|
||||
# 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
|
||||
# 'block list.'
|
||||
# self.visdom.register(help_text, 'text', 1, 'Help')
|
||||
except:
|
||||
time.sleep(0.5)
|
||||
print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
|
||||
'!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')
|
||||
|
||||
def _visdom_ui_handler(self, data):
|
||||
if data['event_type'] == 'KeyPress':
|
||||
if data['key'] == ' ':
|
||||
self.pause_mode = not self.pause_mode
|
||||
|
||||
elif data['key'] == 'ArrowRight' and self.pause_mode:
|
||||
self.step = True
|
||||
|
||||
elif data['key'] == 'n':
|
||||
self.next_seq = True
|
||||
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from lib.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class Preprocessor(object):
|
||||
def __init__(self):
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
# Deal with the image patch
|
||||
img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
|
||||
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
||||
# Deal with the attention mask
|
||||
amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
|
||||
return NestedTensor(img_tensor_norm, amask_tensor)
|
||||
|
||||
|
||||
class PreprocessorX(object):
|
||||
def __init__(self):
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
# Deal with the image patch
|
||||
img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
|
||||
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
||||
# Deal with the attention mask
|
||||
amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
|
||||
return img_tensor_norm, amask_tensor
|
||||
|
||||
|
||||
class PreprocessorX_onnx(object):
|
||||
def __init__(self):
|
||||
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))
|
||||
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
"""img_arr: (H,W,3), amask_arr: (H,W)"""
|
||||
# Deal with the image patch
|
||||
img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2)
|
||||
img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W)
|
||||
# Deal with the attention mask
|
||||
amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W)
|
||||
return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool)
|
||||
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
############## used for visulize eliminated tokens #################
|
||||
def get_keep_indices(decisions):
|
||||
keep_indices = []
|
||||
for i in range(3):
|
||||
if i == 0:
|
||||
keep_indices.append(decisions[i])
|
||||
else:
|
||||
keep_indices.append(keep_indices[-1][decisions[i]])
|
||||
return keep_indices
|
||||
|
||||
|
||||
def gen_masked_tokens(tokens, indices, alpha=0.2):
|
||||
# indices = [i for i in range(196) if i not in indices]
|
||||
indices = indices[0].astype(int)
|
||||
tokens = tokens.copy()
|
||||
tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
|
||||
return tokens
|
||||
|
||||
|
||||
def recover_image(tokens, H, W, Hp, Wp, patch_size):
|
||||
# image: (C, 196, 16, 16)
|
||||
image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
|
||||
return image
|
||||
|
||||
|
||||
def pad_img(img):
|
||||
height, width, channels = img.shape
|
||||
im_bg = np.ones((height, width + 8, channels)) * 255
|
||||
im_bg[0:height, 0:width, :] = img
|
||||
return im_bg
|
||||
|
||||
|
||||
def gen_visualization(image, mask_indices, patch_size=16):
|
||||
# image [224, 224, 3]
|
||||
# mask_indices, list of masked token indices
|
||||
|
||||
# mask mask_indices need to cat
|
||||
# mask_indices = mask_indices[::-1]
|
||||
num_stages = len(mask_indices)
|
||||
for i in range(1, num_stages):
|
||||
mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)
|
||||
|
||||
# keep_indices = get_keep_indices(decisions)
|
||||
image = np.asarray(image)
|
||||
H, W, C = image.shape
|
||||
Hp, Wp = H // patch_size, W // patch_size
|
||||
image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)
|
||||
|
||||
stages = [
|
||||
recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
|
||||
for i in range(num_stages)
|
||||
]
|
||||
imgs = [image] + stages
|
||||
imgs = [pad_img(img) for img in imgs]
|
||||
viz = np.concatenate(imgs, axis=1)
|
||||
return viz
|
||||
@@ -0,0 +1 @@
|
||||
from .params import TrackerParams, FeatureParams, Choice
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
prj_path = osp.join(this_dir, '..', '..', '..')
|
||||
add_path(prj_path)
|
||||
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def hann1d(sz: int, centered = True) -> torch.Tensor:
|
||||
"""1D cosine window."""
|
||||
if centered:
|
||||
return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float()))
|
||||
w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float()))
|
||||
return torch.cat([w, w[1:sz-sz//2].flip((0,))])
|
||||
|
||||
|
||||
def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""2D cosine window."""
|
||||
return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""2D cosine window."""
|
||||
distance = torch.stack([ctr_point, sz-ctr_point], dim=0)
|
||||
max_distance, _ = distance.max(dim=0)
|
||||
|
||||
hann1d_x = hann1d(max_distance[0].item() * 2, centered)
|
||||
hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]]
|
||||
hann1d_y = hann1d(max_distance[1].item() * 2, centered)
|
||||
hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]]
|
||||
|
||||
return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
|
||||
def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""1D clipped cosine window."""
|
||||
|
||||
# Ensure that the difference is even
|
||||
effective_sz += (effective_sz - sz) % 2
|
||||
effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1)
|
||||
|
||||
pad = (sz - effective_sz) // 2
|
||||
|
||||
window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate')
|
||||
|
||||
if centered:
|
||||
return window
|
||||
else:
|
||||
mid = (sz / 2).int()
|
||||
window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3)
|
||||
return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2)
|
||||
|
||||
|
||||
def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor:
|
||||
if half:
|
||||
k = torch.arange(0, int(sz/2+1))
|
||||
else:
|
||||
k = torch.arange(-int((sz-1)/2), int(sz/2+1))
|
||||
return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2)
|
||||
|
||||
|
||||
def gauss_spatial(sz, sigma, center=0, end_pad=0):
|
||||
k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad)
|
||||
return torch.exp(-1.0/(2*sigma**2) * (k - center)**2)
|
||||
|
||||
|
||||
def label_function(sz: torch.Tensor, sigma: torch.Tensor):
|
||||
return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1)
|
||||
|
||||
def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)):
|
||||
"""The origin is in the middle of the image."""
|
||||
return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \
|
||||
gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
def cubic_spline_fourier(f, a):
|
||||
"""The continuous Fourier transform of a cubic spline kernel."""
|
||||
|
||||
bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f))
|
||||
- (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \
|
||||
/ (4 * math.pi**4 * f**4)
|
||||
|
||||
bf[f == 0] = 1
|
||||
|
||||
return bf
|
||||
|
||||
def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
||||
"""Computes maximum and argmax in the last two dimensions."""
|
||||
|
||||
max_val_row, argmax_row = torch.max(a, dim=-2)
|
||||
max_val, argmax_col = torch.max(max_val_row, dim=-1)
|
||||
argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)]
|
||||
argmax_row = argmax_row.reshape(argmax_col.shape)
|
||||
argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1)
|
||||
return max_val, argmax
|
||||
@@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def load_text_numpy(path, delimiter, dtype):
|
||||
if isinstance(delimiter, (tuple, list)):
|
||||
for d in delimiter:
|
||||
try:
|
||||
ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype)
|
||||
return ground_truth_rect
|
||||
except:
|
||||
pass
|
||||
|
||||
raise Exception('Could not read file {}'.format(path))
|
||||
else:
|
||||
ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype)
|
||||
return ground_truth_rect
|
||||
|
||||
|
||||
def load_text_pandas(path, delimiter, dtype):
|
||||
if isinstance(delimiter, (tuple, list)):
|
||||
for d in delimiter:
|
||||
try:
|
||||
ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False,
|
||||
low_memory=False).values
|
||||
return ground_truth_rect
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
raise Exception('Could not read file {}'.format(path))
|
||||
else:
|
||||
ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False,
|
||||
low_memory=False).values
|
||||
return ground_truth_rect
|
||||
|
||||
|
||||
def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'):
|
||||
if backend == 'numpy':
|
||||
return load_text_numpy(path, delimiter, dtype)
|
||||
elif backend == 'pandas':
|
||||
return load_text_pandas(path, delimiter, dtype)
|
||||
|
||||
|
||||
def load_str(path):
|
||||
with open(path, "r") as f:
|
||||
text_str = f.readline().strip().lower()
|
||||
return text_str
|
||||
@@ -0,0 +1,43 @@
|
||||
from lib.utils import TensorList
|
||||
import random
|
||||
|
||||
|
||||
class TrackerParams:
|
||||
"""Class for tracker parameters."""
|
||||
def set_default_values(self, default_vals: dict):
|
||||
for name, val in default_vals.items():
|
||||
if not hasattr(self, name):
|
||||
setattr(self, name, val)
|
||||
|
||||
def get(self, name: str, *default):
|
||||
"""Get a parameter value with the given name. If it does not exists, it return the default value given as a
|
||||
second argument or returns an error if no default value is given."""
|
||||
if len(default) > 1:
|
||||
raise ValueError('Can only give one default value.')
|
||||
|
||||
if not default:
|
||||
return getattr(self, name)
|
||||
|
||||
return getattr(self, name, default[0])
|
||||
|
||||
def has(self, name: str):
|
||||
"""Check if there exist a parameter with the given name."""
|
||||
return hasattr(self, name)
|
||||
|
||||
|
||||
class FeatureParams:
|
||||
"""Class for feature specific parameters"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
raise ValueError
|
||||
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, list):
|
||||
setattr(self, name, TensorList(val))
|
||||
else:
|
||||
setattr(self, name, val)
|
||||
|
||||
|
||||
def Choice(*args):
|
||||
"""Can be used to sample random parameter values."""
|
||||
return random.choice(args)
|
||||
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import _init_paths
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def transform_got10k(tracker_name, cfg_name):
|
||||
env = env_settings()
|
||||
result_dir = env.results_path
|
||||
src_dir = os.path.join(result_dir, "%s/%s/got10k/" % (tracker_name, cfg_name))
|
||||
dest_dir = os.path.join(result_dir, "%s/%s/got10k_submit/" % (tracker_name, cfg_name))
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
items = os.listdir(src_dir)
|
||||
for item in items:
|
||||
if "all" in item:
|
||||
continue
|
||||
src_path = os.path.join(src_dir, item)
|
||||
if "time" not in item:
|
||||
seq_name = item.replace(".txt", '')
|
||||
seq_dir = os.path.join(dest_dir, seq_name)
|
||||
if not os.path.exists(seq_dir):
|
||||
os.makedirs(seq_dir)
|
||||
new_item = item.replace(".txt", '_001.txt')
|
||||
dest_path = os.path.join(seq_dir, new_item)
|
||||
bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
|
||||
np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
|
||||
else:
|
||||
seq_name = item.replace("_time.txt", '')
|
||||
seq_dir = os.path.join(dest_dir, seq_name)
|
||||
if not os.path.exists(seq_dir):
|
||||
os.makedirs(seq_dir)
|
||||
dest_path = os.path.join(seq_dir, item)
|
||||
os.system("cp %s %s" % (src_path, dest_path))
|
||||
# make zip archive
|
||||
shutil.make_archive(src_dir, "zip", src_dir)
|
||||
shutil.make_archive(dest_dir, "zip", dest_dir)
|
||||
# Remove the original files
|
||||
shutil.rmtree(src_dir)
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='transform got10k results.')
|
||||
parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
|
||||
parser.add_argument('--cfg_name', type=str, help='Name of config file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
transform_got10k(args.tracker_name, args.cfg_name)
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import _init_paths
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def transform_trackingnet(tracker_name, cfg_name):
|
||||
env = env_settings()
|
||||
result_dir = env.results_path
|
||||
src_dir = os.path.join(result_dir, "%s/%s/trackingnet/" % (tracker_name, cfg_name))
|
||||
dest_dir = os.path.join(result_dir, "%s/%s/trackingnet_submit/" % (tracker_name, cfg_name))
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
items = os.listdir(src_dir)
|
||||
for item in items:
|
||||
if "all" in item:
|
||||
continue
|
||||
if "time" not in item:
|
||||
src_path = os.path.join(src_dir, item)
|
||||
dest_path = os.path.join(dest_dir, item)
|
||||
bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
|
||||
np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
|
||||
# make zip archive
|
||||
shutil.make_archive(src_dir, "zip", src_dir)
|
||||
shutil.make_archive(dest_dir, "zip", dest_dir)
|
||||
# Remove the original files
|
||||
shutil.rmtree(src_dir)
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='transform trackingnet results.')
|
||||
parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
|
||||
parser.add_argument('--cfg_name', type=str, help='Name of config file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
transform_trackingnet(args.tracker_name, args.cfg_name)
|
||||
Vendored
BIN
Binary file not shown.
@@ -0,0 +1 @@
|
||||
from .admin.multigpu import MultiGPU
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
prj_path = osp.join(this_dir, '../..')
|
||||
add_path(prj_path)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base_actor import BaseActor
|
||||
from .artrack import ARTrackActor
|
||||
from .artrack_seq import ARTrackSeqActor
|
||||
@@ -0,0 +1,281 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from lib.utils.merge import merge_template_search
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt)**2 + (cy_pred - cy_gt)**2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2*torch.sin(torch.arcsin(x)-torch.pi/4)**2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw+eps))**2
|
||||
py = ((cy_gt - cy_pred) / (ch+eps))**2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
#shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
#IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw**2 + ch**2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi**2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v**2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
class ARTrackActor(BaseActor):
|
||||
""" Actor for training ARTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.range = self.cfg.MODEL.RANGE
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.loss_weight['KL'] = 100
|
||||
self.loss_weight['focal'] = 2
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins * self.range
|
||||
end = self.bins * self.range + 1
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=(-1*magic_num), max=(1+magic_num))
|
||||
data['real_bbox'] = gt_bbox
|
||||
|
||||
seq_ori = (gt_bbox + magic_num) * (self.bins - 1)
|
||||
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_losses(self, pred_dict, gt_dict, return_status=True):
|
||||
bins = self.bins
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
seq_output = gt_dict['seq_output']
|
||||
pred_feat = pred_dict["feat"]
|
||||
if self.focal == None:
|
||||
weight = torch.ones(bins*self.range+2) * 1
|
||||
weight[bins*self.range+1] = 0.1
|
||||
weight[bins*self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.klloss = torch.nn.KLDivLoss(reduction='none').to(pred_feat)
|
||||
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
# compute varfifocal loss
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, bins*2+2)
|
||||
target = seq_output.reshape(-1).to(torch.int64)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
# compute giou and L1 loss
|
||||
beta = 1
|
||||
pred = pred_feat[0:4, :, 0:bins*self.range] * beta
|
||||
target = seq_output[:, 0:4].to(pred_feat)
|
||||
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range((-1*magic_num+1/(self.bins*self.range)), (1+magic_num-1/(self.bins*self.range)), 2/(self.bins*self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
target = target / (bins - 1) - magic_num
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
sious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
sious = sious.mean()
|
||||
siou_loss = sious
|
||||
l1_loss = self.objective['l1'](extra_seq, target)
|
||||
|
||||
loss = self.loss_weight['giou'] * siou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * varifocal_loss
|
||||
|
||||
if return_status:
|
||||
# status for log
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": loss.item(),
|
||||
"Loss/giou": siou_loss.item(),
|
||||
"Loss/l1": l1_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
return loss, status
|
||||
else:
|
||||
return loss
|
||||
@@ -0,0 +1,629 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
import lib.train.data.bounding_box_utils as bbutils
|
||||
from lib.utils.merge import merge_template_search
|
||||
from torch.distributions.categorical import Categorical
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
|
||||
def IoU(rect1, rect2):
|
||||
""" caculate interection over union
|
||||
Args:
|
||||
rect1: (x1, y1, x2, y2)
|
||||
rect2: (x1, y1, x2, y2)
|
||||
Returns:
|
||||
iou
|
||||
"""
|
||||
# overlap
|
||||
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
|
||||
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
|
||||
|
||||
xx1 = np.maximum(tx1, x1)
|
||||
yy1 = np.maximum(ty1, y1)
|
||||
xx2 = np.minimum(tx2, x2)
|
||||
yy2 = np.minimum(ty2, y2)
|
||||
|
||||
ww = np.maximum(0, xx2 - xx1)
|
||||
hh = np.maximum(0, yy2 - yy1)
|
||||
|
||||
area = (x2 - x1) * (y2 - y1)
|
||||
target_a = (tx2 - tx1) * (ty2 - ty1)
|
||||
inter = ww * hh
|
||||
iou = inter / (area + target_a - inter)
|
||||
return iou
|
||||
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt) ** 2 + (cy_pred - cy_gt) ** 2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2 * torch.sin(torch.arcsin(x) - torch.pi / 4) ** 2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw + eps)) ** 2
|
||||
py = ((cy_gt - cy_pred) / (ch + eps)) ** 2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
# shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
# IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw ** 2 + ch ** 2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi ** 2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v ** 2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
|
||||
class ARTrackSeqActor(BaseActor):
|
||||
""" Actor for training OSTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.range = cfg.MODEL.RANGE
|
||||
self.pre_num = cfg.MODEL.PRENUM
|
||||
self.loss_weight['KL'] = 0
|
||||
self.loss_weight['focal'] = 0
|
||||
self.pre_bbox = None
|
||||
self.x_feat_rem = None
|
||||
self.update_rem = None
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def _bbox_clip(self, cx, cy, width, height, boundary):
|
||||
cx = max(0, min(cx, boundary[1]))
|
||||
cy = max(0, min(cy, boundary[0]))
|
||||
width = max(10, min(width, boundary[1]))
|
||||
height = max(10, min(height, boundary[0]))
|
||||
return cx, cy, width, height
|
||||
|
||||
def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
|
||||
"""
|
||||
args:
|
||||
im: bgr based image
|
||||
pos: center position
|
||||
model_sz: exemplar size
|
||||
s_z: original size
|
||||
avg_chans: channel average
|
||||
"""
|
||||
if isinstance(pos, float):
|
||||
pos = [pos, pos]
|
||||
sz = original_sz
|
||||
im_sz = im.shape
|
||||
c = (original_sz + 1) / 2
|
||||
# context_xmin = round(pos[0] - c) # py2 and py3 round
|
||||
context_xmin = np.floor(pos[0] - c + 0.5)
|
||||
context_xmax = context_xmin + sz - 1
|
||||
# context_ymin = round(pos[1] - c)
|
||||
context_ymin = np.floor(pos[1] - c + 0.5)
|
||||
context_ymax = context_ymin + sz - 1
|
||||
left_pad = int(max(0., -context_xmin))
|
||||
top_pad = int(max(0., -context_ymin))
|
||||
right_pad = int(max(0., context_xmax - im_sz[1] + 1))
|
||||
bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
|
||||
|
||||
context_xmin = context_xmin + left_pad
|
||||
context_xmax = context_xmax + left_pad
|
||||
context_ymin = context_ymin + top_pad
|
||||
context_ymax = context_ymax + top_pad
|
||||
|
||||
r, c, k = im.shape
|
||||
if any([top_pad, bottom_pad, left_pad, right_pad]):
|
||||
size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k)
|
||||
te_im = np.zeros(size, np.uint8)
|
||||
te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
|
||||
if top_pad:
|
||||
te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
|
||||
if bottom_pad:
|
||||
te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
|
||||
if left_pad:
|
||||
te_im[:, 0:left_pad, :] = avg_chans
|
||||
if right_pad:
|
||||
te_im[:, c + left_pad:, :] = avg_chans
|
||||
im_patch = te_im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
else:
|
||||
im_patch = im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
|
||||
if not np.array_equal(model_sz, original_sz):
|
||||
try:
|
||||
im_patch = cv2.resize(im_patch, (model_sz, model_sz))
|
||||
except:
|
||||
return None
|
||||
im_patch = im_patch.transpose(2, 0, 1)
|
||||
im_patch = im_patch[np.newaxis, :, :, :]
|
||||
im_patch = im_patch.astype(np.float32)
|
||||
im_patch = torch.from_numpy(im_patch)
|
||||
im_patch = im_patch.cuda()
|
||||
return im_patch
|
||||
|
||||
def batch_init(self, images, template_bbox, initial_bbox) -> dict:
|
||||
self.frame_num = 1
|
||||
self.device = 'cuda'
|
||||
# Convert bbox (x1, y1, w, h) -> (cx, cy, w, h)
|
||||
|
||||
template_bbox = bbutils.batch_xywh2center2(template_bbox) # ndarray:(2*num_seq,4)
|
||||
initial_bbox = bbutils.batch_xywh2center2(initial_bbox) # ndarray:(2*num_seq,4)
|
||||
self.center_pos = initial_bbox[:, :2] # ndarray:(2*num_seq,2)
|
||||
self.size = initial_bbox[:, 2:] # ndarray:(2*num_seq,2)
|
||||
self.pre_bbox = initial_bbox
|
||||
for i in range(self.pre_num - 1):
|
||||
self.pre_bbox = numpy.concatenate((self.pre_bbox, initial_bbox), axis=1)
|
||||
# print(self.pre_bbox.shape)
|
||||
|
||||
template_factor = self.cfg.DATA.TEMPLATE.FACTOR
|
||||
w_z = template_bbox[:, 2] * template_factor # ndarray:(2*num_seq)
|
||||
h_z = template_bbox[:, 3] * template_factor # ndarray:(2*num_seq)
|
||||
s_z = np.ceil(np.sqrt(w_z * h_z)) # ndarray:(2*num_seq)
|
||||
|
||||
self.channel_average = []
|
||||
for img in images:
|
||||
self.channel_average.append(np.mean(img, axis=(0, 1)))
|
||||
self.channel_average = np.array(self.channel_average) # ndarray:(2*num_seq,3)
|
||||
|
||||
# get crop
|
||||
z_crop_list = []
|
||||
for i in range(len(images)):
|
||||
here_crop = self.get_subwindow(images[i], template_bbox[i, :2],
|
||||
self.cfg.DATA.TEMPLATE.SIZE, s_z[i], self.channel_average[i])
|
||||
z_crop = here_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
self.mean = [0.485, 0.456, 0.406]
|
||||
self.std = [0.229, 0.224, 0.225]
|
||||
self.inplace = False
|
||||
z_crop[0] = tvisf.normalize(z_crop[0], self.mean, self.std, self.inplace)
|
||||
z_crop_list.append(z_crop.clone())
|
||||
z_crop = torch.cat(z_crop_list, dim=0) # Tensor(2*num_seq,3,128,128)
|
||||
|
||||
self.update_rem = None
|
||||
|
||||
out = {'template_images': z_crop}
|
||||
return out
|
||||
|
||||
def batch_track(self, img, gt_boxes, template, action_mode='max') -> dict:
|
||||
search_factor = self.cfg.DATA.SEARCH.FACTOR
|
||||
w_x = self.size[:, 0] * search_factor
|
||||
h_x = self.size[:, 1] * search_factor
|
||||
s_x = np.ceil(np.sqrt(w_x * h_x))
|
||||
|
||||
gt_boxes_corner = bbutils.batch_xywh2corner(gt_boxes) # ndarray:(2*num_seq,4)
|
||||
|
||||
x_crop_list = []
|
||||
gt_in_crop_list = []
|
||||
pre_seq_list = []
|
||||
pre_seq_in_list = []
|
||||
x_feat_list = []
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
for i in range(len(img)):
|
||||
channel_avg = np.mean(img[i], axis=(0, 1))
|
||||
x_crop = self.get_subwindow(img[i], self.center_pos[i], self.cfg.DATA.SEARCH.SIZE,
|
||||
round(s_x[i]), channel_avg)
|
||||
if x_crop == None:
|
||||
return None
|
||||
for q in range(self.pre_num):
|
||||
pre_seq_temp = bbutils.batch_center2corner(self.pre_bbox[:, 0 + 4 * q:4 + 4 * q])
|
||||
if q == 0:
|
||||
pre_seq = pre_seq_temp
|
||||
else:
|
||||
pre_seq = numpy.concatenate((pre_seq, pre_seq_temp), axis=1)
|
||||
|
||||
if gt_boxes_corner is not None and np.sum(np.abs(gt_boxes_corner[i] - np.zeros(4))) > 10:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
for w in range(self.pre_num):
|
||||
|
||||
pre_in[0 + w * 4:2 + w * 4] = pre_seq[i, 0 + w * 4:2 + w * 4] - self.center_pos[i]
|
||||
pre_in[2 + w * 4:4 + w * 4] = pre_seq[i, 2 + w * 4:4 + w * 4] - self.center_pos[i]
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] * (
|
||||
self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] / self.cfg.DATA.SEARCH.SIZE
|
||||
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop = np.zeros(4)
|
||||
gt_in_crop[:2] = gt_boxes_corner[i, :2] - self.center_pos[i]
|
||||
gt_in_crop[2:] = gt_boxes_corner[i, 2:] - self.center_pos[i]
|
||||
gt_in_crop = gt_in_crop * (self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
gt_in_crop[2:] = gt_in_crop[2:] - gt_in_crop[:2] # (x1,y1,x2,y2) to (x1,y1,w,h)
|
||||
gt_in_crop_list.append(gt_in_crop)
|
||||
else:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop_list.append(np.zeros(4))
|
||||
pre_seq_input = torch.from_numpy(pre_in).clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq_input = (pre_seq_input + 0.5) * (self.bins - 1)
|
||||
pre_seq_in_list.append(pre_seq_input.clone())
|
||||
x_crop = x_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
x_crop[0] = tvisf.normalize(x_crop[0], self.mean, self.std, self.inplace)
|
||||
x_crop_list.append(x_crop.clone())
|
||||
|
||||
x_crop = torch.cat(x_crop_list, dim=0)
|
||||
pre_seq_output = torch.cat(pre_seq_in_list, dim=0).reshape(-1, 4 * self.pre_num)
|
||||
|
||||
outputs = self.net(template, x_crop, seq_input=pre_seq_output, head_type=None, stage="batch_track",
|
||||
search_feature=self.x_feat_rem, update=None)
|
||||
selected_indices = outputs['seqs'].detach()
|
||||
x_feat = outputs['x_feat'].detach().cpu()
|
||||
self.x_feat_rem = x_feat.clone()
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
pred_bbox = selected_indices[:, 0:4].data.cpu().numpy()
|
||||
bbox = (pred_bbox / (self.bins - 1) - magic_num) * s_x.reshape(-1, 1)
|
||||
cx = bbox[:, 0] + self.center_pos[:, 0] - s_x / 2
|
||||
cy = bbox[:, 1] + self.center_pos[:, 1] - s_x / 2
|
||||
width = bbox[:, 2] - bbox[:, 0]
|
||||
height = bbox[:, 3] - bbox[:, 1]
|
||||
cx = cx + width / 2
|
||||
cy = cy + height / 2
|
||||
|
||||
for i in range(len(img)):
|
||||
cx[i], cy[i], width[i], height[i] = self._bbox_clip(cx[i], cy[i], width[i],
|
||||
height[i], img[i].shape[:2])
|
||||
self.center_pos = np.stack([cx, cy], 1)
|
||||
self.size = np.stack([width, height], 1)
|
||||
for e in range(self.pre_num):
|
||||
if e != self.pre_num - 1:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = self.pre_bbox[:, 4 + e * 4:8 + e * 4]
|
||||
else:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = numpy.stack([cx, cy, width, height], 1)
|
||||
|
||||
bbox = np.stack([cx - width / 2, cy - height / 2, width, height], 1)
|
||||
|
||||
out = {
|
||||
'search_images': x_crop,
|
||||
'pred_bboxes': bbox,
|
||||
'selected_indices': selected_indices.cpu(),
|
||||
'gt_in_crop': torch.tensor(np.stack(gt_in_crop_list, axis=0), dtype=torch.float),
|
||||
'pre_seq': torch.tensor(np.stack(pre_seq_list, axis=0), dtype=torch.float),
|
||||
'x_feat': torch.tensor([item.cpu().detach().numpy() for item in x_feat_list], dtype=torch.float),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
def explore(self, data):
|
||||
results = {}
|
||||
search_images_list = []
|
||||
search_anno_list = []
|
||||
iou_list = []
|
||||
pre_seq_list = []
|
||||
x_feat_list = []
|
||||
|
||||
num_frames = data['num_frames']
|
||||
images = data['search_images']
|
||||
gt_bbox = data['search_annos']
|
||||
template = data['template_images']
|
||||
template_bbox = data['template_annos']
|
||||
|
||||
template = template
|
||||
template_bbox = template_bbox
|
||||
template_bbox = np.array(template_bbox)
|
||||
num_seq = len(num_frames)
|
||||
|
||||
for idx in range(np.max(num_frames)):
|
||||
here_images = [img[idx] for img in images] # S, N
|
||||
here_gt_bbox = np.array([gt[idx] for gt in gt_bbox])
|
||||
|
||||
here_images = here_images
|
||||
here_gt_bbox = np.concatenate([here_gt_bbox], 0)
|
||||
|
||||
if idx == 0:
|
||||
outputs_template = self.batch_init(template, template_bbox, here_gt_bbox)
|
||||
results['template_images'] = outputs_template['template_images']
|
||||
|
||||
else:
|
||||
outputs = self.batch_track(here_images, here_gt_bbox, outputs_template['template_images'],
|
||||
action_mode='half')
|
||||
if outputs == None:
|
||||
return None
|
||||
|
||||
x_feat = outputs['x_feat']
|
||||
pred_bbox = outputs['pred_bboxes']
|
||||
search_images_list.append(outputs['search_images'])
|
||||
search_anno_list.append(outputs['gt_in_crop'])
|
||||
if len(outputs['pre_seq']) != 8:
|
||||
print(outputs['pre_seq'])
|
||||
print(len(outputs['pre_seq']))
|
||||
print(idx)
|
||||
print(data['num_frames'])
|
||||
print(data['search_annos'])
|
||||
return None
|
||||
pre_seq_list.append(outputs['pre_seq'])
|
||||
pred_bbox_corner = bbutils.batch_xywh2corner(pred_bbox)
|
||||
gt_bbox_corner = bbutils.batch_xywh2corner(here_gt_bbox)
|
||||
here_iou = []
|
||||
for i in range(num_seq):
|
||||
bbox_iou = IoU(pred_bbox_corner[i], gt_bbox_corner[i])
|
||||
here_iou.append(bbox_iou)
|
||||
iou_list.append(here_iou)
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
results['x_feat'] = torch.cat([torch.stack(x_feat_list)], dim=2)
|
||||
|
||||
results['search_images'] = torch.cat([torch.stack(search_images_list)],
|
||||
dim=1)
|
||||
results['search_anno'] = torch.cat([torch.stack(search_anno_list)],
|
||||
dim=1)
|
||||
results['pre_seq'] = torch.cat([torch.stack(pre_seq_list)], dim=1)
|
||||
|
||||
iou_tensor = torch.tensor(iou_list, dtype=torch.float)
|
||||
results['baseline_iou'] = torch.cat([iou_tensor[:, :num_seq]], dim=1)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
box_mask_z = None
|
||||
ce_keep_rate = None
|
||||
if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
box_mask_z = generate_mask_cond(self.cfg, template_list[0].shape[0], template_list[0].device,
|
||||
data['template_anno'][0])
|
||||
|
||||
ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH
|
||||
ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH
|
||||
ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch,
|
||||
total_epochs=ce_start_epoch + ce_warm_epoch,
|
||||
ITERS_PER_EPOCH=1,
|
||||
base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0])
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins
|
||||
end = self.bins + 1
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=0.5, max=1.5)
|
||||
data['real_bbox'] = gt_bbox
|
||||
seq_ori = gt_bbox * (self.bins - 1)
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
ce_template_mask=box_mask_z,
|
||||
ce_keep_rate=ce_keep_rate,
|
||||
return_last_attn=False,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_sequence_losses(self, data):
|
||||
num_frames = data['search_images'].shape[0]
|
||||
template_images = data['template_images'].repeat(num_frames, 1, 1, 1, 1)
|
||||
template_images = template_images.view(-1, *template_images.size()[2:])
|
||||
search_images = data['search_images'].reshape(-1, *data['search_images'].size()[2:])
|
||||
search_anno = data['search_anno'].reshape(-1, *data['search_anno'].size()[2:])
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
self.loss_weight['focal'] = 0
|
||||
pre_seq = data['pre_seq'].reshape(-1, 4 * self.pre_num)
|
||||
x_feat = data['x_feat'].reshape(-1, *data['x_feat'].size()[2:])
|
||||
pre_seq = pre_seq.clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq = (pre_seq + magic_num) * (self.bins - 1)
|
||||
|
||||
outputs = self.net(template_images, search_images, seq_input=pre_seq, stage="forward_pass",
|
||||
search_feature=x_feat, update=None)
|
||||
|
||||
pred_feat = outputs["feat"]
|
||||
# generate labels
|
||||
if self.focal == None:
|
||||
weight = torch.ones(self.bins * self.range + 2) * 1
|
||||
weight[self.bins * self.range + 1] = 0.1
|
||||
weight[self.bins * self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
|
||||
search_anno[:, 2] = search_anno[:, 2] + search_anno[:, 0]
|
||||
search_anno[:, 3] = search_anno[:, 3] + search_anno[:, 1]
|
||||
target = (search_anno / self.cfg.DATA.SEARCH.SIZE + 0.5) * (self.bins - 1)
|
||||
|
||||
target = target.clamp(min=0.0, max=(self.bins * self.range - 0.0001))
|
||||
target_iou = target
|
||||
target = torch.cat([target], dim=1)
|
||||
target = target.reshape(-1).to(torch.int64)
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, self.bins * self.range + 2)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
pred = pred_feat[0:4, :, 0:self.bins * self.range]
|
||||
target = target_iou[:, 0:4].to(pred_feat) / (self.bins - 1) - magic_num
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range(-1 * magic_num + 1 / (self.bins * self.range), 1 + magic_num - 1 / (self.bins * self.range), 2 / (self.bins * self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
|
||||
cious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
cious = cious.mean()
|
||||
|
||||
giou_loss = cious
|
||||
loss_bb = self.loss_weight['giou'] * giou_loss + self.loss_weight[
|
||||
'focal'] * varifocal_loss
|
||||
|
||||
total_losses = loss_bb
|
||||
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": total_losses.item(),
|
||||
"Loss/giou": giou_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
|
||||
return total_losses, status
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class BaseActor:
|
||||
""" Base class for actor. The actor class handles the passing of the data through the network
|
||||
and calculation the loss"""
|
||||
def __init__(self, net, objective):
|
||||
"""
|
||||
args:
|
||||
net - The network to train
|
||||
objective - The loss function
|
||||
"""
|
||||
self.net = net
|
||||
self.objective = objective
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
""" Called in each training iteration. Should pass in input data through the network, calculate the loss, and
|
||||
return the training stats for the input data
|
||||
args:
|
||||
data - A TensorDict containing all the necessary data blocks.
|
||||
|
||||
returns:
|
||||
loss - loss for the input data
|
||||
stats - a dict containing detailed losses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device):
|
||||
""" Move the network to device
|
||||
args:
|
||||
device - device to use. 'cpu' or 'cuda'
|
||||
"""
|
||||
self.net.to(device)
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set whether the network is in train mode.
|
||||
args:
|
||||
mode (True) - Bool specifying whether in training mode.
|
||||
"""
|
||||
self.net.train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set network to eval mode"""
|
||||
self.train(False)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .environment import env_settings, create_default_local_file_ITP_train
|
||||
from .stats import AverageMeter, StatValue
|
||||
#from .tensorboard import TensorboardWriter
|
||||
@@ -0,0 +1,102 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': empty_str,
|
||||
'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
|
||||
'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
|
||||
'lasot_dir': empty_str,
|
||||
'got10k_dir': empty_str,
|
||||
'trackingnet_dir': empty_str,
|
||||
'coco_dir': empty_str,
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': empty_str,
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def create_default_local_file_ITP_train(workspace_dir, data_dir):
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': workspace_dir,
|
||||
'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files.
|
||||
'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'),
|
||||
'lasot_dir': os.path.join(data_dir, 'lasot'),
|
||||
'got10k_dir': os.path.join(data_dir, 'got10k/train'),
|
||||
'got10k_val_dir': os.path.join(data_dir, 'got10k/val'),
|
||||
'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'),
|
||||
'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'),
|
||||
'trackingnet_dir': os.path.join(data_dir, 'trackingnet'),
|
||||
'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'),
|
||||
'coco_dir': os.path.join(data_dir, 'coco'),
|
||||
'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'),
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': os.path.join(data_dir, 'vid'),
|
||||
'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'),
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
if attr_val == empty_str:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.train.admin.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.EnvironmentSettings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
|
||||
@@ -0,0 +1,24 @@
|
||||
class EnvironmentSettings:
|
||||
def __init__(self):
|
||||
self.workspace_dir = '/home/baiyifan/code/2stage_update_intrain' # Base directory for saving network checkpoints.
|
||||
self.tensorboard_dir = '/home/baiyifan/code/2stage/tensorboard' # Directory for tensorboard files.
|
||||
self.pretrained_networks = '/home/baiyifan/code/2stage/pretrained_networks'
|
||||
self.lasot_dir = '/home/baiyifan/LaSOT/LaSOTBenchmark'
|
||||
self.got10k_dir = '/home/baiyifan/GOT-10k/train'
|
||||
self.got10k_val_dir = '/home/baiyifan/GOT-10k/val'
|
||||
self.lasot_lmdb_dir = '/home/baiyifan/code/2stage/data/lasot_lmdb'
|
||||
self.got10k_lmdb_dir = '/home/baiyifan/code/2stage/data/got10k_lmdb'
|
||||
self.trackingnet_dir = '/ssddata/TrackingNet/all_zip'
|
||||
self.trackingnet_lmdb_dir = '/home/baiyifan/code/2stage/data/trackingnet_lmdb'
|
||||
self.coco_dir = '/home/baiyifan/coco'
|
||||
self.coco_lmdb_dir = '/home/baiyifan/code/2stage/data/coco_lmdb'
|
||||
self.lvis_dir = ''
|
||||
self.sbd_dir = ''
|
||||
self.imagenet_dir = '/home/baiyifan/code/2stage/data/vid'
|
||||
self.imagenet_lmdb_dir = '/home/baiyifan/code/2stage/data/vid_lmdb'
|
||||
self.imagenetdet_dir = ''
|
||||
self.ecssd_dir = ''
|
||||
self.hkuis_dir = ''
|
||||
self.msra10k_dir = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||||
|
||||
|
||||
def is_multi_gpu(net):
|
||||
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||||
|
||||
|
||||
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
pass
|
||||
return getattr(self.module, item)
|
||||
@@ -0,0 +1,13 @@
|
||||
from lib.train.admin.environment import env_settings
|
||||
|
||||
|
||||
class Settings:
|
||||
""" Training settings, e.g. the paths to datasets and networks."""
|
||||
def __init__(self):
|
||||
self.set_default()
|
||||
|
||||
def set_default(self):
|
||||
self.env = env_settings()
|
||||
self.use_gpu = True
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
class StatValue:
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
self.history.append(self.val)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.has_new_data = False
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.val = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def new_epoch(self):
|
||||
if self.count > 0:
|
||||
self.history.append(self.avg)
|
||||
self.reset()
|
||||
self.has_new_data = True
|
||||
else:
|
||||
self.has_new_data = False
|
||||
|
||||
|
||||
def topk_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
single_input = not isinstance(topk, (tuple, list))
|
||||
if single_input:
|
||||
topk = (topk,)
|
||||
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
|
||||
res.append(correct_k * 100.0 / batch_size)
|
||||
|
||||
if single_input:
|
||||
return res[0]
|
||||
|
||||
return res
|
||||
@@ -0,0 +1,27 @@
|
||||
#import os
|
||||
#from collections import OrderedDict
|
||||
#try:
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
#except:
|
||||
# print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
#class TensorboardWriter:
|
||||
# def __init__(self, directory, loader_names):
|
||||
# self.directory = directory
|
||||
# self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
|
||||
|
||||
# def write_info(self, script_name, description):
|
||||
# tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
|
||||
# tb_info_writer.add_text('Script_name', script_name)
|
||||
# tb_info_writer.add_text('Description', description)
|
||||
# tb_info_writer.close()
|
||||
|
||||
# def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
|
||||
# for loader_name, loader_stats in stats.items():
|
||||
# if loader_stats is None:
|
||||
# continue
|
||||
# for var_name, val in loader_stats.items():
|
||||
# if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
|
||||
# self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
|
||||
@@ -0,0 +1,193 @@
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
# datasets related
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader
|
||||
import lib.train.data.transforms as tfm
|
||||
from lib.utils.misc import is_main_process
|
||||
|
||||
|
||||
def update_settings(settings, cfg):
|
||||
settings.print_interval = cfg.TRAIN.PRINT_INTERVAL
|
||||
settings.search_area_factor = {'template': cfg.DATA.TEMPLATE.FACTOR,
|
||||
'search': cfg.DATA.SEARCH.FACTOR}
|
||||
settings.output_sz = {'template': cfg.DATA.TEMPLATE.SIZE,
|
||||
'search': cfg.DATA.SEARCH.SIZE}
|
||||
settings.center_jitter_factor = {'template': cfg.DATA.TEMPLATE.CENTER_JITTER,
|
||||
'search': cfg.DATA.SEARCH.CENTER_JITTER}
|
||||
settings.scale_jitter_factor = {'template': cfg.DATA.TEMPLATE.SCALE_JITTER,
|
||||
'search': cfg.DATA.SEARCH.SCALE_JITTER}
|
||||
settings.grad_clip_norm = cfg.TRAIN.GRAD_CLIP_NORM
|
||||
settings.print_stats = None
|
||||
settings.batchsize = cfg.TRAIN.BATCH_SIZE
|
||||
settings.scheduler_type = cfg.TRAIN.SCHEDULER.TYPE
|
||||
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
|
||||
def build_dataloaders(cfg, settings):
|
||||
# Data transform
|
||||
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05),
|
||||
tfm.RandomHorizontalFlip(probability=0.5))
|
||||
|
||||
transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
|
||||
tfm.RandomHorizontalFlip_Norm(probability=0.5),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
transform_val = tfm.Transform(tfm.ToTensor(),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
# The tracking pairs processing module
|
||||
output_sz = settings.output_sz
|
||||
search_area_factor = settings.search_area_factor
|
||||
|
||||
data_processing_train = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_train,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
data_processing_val = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_val,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
# Train sampler and loader
|
||||
settings.num_template = getattr(cfg.DATA.TEMPLATE, "NUMBER", 1)
|
||||
settings.num_search = getattr(cfg.DATA.SEARCH, "NUMBER", 1)
|
||||
sampler_mode = getattr(cfg.DATA, "SAMPLER_MODE", "causal")
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
print("sampler_mode", sampler_mode)
|
||||
dataset_train = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_train,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
|
||||
train_sampler = DistributedSampler(dataset_train) if settings.local_rank != -1 else None
|
||||
shuffle = False if settings.local_rank != -1 else True
|
||||
|
||||
loader_train = LTRLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=shuffle,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=train_sampler)
|
||||
|
||||
# Validation samplers and loaders
|
||||
dataset_val = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.VAL.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.VAL.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.VAL.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_val,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
val_sampler = DistributedSampler(dataset_val) if settings.local_rank != -1 else None
|
||||
loader_val = LTRLoader('val', dataset_val, training=False, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=val_sampler,
|
||||
epoch_interval=cfg.TRAIN.VAL_EPOCH_INTERVAL)
|
||||
|
||||
return loader_train, loader_val
|
||||
|
||||
|
||||
def get_optimizer_scheduler(net, cfg):
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
if train_cls:
|
||||
print("Only training classification head. Learnable parameters are shown below.")
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "cls" in n and p.requires_grad]}
|
||||
]
|
||||
|
||||
for n, p in net.named_parameters():
|
||||
if "cls" not in n:
|
||||
p.requires_grad = False
|
||||
else:
|
||||
print(n)
|
||||
else:
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in net.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.TRAIN.LR * cfg.TRAIN.BACKBONE_MULTIPLIER,
|
||||
},
|
||||
]
|
||||
if is_main_process():
|
||||
print("Learnable parameters are shown below.")
|
||||
for n, p in net.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
if cfg.TRAIN.OPTIMIZER == "ADAMW":
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR,
|
||||
weight_decay=cfg.TRAIN.WEIGHT_DECAY)
|
||||
else:
|
||||
raise ValueError("Unsupported Optimizer")
|
||||
if cfg.TRAIN.SCHEDULER.TYPE == 'step':
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP_EPOCH)
|
||||
elif cfg.TRAIN.SCHEDULER.TYPE == "Mstep":
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=cfg.TRAIN.SCHEDULER.MILESTONES,
|
||||
gamma=cfg.TRAIN.SCHEDULER.GAMMA)
|
||||
else:
|
||||
raise ValueError("Unsupported scheduler")
|
||||
return optimizer, lr_scheduler
|
||||
@@ -0,0 +1,2 @@
|
||||
from .loader import LTRLoader
|
||||
from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader
|
||||
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def batch_center2corner(boxes):
|
||||
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
|
||||
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
|
||||
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
|
||||
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def batch_corner2center(boxes):
|
||||
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
|
||||
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
|
||||
w = (boxes[:, 2] - boxes[:, 0])
|
||||
h = (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center(boxes):
|
||||
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
|
||||
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center2(boxes):
|
||||
cx = boxes[:, 0] + boxes[:, 2] / 2
|
||||
cy = boxes[:, 1] + boxes[:, 3] / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
|
||||
def batch_xywh2corner(boxes):
|
||||
xmin = boxes[:, 0]
|
||||
ymin = boxes[:, 1]
|
||||
xmax = boxes[:, 0] + boxes[:, 2]
|
||||
ymax = boxes[:, 1] + boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def rect_to_rel(bb, sz_norm=None):
|
||||
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
|
||||
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
|
||||
args:
|
||||
bb - N x 4 tensor of boxes.
|
||||
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
|
||||
"""
|
||||
|
||||
c = bb[...,:2] + 0.5 * bb[...,2:]
|
||||
if sz_norm is None:
|
||||
c_rel = c / bb[...,2:]
|
||||
else:
|
||||
c_rel = c / sz_norm
|
||||
sz_rel = torch.log(bb[...,2:])
|
||||
return torch.cat((c_rel, sz_rel), dim=-1)
|
||||
|
||||
|
||||
def rel_to_rect(bb, sz_norm=None):
|
||||
"""Inverts the effect of rect_to_rel. See above."""
|
||||
|
||||
sz = torch.exp(bb[...,2:])
|
||||
if sz_norm is None:
|
||||
c = bb[...,:2] * sz
|
||||
else:
|
||||
c = bb[...,:2] * sz_norm
|
||||
tl = c - 0.5 * sz
|
||||
return torch.cat((tl, sz), dim=-1)
|
||||
|
||||
|
||||
def masks_to_bboxes(mask, fmt='c'):
|
||||
|
||||
""" Convert a mask tensor to one or more bounding boxes.
|
||||
Note: This function is a bit new, make sure it does what it says. /Andreas
|
||||
:param mask: Tensor of masks, shape = (..., H, W)
|
||||
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
|
||||
't' => "top left + size" or (x_left, y_top, width, height)
|
||||
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
|
||||
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
|
||||
"""
|
||||
batch_shape = mask.shape[:-2]
|
||||
mask = mask.reshape((-1, *mask.shape[-2:]))
|
||||
bboxes = []
|
||||
|
||||
for m in mask:
|
||||
mx = m.sum(dim=-2).nonzero()
|
||||
my = m.sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
bboxes.append(bb)
|
||||
|
||||
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
|
||||
bboxes = bboxes.reshape(batch_shape + (4,))
|
||||
|
||||
if fmt == 'v':
|
||||
return bboxes
|
||||
|
||||
x1 = bboxes[..., :2]
|
||||
s = bboxes[..., 2:] - x1 + 1
|
||||
|
||||
if fmt == 'c':
|
||||
return torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
return torch.cat((x1, s), dim=-1)
|
||||
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
|
||||
|
||||
def masks_to_bboxes_multi(mask, ids, fmt='c'):
|
||||
assert mask.dim() == 2
|
||||
bboxes = []
|
||||
|
||||
for id in ids:
|
||||
mx = (mask == id).sum(dim=-2).nonzero()
|
||||
my = (mask == id).float().sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
|
||||
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
|
||||
|
||||
x1 = bb[:2]
|
||||
s = bb[2:] - x1 + 1
|
||||
|
||||
if fmt == 'v':
|
||||
pass
|
||||
elif fmt == 'c':
|
||||
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
bb = torch.cat((x1, s), dim=-1)
|
||||
else:
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
bboxes.append(bb)
|
||||
|
||||
return bboxes
|
||||
@@ -0,0 +1,103 @@
|
||||
import jpeg4py
|
||||
import cv2 as cv
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
|
||||
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
|
||||
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
|
||||
[0, 64, 128], [128, 64, 128]]
|
||||
|
||||
|
||||
def default_image_loader(path):
|
||||
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
|
||||
but reverts to the opencv_loader if the former is not available."""
|
||||
if default_image_loader.use_jpeg4py is None:
|
||||
# Try using jpeg4py
|
||||
im = jpeg4py_loader(path)
|
||||
if im is None:
|
||||
default_image_loader.use_jpeg4py = False
|
||||
print('Using opencv_loader instead.')
|
||||
else:
|
||||
default_image_loader.use_jpeg4py = True
|
||||
return im
|
||||
if default_image_loader.use_jpeg4py:
|
||||
return jpeg4py_loader(path)
|
||||
return opencv_loader(path)
|
||||
|
||||
default_image_loader.use_jpeg4py = None
|
||||
|
||||
|
||||
def jpeg4py_loader(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_loader(path):
|
||||
""" Read image using opencv's imread function and returns it in rgb format"""
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def jpeg4py_loader_w_failsafe(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except:
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_seg_loader(path):
|
||||
""" Read segmentation annotation using opencv's imread function"""
|
||||
try:
|
||||
return cv.imread(path)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def imread_indexed(filename):
|
||||
""" Load indexed image with given filename. Used to read segmentation annotations."""
|
||||
|
||||
im = Image.open(filename)
|
||||
|
||||
annotation = np.atleast_3d(im)[...,0]
|
||||
return annotation
|
||||
|
||||
|
||||
def imwrite_indexed(filename, array, color_palette=None):
|
||||
""" Save indexed image as png. Used to save segmentation annotation."""
|
||||
|
||||
if color_palette is None:
|
||||
color_palette = davis_palette
|
||||
|
||||
if np.atleast_3d(array).shape[2] != 1:
|
||||
raise Exception("Saving indexed PNGs requires 2D array.")
|
||||
|
||||
im = Image.fromarray(array)
|
||||
im.putpalette(color_palette.ravel())
|
||||
im.save(filename, format='PNG')
|
||||
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.utils.data.dataloader
|
||||
import importlib
|
||||
import collections
|
||||
# from torch._six import string_classes
|
||||
from lib.utils import TensorDict, TensorList
|
||||
|
||||
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
|
||||
int_classes = int
|
||||
else:
|
||||
from torch._six import int_classes
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
string_classes = str
|
||||
|
||||
def _check_use_shared_memory():
|
||||
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
|
||||
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
|
||||
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
|
||||
if hasattr(collate_lib, '_use_shared_memory'):
|
||||
return getattr(collate_lib, '_use_shared_memory')
|
||||
return torch.utils.data.get_worker_info() is not None
|
||||
|
||||
|
||||
def ltr_collate(batch):
|
||||
"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 0, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
def ltr_collate_stack1(batch):
|
||||
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 1, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 1)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate_stack1(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
class LTRLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
|
||||
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
|
||||
select along which dimension the data should be stacked to form a batch.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: 1).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: False).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, ``shuffle`` must be False.
|
||||
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
||||
indices at a time. Mutually exclusive with batch_size, shuffle,
|
||||
sampler, and drop_last.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means that the data will be loaded in the main process.
|
||||
(default: 0)
|
||||
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
||||
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
||||
into CUDA pinned memory before returning them.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: False)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: 0)
|
||||
worker_init_fn (callable, optional): If not None, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: None)
|
||||
|
||||
.. note:: By default, each worker will have its PyTorch seed set to
|
||||
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
||||
by main process using its RNG. However, seeds for other libraries
|
||||
may be duplicated upon initializing workers (w.g., NumPy), causing
|
||||
each worker to return identical random numbers. (See
|
||||
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
||||
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
||||
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
||||
before data loading.
|
||||
|
||||
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
||||
unpicklable object, e.g., a lambda function.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
print("pin_memory is", pin_memory)
|
||||
if collate_fn is None:
|
||||
if stack_dim == 0:
|
||||
collate_fn = ltr_collate
|
||||
elif stack_dim == 1:
|
||||
collate_fn = ltr_collate_stack1
|
||||
else:
|
||||
raise ValueError('Stack dim no supported. Must be 0 or 1.')
|
||||
|
||||
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from lib.utils import TensorDict
|
||||
import lib.train.data.processing_utils as prutils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stack_tensors(x):
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
|
||||
return torch.stack(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseProcessing:
|
||||
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
|
||||
through the network. For example, it can be used to crop a search region around the object, apply various data
|
||||
augmentations, etc."""
|
||||
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
|
||||
"""
|
||||
args:
|
||||
transform - The set of transformations to be applied on the images. Used only if template_transform or
|
||||
search_transform is None.
|
||||
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
|
||||
example, it can be used to convert both template and search images to grayscale.
|
||||
"""
|
||||
self.transform = {'template': transform if template_transform is None else template_transform,
|
||||
'search': transform if search_transform is None else search_transform,
|
||||
'joint': joint_transform}
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class STARKProcessing(BaseProcessing):
|
||||
""" The processing class used for training LittleBoy. The images are processed in the following way.
|
||||
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
|
||||
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
|
||||
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
|
||||
always at the center of the search region. The search region is then resized to a fixed size given by the
|
||||
argument output_sz.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
|
||||
mode='pair', settings=None, *args, **kwargs):
|
||||
"""
|
||||
args:
|
||||
search_area_factor - The size of the search region relative to the target size.
|
||||
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
|
||||
square.
|
||||
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.search_area_factor = search_area_factor
|
||||
self.output_sz = output_sz
|
||||
self.center_jitter_factor = center_jitter_factor
|
||||
self.scale_jitter_factor = scale_jitter_factor
|
||||
self.mode = mode
|
||||
self.settings = settings
|
||||
|
||||
def _get_jittered_box(self, box, mode):
|
||||
""" Jitter the input box
|
||||
args:
|
||||
box - input bounding box
|
||||
mode - string 'template' or 'search' indicating template or search data
|
||||
|
||||
returns:
|
||||
torch.Tensor - jittered box
|
||||
"""
|
||||
|
||||
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
|
||||
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
|
||||
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
|
||||
|
||||
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the following fields:
|
||||
'template_images', search_images', 'template_anno', 'search_anno'
|
||||
returns:
|
||||
TensorDict - output data block with following fields:
|
||||
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
|
||||
"""
|
||||
# Apply joint transforms
|
||||
if self.transform['joint'] is not None:
|
||||
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
|
||||
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
|
||||
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
|
||||
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
|
||||
|
||||
for s in ['template', 'search']:
|
||||
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
|
||||
"In pair mode, num train/test frames must be 1"
|
||||
|
||||
# Add a uniform noise to the center pos
|
||||
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
|
||||
|
||||
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
|
||||
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
|
||||
|
||||
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
|
||||
if (crop_sz < 1).any():
|
||||
data['valid'] = False
|
||||
# print("Too small box is found. Replace it with new data.")
|
||||
return data
|
||||
|
||||
# Crop image region centered at jittered_anno box and get the attention mask
|
||||
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
|
||||
data[s + '_anno'], self.search_area_factor[s],
|
||||
self.output_sz[s], masks=data[s + '_masks'])
|
||||
# Apply transforms
|
||||
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
|
||||
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
|
||||
|
||||
|
||||
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
|
||||
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
|
||||
for ele in data[s + '_att']:
|
||||
if (ele == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of original attention mask are all one. Replace it with new data.")
|
||||
return data
|
||||
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
|
||||
for ele in data[s + '_att']:
|
||||
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
|
||||
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
|
||||
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
|
||||
if (mask_down == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of down-sampled attention mask are all one. "
|
||||
# "Replace it with new data.")
|
||||
return data
|
||||
|
||||
data['valid'] = True
|
||||
# if we use copy-and-paste augmentation
|
||||
if data["template_masks"] is None or data["search_masks"] is None:
|
||||
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
|
||||
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
|
||||
# Prepare output
|
||||
if self.mode == 'sequence':
|
||||
data = data.apply(stack_tensors)
|
||||
else:
|
||||
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
'''modified from the original test implementation
|
||||
Replace cv.BORDER_REPLICATE with cv.BORDER_CONSTANT
|
||||
Add a variable called att_mask for computing attention and positional encoding later'''
|
||||
|
||||
|
||||
def sample_target(im, target_bb, search_area_factor, output_sz=None, mask=None):
|
||||
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
||||
|
||||
args:
|
||||
im - cv image
|
||||
target_bb - target box [x, y, w, h]
|
||||
search_area_factor - Ratio of crop size to target size
|
||||
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
||||
|
||||
returns:
|
||||
cv image - extracted crop
|
||||
float - the factor by which the crop has been resized to make the crop size equal output_size
|
||||
"""
|
||||
if not isinstance(target_bb, list):
|
||||
x, y, w, h = target_bb.tolist()
|
||||
else:
|
||||
x, y, w, h = target_bb
|
||||
# Crop image
|
||||
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
||||
|
||||
if crop_sz < 1:
|
||||
raise Exception('Too small bounding box.')
|
||||
|
||||
x1 = round(x + 0.5 * w - crop_sz * 0.5)
|
||||
x2 = x1 + crop_sz
|
||||
|
||||
y1 = round(y + 0.5 * h - crop_sz * 0.5)
|
||||
y2 = y1 + crop_sz
|
||||
|
||||
x1_pad = max(0, -x1)
|
||||
x2_pad = max(x2 - im.shape[1] + 1, 0)
|
||||
|
||||
y1_pad = max(0, -y1)
|
||||
y2_pad = max(y2 - im.shape[0] + 1, 0)
|
||||
|
||||
# Crop target
|
||||
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
||||
if mask is not None:
|
||||
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
||||
|
||||
# Pad
|
||||
im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_CONSTANT)
|
||||
# deal with attention mask
|
||||
H, W, _ = im_crop_padded.shape
|
||||
att_mask = np.ones((H,W))
|
||||
end_x, end_y = -x2_pad, -y2_pad
|
||||
if y2_pad == 0:
|
||||
end_y = None
|
||||
if x2_pad == 0:
|
||||
end_x = None
|
||||
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
||||
if mask is not None:
|
||||
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
||||
|
||||
if output_sz is not None:
|
||||
resize_factor = output_sz / crop_sz
|
||||
im_crop_padded = cv.resize(im_crop_padded, (output_sz, output_sz))
|
||||
att_mask = cv.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
||||
if mask is None:
|
||||
return im_crop_padded, resize_factor, att_mask
|
||||
mask_crop_padded = \
|
||||
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
||||
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
||||
|
||||
else:
|
||||
if mask is None:
|
||||
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
||||
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
||||
|
||||
|
||||
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
|
||||
crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box_in - the box for which the co-ordinates are to be transformed
|
||||
box_extract - the box about which the image crop has been extracted.
|
||||
resize_factor - the ratio between the original image scale and the scale of the image crop
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]
|
||||
|
||||
box_in_center = box_in[0:2] + 0.5 * box_in[2:4]
|
||||
|
||||
box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
|
||||
box_out_wh = box_in[2:4] * resize_factor
|
||||
|
||||
box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
|
||||
def jittered_center_crop(frames, box_extract, box_gt, search_area_factor, output_sz, masks=None):
|
||||
""" For each frame in frames, extracts a square crop centered at box_extract, of area search_area_factor^2
|
||||
times box_extract area. The extracted crops are then resized to output_sz. Further, the co-ordinates of the box
|
||||
box_gt are transformed to the image crop co-ordinates
|
||||
|
||||
args:
|
||||
frames - list of frames
|
||||
box_extract - list of boxes of same length as frames. The crops are extracted using anno_extract
|
||||
box_gt - list of boxes of same length as frames. The co-ordinates of these boxes are transformed from
|
||||
image co-ordinates to the crop co-ordinates
|
||||
search_area_factor - The area of the extracted crop is search_area_factor^2 times box_extract area
|
||||
output_sz - The size to which the extracted crops are resized
|
||||
|
||||
returns:
|
||||
list - list of image crops
|
||||
list - box_gt location in the crop co-ordinates
|
||||
"""
|
||||
|
||||
if masks is None:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz)
|
||||
for f, a in zip(frames, box_extract)]
|
||||
frames_crop, resize_factors, att_mask = zip(*crops_resize_factors)
|
||||
masks_crop = None
|
||||
else:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz, m)
|
||||
for f, a, m in zip(frames, box_extract, masks)]
|
||||
frames_crop, resize_factors, att_mask, masks_crop = zip(*crops_resize_factors)
|
||||
# frames_crop: tuple of ndarray (128,128,3), att_mask: tuple of ndarray (128,128)
|
||||
crop_sz = torch.Tensor([output_sz, output_sz])
|
||||
|
||||
# find the bb location in the crop
|
||||
'''Note that here we use normalized coord'''
|
||||
box_crop = [transform_image_to_crop(a_gt, a_ex, rf, crop_sz, normalize=True)
|
||||
for a_gt, a_ex, rf in zip(box_gt, box_extract, resize_factors)] # (x1,y1,w,h) list of tensors
|
||||
|
||||
return frames_crop, box_crop, att_mask, masks_crop
|
||||
|
||||
|
||||
def transform_box_to_crop(box: torch.Tensor, crop_box: torch.Tensor, crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box - the box for which the co-ordinates are to be transformed
|
||||
crop_box - bounding box defining the crop in the original image
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
|
||||
box_out = box.clone()
|
||||
box_out[:2] -= crop_box[:2]
|
||||
|
||||
scale_factor = crop_sz / crop_box[2:]
|
||||
|
||||
box_out[:2] *= scale_factor
|
||||
box_out[2:] *= scale_factor
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
@@ -0,0 +1,349 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
from lib.utils import TensorDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def no_processing(data):
|
||||
return data
|
||||
|
||||
|
||||
class TrackingSampler(torch.utils.data.Dataset):
|
||||
""" Class responsible for sampling frames from training sequences to form batches.
|
||||
|
||||
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
|
||||
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
|
||||
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
|
||||
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
|
||||
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
|
||||
|
||||
The sampled frames are then passed through the input 'processing' function for the necessary processing-
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
|
||||
train_cls=False, pos_prob=0.5):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the test frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.train_cls = train_cls # whether we are training classification
|
||||
self.pos_prob = pos_prob # probability of sampling positive class when making classification
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.processing = processing
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
|
||||
allow_invisible=False, force_invisible=False):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
# get valid ids
|
||||
if force_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
|
||||
else:
|
||||
if allow_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id)]
|
||||
else:
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.train_cls:
|
||||
return self.getitem_cls()
|
||||
else:
|
||||
return self.getitem()
|
||||
|
||||
def getitem(self):
|
||||
"""
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
|
||||
if is_video_dataset:
|
||||
template_frame_ids = None
|
||||
search_frame_ids = None
|
||||
gap_increase = 0
|
||||
|
||||
if self.frame_sample_mode == 'causal':
|
||||
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
|
||||
while search_frame_ids is None:
|
||||
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
|
||||
min_id=base_frame_id[0] - self.max_gap - gap_increase,
|
||||
max_id=base_frame_id[0])
|
||||
if prev_frame_ids is None:
|
||||
gap_increase += 5
|
||||
continue
|
||||
template_frame_ids = base_frame_id + prev_frame_ids
|
||||
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
|
||||
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
|
||||
num_ids=self.num_search_frames)
|
||||
# Increase gap until a frame is found
|
||||
gap_increase += 5
|
||||
|
||||
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("Illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def getitem_cls(self):
|
||||
# get data for classification
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
label = None
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample template and search frame ids
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode in ["trident", "trident_pro"]:
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
# "try" is used to handle trackingnet data failure
|
||||
# get images and bounding boxes (for templates)
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
|
||||
seq_info_dict)
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
|
||||
(H, W))] * self.num_template_frames
|
||||
# get images and bounding boxes (for searches)
|
||||
# positive samples
|
||||
if random.random() < self.pos_prob:
|
||||
label = torch.ones(1,)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
# negative samples
|
||||
else:
|
||||
label = torch.zeros(1,)
|
||||
if is_video_dataset:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
|
||||
if search_frame_ids is None:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
|
||||
seq_info_dict)
|
||||
search_anno["bbox"] = [self.get_center_box(H, W)]
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
H, W, _ = search_frames[0].shape
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
# add classification label
|
||||
data["label"] = label
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def get_center_box(self, H, W, ratio=1/8):
|
||||
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
|
||||
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
|
||||
|
||||
def sample_seq_from_dataset(self, dataset, is_video_dataset):
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
return seq_id, visible, seq_info_dict
|
||||
|
||||
def get_one_search(self):
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
# sample a sequence
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample a frame
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == "stark":
|
||||
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
|
||||
else:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
|
||||
else:
|
||||
search_frame_ids = [1]
|
||||
# get the image, bounding box and other info
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
return search_frames, search_anno, meta_obj_test
|
||||
|
||||
def get_frame_ids_trident(self, visible):
|
||||
# get template and search ids in a 'trident' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
if self.frame_sample_mode == "trident_pro":
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
|
||||
allow_invisible=True)
|
||||
else:
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
def get_frame_ids_stark(self, visible, valid):
|
||||
# get template and search ids in a 'stark' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
"""we require the frame to be valid but not necessary visible"""
|
||||
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
@@ -0,0 +1,265 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class SequenceSampler(torch.utils.data.Dataset):
|
||||
"""
|
||||
Sample sequence for sequence-level training
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
|
||||
max_interval - Maximum interval between sampled frames
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
prob - sequential sampling by prob / interval sampling by 1-prob
|
||||
"""
|
||||
self.datasets = datasets
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
self.prob=prob
|
||||
self.extra=1
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
|
||||
def _sequential_sample(self, visible):
|
||||
# Sample frames in sequential manner
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
if self.max_gap == -1:
|
||||
left = template_frame_ids[0]
|
||||
else:
|
||||
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
|
||||
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
|
||||
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)[0]
|
||||
|
||||
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
|
||||
search_frame_ids = valid_ids[:self.num_search_frames]
|
||||
|
||||
# if length is not enough
|
||||
last = search_frame_ids[-1]
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
if last >= len(visible) - 1:
|
||||
search_frame_ids.append(last)
|
||||
else:
|
||||
last += 1
|
||||
if visible[last]:
|
||||
search_frame_ids.append(last)
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def _random_interval_sample(self, visible):
|
||||
# Get valid ids
|
||||
valid_ids = [i for i in range(len(visible)) if visible[i]]
|
||||
|
||||
# Sample template frame
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
|
||||
if template_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == 0:
|
||||
template_frame_ids = [valid_ids[0]]
|
||||
break
|
||||
|
||||
# Sample first search frame
|
||||
if self.max_gap == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
else:
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
|
||||
template_frame_ids[0] + self.max_gap)
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)
|
||||
|
||||
if search_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
break
|
||||
|
||||
# Sample rest of the search frames with random interval
|
||||
last = search_frame_ids[0]
|
||||
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
|
||||
# sample id with interval
|
||||
max_id = min(last + self.max_interval + 1, len(visible))
|
||||
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
|
||||
max_id=max_id)
|
||||
|
||||
if id is None:
|
||||
# If not found in current range, find from previous range
|
||||
last = last + self.max_interval
|
||||
else:
|
||||
search_frame_ids.append(id[0])
|
||||
last = search_frame_ids[-1]
|
||||
|
||||
# if length is not enough, randomly sample new ids
|
||||
if len(search_frame_ids) < self.num_search_frames:
|
||||
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
|
||||
|
||||
if len(valid_ids) > 0:
|
||||
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
|
||||
self.num_search_frames - len(search_frame_ids)))
|
||||
search_frame_ids = search_frame_ids + new_ids
|
||||
search_frame_ids = sorted(search_frame_ids, key=int)
|
||||
|
||||
# if length is still not enough, duplicate last frame
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
search_frame_ids.append(search_frame_ids[-1])
|
||||
|
||||
for i in range(1, self.num_search_frames):
|
||||
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
|
||||
print(search_frame_ids[i] - search_frame_ids[i - 1])
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
if dataset.get_name() == 'got10k' :
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
else:
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
self.max_gap = max_gap * self.extra
|
||||
self.max_interval = max_interval * self.extra
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
while True:
|
||||
try:
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == 'sequential':
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
|
||||
elif self.frame_sample_mode == 'random_interval':
|
||||
if random.random() < self.prob:
|
||||
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
|
||||
else:
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
else:
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
#print(self.max_gap, self.max_interval)
|
||||
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
|
||||
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
|
||||
# print("====================================================================================")
|
||||
# print("dataset index: {}".format(index))
|
||||
# print("seq_id: {}".format(seq_id))
|
||||
# print('template_frame_ids: {}'.format(template_frame_ids))
|
||||
# print('search_frame_ids: {}'.format(search_frame_ids))
|
||||
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
|
||||
'template_annos': np.array(template_bbox).squeeze(),
|
||||
'search_images': np.array(search_frames), # (num_frames) search images
|
||||
'search_annos': np.array(search_bbox),
|
||||
'seq_id': seq_id,
|
||||
'dataset': dataset.get_name(),
|
||||
'search_class': meta_obj_search.get('object_class_name'),
|
||||
'num_frames': len(search_frames)
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,335 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
|
||||
|
||||
class Transform:
|
||||
"""A set of transformations, used for e.g. data augmentation.
|
||||
Args of constructor:
|
||||
transforms: An arbitrary number of transformations, derived from the TransformBase class.
|
||||
They are applied in the order they are given.
|
||||
|
||||
The Transform object can jointly transform images, bounding boxes and segmentation masks.
|
||||
This is done by calling the object with the following key-word arguments (all are optional).
|
||||
|
||||
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
|
||||
image - Image
|
||||
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
|
||||
bbox - Bounding box on the form [x, y, w, h]
|
||||
mask - Segmentation mask with discrete classes
|
||||
|
||||
The following parameters can be supplied with calling the transform object:
|
||||
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
|
||||
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
|
||||
different random rolls. Default: True.
|
||||
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
|
||||
is used instead. Default: True.
|
||||
|
||||
Check the DiMPProcessing class for examples.
|
||||
"""
|
||||
|
||||
def __init__(self, *transforms):
|
||||
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
|
||||
transforms = transforms[0]
|
||||
self.transforms = transforms
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['joint', 'new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
|
||||
def __call__(self, **inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
for v in inputs.keys():
|
||||
if v not in self._valid_all:
|
||||
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
|
||||
|
||||
joint_mode = inputs.get('joint', True)
|
||||
new_roll = inputs.get('new_roll', True)
|
||||
|
||||
if not joint_mode:
|
||||
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
|
||||
return tuple(list(o) for o in out)
|
||||
|
||||
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
|
||||
for t in self.transforms:
|
||||
out = t(**out, joint=joint_mode, new_roll=new_roll)
|
||||
if len(var_names) == 1:
|
||||
return out[var_names[0]]
|
||||
# Make sure order is correct
|
||||
return tuple(out[v] for v in var_names)
|
||||
|
||||
def _split_inputs(self, inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
|
||||
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
|
||||
if isinstance(arg_val, list):
|
||||
for inp, av in zip(split_inputs, arg_val):
|
||||
inp[arg_name] = av
|
||||
else:
|
||||
for inp in split_inputs:
|
||||
inp[arg_name] = arg_val
|
||||
return split_inputs
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class TransformBase:
|
||||
"""Base class for transformation objects. See the Transform class for details."""
|
||||
def __init__(self):
|
||||
"""2020.12.24 Add 'att' to valid inputs"""
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
self._rand_params = None
|
||||
|
||||
def __call__(self, **inputs):
|
||||
# Split input
|
||||
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
|
||||
|
||||
# Roll random parameters for the transform
|
||||
if input_args.get('new_roll', True):
|
||||
rand_params = self.roll()
|
||||
if rand_params is None:
|
||||
rand_params = ()
|
||||
elif not isinstance(rand_params, tuple):
|
||||
rand_params = (rand_params,)
|
||||
self._rand_params = rand_params
|
||||
|
||||
outputs = dict()
|
||||
for var_name, var in input_vars.items():
|
||||
if var is not None:
|
||||
transform_func = getattr(self, 'transform_' + var_name)
|
||||
if var_name in ['coords', 'bbox']:
|
||||
params = (self._get_image_size(input_vars),) + self._rand_params
|
||||
else:
|
||||
params = self._rand_params
|
||||
if isinstance(var, (list, tuple)):
|
||||
outputs[var_name] = [transform_func(x, *params) for x in var]
|
||||
else:
|
||||
outputs[var_name] = transform_func(var, *params)
|
||||
return outputs
|
||||
|
||||
def _get_image_size(self, inputs):
|
||||
im = None
|
||||
for var_name in ['image', 'mask']:
|
||||
if inputs.get(var_name) is not None:
|
||||
im = inputs[var_name]
|
||||
break
|
||||
if im is None:
|
||||
return None
|
||||
if isinstance(im, (list, tuple)):
|
||||
im = im[0]
|
||||
if isinstance(im, np.ndarray):
|
||||
return im.shape[:2]
|
||||
if torch.is_tensor(im):
|
||||
return (im.shape[-2], im.shape[-1])
|
||||
raise Exception('Unknown image type')
|
||||
|
||||
def roll(self):
|
||||
return None
|
||||
|
||||
def transform_image(self, image, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return coords
|
||||
|
||||
def transform_bbox(self, bbox, image_shape, *rand_params):
|
||||
"""Assumes [x, y, w, h]"""
|
||||
# Check if not overloaded
|
||||
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
|
||||
return bbox
|
||||
|
||||
coord = bbox.clone().view(-1,2).t().flip(0)
|
||||
|
||||
x1 = coord[1, 0]
|
||||
x2 = coord[1, 0] + coord[1, 1]
|
||||
|
||||
y1 = coord[0, 0]
|
||||
y2 = coord[0, 0] + coord[0, 1]
|
||||
|
||||
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
|
||||
|
||||
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
|
||||
tl = torch.min(coord_transf, dim=1)[0]
|
||||
sz = torch.max(coord_transf, dim=1)[0] - tl
|
||||
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
|
||||
return bbox_out
|
||||
|
||||
def transform_mask(self, mask, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, *rand_params):
|
||||
"""2020.12.24 Added to deal with attention masks"""
|
||||
return att
|
||||
|
||||
|
||||
class ToTensor(TransformBase):
|
||||
"""Convert to a Tensor"""
|
||||
|
||||
def transform_image(self, image):
|
||||
# handle numpy array
|
||||
if image.ndim == 2:
|
||||
image = image[:, :, None]
|
||||
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(image, torch.ByteTensor):
|
||||
return image.float().div(255)
|
||||
else:
|
||||
return image
|
||||
|
||||
def transfrom_mask(self, mask):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
|
||||
def transform_att(self, att):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class ToTensorAndJitter(TransformBase):
|
||||
"""Convert to a Tensor and jitter brightness"""
|
||||
def __init__(self, brightness_jitter=0.0, normalize=True):
|
||||
super().__init__()
|
||||
self.brightness_jitter = brightness_jitter
|
||||
self.normalize = normalize
|
||||
|
||||
def roll(self):
|
||||
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
|
||||
|
||||
def transform_image(self, image, brightness_factor):
|
||||
# handle numpy array
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
|
||||
# backward compatibility
|
||||
if self.normalize:
|
||||
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
|
||||
else:
|
||||
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
|
||||
|
||||
def transform_mask(self, mask, brightness_factor):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
else:
|
||||
return mask
|
||||
def transform_att(self, att, brightness_factor):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class Normalize(TransformBase):
|
||||
"""Normalize image"""
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
super().__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def transform_image(self, image):
|
||||
return tvisf.normalize(image, self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToGrayscale(TransformBase):
|
||||
"""Converts image to grayscale with probability"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_grayscale):
|
||||
if do_grayscale:
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
|
||||
return np.stack([img_gray, img_gray, img_gray], axis=2)
|
||||
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
|
||||
return image
|
||||
|
||||
|
||||
class ToBGR(TransformBase):
|
||||
"""Converts image to BGR"""
|
||||
def transform_image(self, image):
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
||||
return img_bgr
|
||||
|
||||
|
||||
class RandomHorizontalFlip(TransformBase):
|
||||
"""Horizontally flip image randomly with a probability p."""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(image):
|
||||
return image.flip((2,))
|
||||
return np.fliplr(image).copy()
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
|
||||
def transform_mask(self, mask, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(mask):
|
||||
return mask.flip((-1,))
|
||||
return np.fliplr(mask).copy()
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(att):
|
||||
return att.flip((-1,))
|
||||
return np.fliplr(att).copy()
|
||||
return att
|
||||
|
||||
|
||||
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
|
||||
"""Horizontally flip image randomly with a probability p.
|
||||
The difference is that the coord is normalized to [0,1]"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
"""we should use 1 rather than image_shape"""
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = 1 - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
@@ -0,0 +1,33 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install wandb" to install wandb')
|
||||
|
||||
|
||||
class WandbWriter:
|
||||
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
|
||||
self.wandb = wandb
|
||||
self.step = cur_step
|
||||
self.interval = step_interval
|
||||
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
|
||||
|
||||
def write_log(self, stats: OrderedDict, epoch=-1):
|
||||
self.step += 1
|
||||
for loader_name, loader_stats in stats.items():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
|
||||
log_dict = {}
|
||||
for var_name, val in loader_stats.items():
|
||||
if hasattr(val, 'avg'):
|
||||
log_dict.update({loader_name + '/' + var_name: val.avg})
|
||||
else:
|
||||
log_dict.update({loader_name + '/' + var_name: val.val})
|
||||
|
||||
if epoch >= 0:
|
||||
log_dict.update({loader_name + '/epoch': epoch})
|
||||
|
||||
self.wandb.log(log_dict, step=self.step*self.interval)
|
||||
@@ -0,0 +1,16 @@
|
||||
# README
|
||||
|
||||
## Description for different text files
|
||||
GOT10K
|
||||
- got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos)
|
||||
- got10k_train_split.txt: part of videos from the GOT-10K training set
|
||||
- got10k_val_split.txt: another part of videos from the GOT-10K training set
|
||||
- got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html))
|
||||
- got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set
|
||||
- got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set
|
||||
|
||||
LaSOT
|
||||
- lasot_train_split.txt: the complete LaSOT training set
|
||||
|
||||
TrackingNnet
|
||||
- trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,437 @@
|
||||
__author__ = 'tylin'
|
||||
__version__ = '2.0'
|
||||
# Interface for accessing the Microsoft COCO dataset.
|
||||
|
||||
# Microsoft COCO is a large image dataset designed for object detection,
|
||||
# segmentation, and caption generation. pycocotools is a Python API that
|
||||
# assists in loading, parsing and visualizing the annotations in COCO.
|
||||
# Please visit http://mscoco.org/ for more information on COCO, including
|
||||
# for the data, paper, and tutorials. The exact format of the annotations
|
||||
# is also described on the COCO website. For example usage of the pycocotools
|
||||
# please see pycocotools_demo.ipynb. In addition to this API, please download both
|
||||
# the COCO images and annotations in order to run the demo.
|
||||
|
||||
# An alternative to using the API is to load the annotations directly
|
||||
# into Python dictionary
|
||||
# Using the API provides additional utility functions. Note that this API
|
||||
# supports both *instance* and *caption* annotations. In the case of
|
||||
# captions not all functions are defined (e.g. categories are undefined).
|
||||
|
||||
# The following API functions are defined:
|
||||
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
|
||||
# decodeMask - Decode binary mask M encoded via run-length encoding.
|
||||
# encodeMask - Encode binary mask M using run-length encoding.
|
||||
# getAnnIds - Get ann ids that satisfy given filter conditions.
|
||||
# getCatIds - Get cat ids that satisfy given filter conditions.
|
||||
# getImgIds - Get img ids that satisfy given filter conditions.
|
||||
# loadAnns - Load anns with the specified ids.
|
||||
# loadCats - Load cats with the specified ids.
|
||||
# loadImgs - Load imgs with the specified ids.
|
||||
# annToMask - Convert segmentation in an annotation to binary mask.
|
||||
# showAnns - Display the specified annotations.
|
||||
# loadRes - Load algorithm results and create API for accessing them.
|
||||
# download - Download COCO images from mscoco.org server.
|
||||
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
|
||||
# Help on each functions can be accessed by: "help COCO>function".
|
||||
|
||||
# See also COCO>decodeMask,
|
||||
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
|
||||
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
|
||||
# COCO>loadImgs, COCO>annToMask, COCO>showAnns
|
||||
|
||||
# Microsoft COCO Toolbox. version 2.0
|
||||
# Data, paper, and tutorials available at: http://mscoco.org/
|
||||
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
|
||||
# Licensed under the Simplified BSD License [see bsd.txt]
|
||||
|
||||
import json
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
import numpy as np
|
||||
import copy
|
||||
import itertools
|
||||
from pycocotools import mask as maskUtils
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
PYTHON_VERSION = sys.version_info[0]
|
||||
if PYTHON_VERSION == 2:
|
||||
from urllib import urlretrieve
|
||||
elif PYTHON_VERSION == 3:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _isArrayLike(obj):
|
||||
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
||||
|
||||
|
||||
class COCO:
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
||||
:param annotation_file (str): location of annotation file
|
||||
:param image_folder (str): location to the folder that hosts images.
|
||||
:return:
|
||||
"""
|
||||
# load dataset
|
||||
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
||||
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
||||
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
||||
self.dataset = dataset
|
||||
self.createIndex()
|
||||
|
||||
def createIndex(self):
|
||||
# create index
|
||||
print('creating index...')
|
||||
anns, cats, imgs = {}, {}, {}
|
||||
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
|
||||
if 'annotations' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
anns[ann['id']] = ann
|
||||
|
||||
if 'images' in self.dataset:
|
||||
for img in self.dataset['images']:
|
||||
imgs[img['id']] = img
|
||||
|
||||
if 'categories' in self.dataset:
|
||||
for cat in self.dataset['categories']:
|
||||
cats[cat['id']] = cat
|
||||
|
||||
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
catToImgs[ann['category_id']].append(ann['image_id'])
|
||||
|
||||
print('index created!')
|
||||
|
||||
# create class members
|
||||
self.anns = anns
|
||||
self.imgToAnns = imgToAnns
|
||||
self.catToImgs = catToImgs
|
||||
self.imgs = imgs
|
||||
self.cats = cats
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Print information about the annotation file.
|
||||
:return:
|
||||
"""
|
||||
for key, value in self.dataset['info'].items():
|
||||
print('{}: {}'.format(key, value))
|
||||
|
||||
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
|
||||
"""
|
||||
Get ann ids that satisfy given filter conditions. default skips that filter
|
||||
:param imgIds (int array) : get anns for given imgs
|
||||
catIds (int array) : get anns for given cats
|
||||
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
||||
iscrowd (boolean) : get anns for given crowd label (False or True)
|
||||
:return: ids (int array) : integer array of ann ids
|
||||
"""
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == len(areaRng) == 0:
|
||||
anns = self.dataset['annotations']
|
||||
else:
|
||||
if not len(imgIds) == 0:
|
||||
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
|
||||
anns = list(itertools.chain.from_iterable(lists))
|
||||
else:
|
||||
anns = self.dataset['annotations']
|
||||
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
||||
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
|
||||
if not iscrowd == None:
|
||||
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
||||
else:
|
||||
ids = [ann['id'] for ann in anns]
|
||||
return ids
|
||||
|
||||
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
||||
"""
|
||||
filtering parameters. default skips that filter.
|
||||
:param catNms (str array) : get cats for given cat names
|
||||
:param supNms (str array) : get cats for given supercategory names
|
||||
:param catIds (int array) : get cats for given cat ids
|
||||
:return: ids (int array) : integer array of cat ids
|
||||
"""
|
||||
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
||||
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(catNms) == len(supNms) == len(catIds) == 0:
|
||||
cats = self.dataset['categories']
|
||||
else:
|
||||
cats = self.dataset['categories']
|
||||
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
||||
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
||||
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
||||
ids = [cat['id'] for cat in cats]
|
||||
return ids
|
||||
|
||||
def getImgIds(self, imgIds=[], catIds=[]):
|
||||
'''
|
||||
Get img ids that satisfy given filter conditions.
|
||||
:param imgIds (int array) : get imgs for given ids
|
||||
:param catIds (int array) : get imgs with all given cats
|
||||
:return: ids (int array) : integer array of img ids
|
||||
'''
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == 0:
|
||||
ids = self.imgs.keys()
|
||||
else:
|
||||
ids = set(imgIds)
|
||||
for i, catId in enumerate(catIds):
|
||||
if i == 0 and len(ids) == 0:
|
||||
ids = set(self.catToImgs[catId])
|
||||
else:
|
||||
ids &= set(self.catToImgs[catId])
|
||||
return list(ids)
|
||||
|
||||
def loadAnns(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying anns
|
||||
:return: anns (object array) : loaded ann objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.anns[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.anns[ids]]
|
||||
|
||||
def loadCats(self, ids=[]):
|
||||
"""
|
||||
Load cats with the specified ids.
|
||||
:param ids (int array) : integer ids specifying cats
|
||||
:return: cats (object array) : loaded cat objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.cats[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.cats[ids]]
|
||||
|
||||
def loadImgs(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying img
|
||||
:return: imgs (object array) : loaded img objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.imgs[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.imgs[ids]]
|
||||
|
||||
def showAnns(self, anns, draw_bbox=False):
|
||||
"""
|
||||
Display the specified annotations.
|
||||
:param anns (array of object): annotations to display
|
||||
:return: None
|
||||
"""
|
||||
if len(anns) == 0:
|
||||
return 0
|
||||
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
|
||||
datasetType = 'instances'
|
||||
elif 'caption' in anns[0]:
|
||||
datasetType = 'captions'
|
||||
else:
|
||||
raise Exception('datasetType not supported')
|
||||
if datasetType == 'instances':
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in anns:
|
||||
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
|
||||
if 'segmentation' in ann:
|
||||
if type(ann['segmentation']) == list:
|
||||
# polygon
|
||||
for seg in ann['segmentation']:
|
||||
poly = np.array(seg).reshape((int(len(seg)/2), 2))
|
||||
polygons.append(Polygon(poly))
|
||||
color.append(c)
|
||||
else:
|
||||
# mask
|
||||
t = self.imgs[ann['image_id']]
|
||||
if type(ann['segmentation']['counts']) == list:
|
||||
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
|
||||
else:
|
||||
rle = [ann['segmentation']]
|
||||
m = maskUtils.decode(rle)
|
||||
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
||||
if ann['iscrowd'] == 1:
|
||||
color_mask = np.array([2.0,166.0,101.0])/255
|
||||
if ann['iscrowd'] == 0:
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack( (img, m*0.5) ))
|
||||
if 'keypoints' in ann and type(ann['keypoints']) == list:
|
||||
# turn skeleton into zero-based index
|
||||
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
|
||||
kp = np.array(ann['keypoints'])
|
||||
x = kp[0::3]
|
||||
y = kp[1::3]
|
||||
v = kp[2::3]
|
||||
for sk in sks:
|
||||
if np.all(v[sk]>0):
|
||||
plt.plot(x[sk],y[sk], linewidth=3, color=c)
|
||||
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
|
||||
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
|
||||
|
||||
if draw_bbox:
|
||||
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
||||
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
||||
np_poly = np.array(poly).reshape((4,2))
|
||||
polygons.append(Polygon(np_poly))
|
||||
color.append(c)
|
||||
|
||||
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||||
ax.add_collection(p)
|
||||
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
|
||||
ax.add_collection(p)
|
||||
elif datasetType == 'captions':
|
||||
for ann in anns:
|
||||
print(ann['caption'])
|
||||
|
||||
def loadRes(self, resFile):
|
||||
"""
|
||||
Load result file and return a result api object.
|
||||
:param resFile (str) : file name of result file
|
||||
:return: res (obj) : result api object
|
||||
"""
|
||||
res = COCO()
|
||||
res.dataset['images'] = [img for img in self.dataset['images']]
|
||||
|
||||
print('Loading and preparing results...')
|
||||
tic = time.time()
|
||||
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
||||
with open(resFile) as f:
|
||||
anns = json.load(f)
|
||||
elif type(resFile) == np.ndarray:
|
||||
anns = self.loadNumpyAnnotations(resFile)
|
||||
else:
|
||||
anns = resFile
|
||||
assert type(anns) == list, 'results in not an array of objects'
|
||||
annsImgIds = [ann['image_id'] for ann in anns]
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
|
||||
'Results do not correspond to current coco set'
|
||||
if 'caption' in anns[0]:
|
||||
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
|
||||
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
|
||||
for id, ann in enumerate(anns):
|
||||
ann['id'] = id+1
|
||||
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
bb = ann['bbox']
|
||||
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
|
||||
if not 'segmentation' in ann:
|
||||
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
ann['area'] = bb[2]*bb[3]
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'segmentation' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
# now only support compressed RLE format as segmentation results
|
||||
ann['area'] = maskUtils.area(ann['segmentation'])
|
||||
if not 'bbox' in ann:
|
||||
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'keypoints' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
s = ann['keypoints']
|
||||
x = s[0::3]
|
||||
y = s[1::3]
|
||||
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
||||
ann['area'] = (x1-x0)*(y1-y0)
|
||||
ann['id'] = id + 1
|
||||
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
|
||||
print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
||||
|
||||
res.dataset['annotations'] = anns
|
||||
res.createIndex()
|
||||
return res
|
||||
|
||||
def download(self, tarDir = None, imgIds = [] ):
|
||||
'''
|
||||
Download COCO images from mscoco.org server.
|
||||
:param tarDir (str): COCO results directory name
|
||||
imgIds (list): images to be downloaded
|
||||
:return:
|
||||
'''
|
||||
if tarDir is None:
|
||||
print('Please specify target directory')
|
||||
return -1
|
||||
if len(imgIds) == 0:
|
||||
imgs = self.imgs.values()
|
||||
else:
|
||||
imgs = self.loadImgs(imgIds)
|
||||
N = len(imgs)
|
||||
if not os.path.exists(tarDir):
|
||||
os.makedirs(tarDir)
|
||||
for i, img in enumerate(imgs):
|
||||
tic = time.time()
|
||||
fname = os.path.join(tarDir, img['file_name'])
|
||||
if not os.path.exists(fname):
|
||||
urlretrieve(img['coco_url'], fname)
|
||||
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
|
||||
|
||||
def loadNumpyAnnotations(self, data):
|
||||
"""
|
||||
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
|
||||
:param data (numpy.ndarray)
|
||||
:return: annotations (python nested list)
|
||||
"""
|
||||
print('Converting ndarray to lists...')
|
||||
assert(type(data) == np.ndarray)
|
||||
print(data.shape)
|
||||
assert(data.shape[1] == 7)
|
||||
N = data.shape[0]
|
||||
ann = []
|
||||
for i in range(N):
|
||||
if i % 1000000 == 0:
|
||||
print('{}/{}'.format(i,N))
|
||||
ann += [{
|
||||
'image_id' : int(data[i, 0]),
|
||||
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
|
||||
'score' : data[i, 5],
|
||||
'category_id': int(data[i, 6]),
|
||||
}]
|
||||
return ann
|
||||
|
||||
def annToRLE(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
t = self.imgs[ann['image_id']]
|
||||
h, w = t['height'], t['width']
|
||||
segm = ann['segmentation']
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = ann['segmentation']
|
||||
return rle
|
||||
|
||||
def annToMask(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
rle = self.annToRLE(ann)
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
||||
@@ -0,0 +1,11 @@
|
||||
from .lasot import Lasot
|
||||
from .got10k import Got10k
|
||||
from .tracking_net import TrackingNet
|
||||
from .imagenetvid import ImagenetVID
|
||||
from .coco import MSCOCO
|
||||
from .coco_seq import MSCOCOSeq
|
||||
from .got10k_lmdb import Got10k_lmdb
|
||||
from .lasot_lmdb import Lasot_lmdb
|
||||
from .imagenetvid_lmdb import ImagenetVID_lmdb
|
||||
from .coco_seq_lmdb import MSCOCOSeq_lmdb
|
||||
from .tracking_net_lmdb import TrackingNet_lmdb
|
||||
@@ -0,0 +1,92 @@
|
||||
import torch.utils.data
|
||||
from lib.train.data.image_loader import jpeg4py_loader
|
||||
|
||||
|
||||
class BaseImageDataset(torch.utils.data.Dataset):
|
||||
""" Base class for image datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.image_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_images()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_images(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.image_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def get_class_name(self, image_id):
|
||||
return None
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_image_info(self, seq_id):
|
||||
""" Returns information about a particular image,
|
||||
|
||||
args:
|
||||
seq_id - index of the image
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
""" Get a image
|
||||
|
||||
args:
|
||||
image_id - index of image
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
image -
|
||||
anno -
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
import torch.utils.data
|
||||
# 2021.1.5 use jpeg4py_loader_w_failsafe as default
|
||||
from lib.train.data.image_loader import jpeg4py_loader_w_failsafe
|
||||
|
||||
|
||||
class BaseVideoDataset(torch.utils.data.Dataset):
|
||||
""" Base class for video datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.sequence_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_sequences()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def is_video_sequence(self):
|
||||
""" Returns whether the dataset is a video dataset or an image dataset
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return True
|
||||
|
||||
def is_synthetic_video_dataset(self):
|
||||
""" Returns whether the dataset contains real videos or synthetic
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_sequences(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.sequence_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
""" Returns information about a particular sequences,
|
||||
|
||||
args:
|
||||
seq_id - index of the sequence
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
""" Get a set of frames from a particular sequence
|
||||
|
||||
args:
|
||||
seq_id - index of sequence
|
||||
frame_ids - a list of frame numbers
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
list - List of frames corresponding to frame_ids
|
||||
list - List of dicts for each frame
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
from .base_image_dataset import BaseImageDataset
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
|
||||
class MSCOCO(BaseImageDataset):
|
||||
""" The COCO object detection dataset.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None,
|
||||
split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to coco root folder
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
min_area - Objects with area less than min_area are filtered out. Default is 0.0
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list() # the parent class thing would happen in the sampler
|
||||
|
||||
self.image_list = self._get_image_list(min_area=min_area)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction))
|
||||
self.im_per_class = self._build_im_per_class()
|
||||
|
||||
def _get_image_list(self, min_area=None):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
if min_area is not None:
|
||||
image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area]
|
||||
|
||||
return image_list
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def _build_im_per_class(self):
|
||||
im_per_class = {}
|
||||
for i, im in enumerate(self.image_list):
|
||||
class_name = self.cats[self.coco_set.anns[im]['category_id']]['name']
|
||||
if class_name not in im_per_class:
|
||||
im_per_class[class_name] = [i]
|
||||
else:
|
||||
im_per_class[class_name].append(i)
|
||||
|
||||
return im_per_class
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
return self.im_per_class[class_name]
|
||||
|
||||
def get_image_info(self, im_id):
|
||||
anno = self._get_anno(im_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(4,)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno))
|
||||
|
||||
valid = (bbox[2] > 0) & (bbox[3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, im_id):
|
||||
anno = self.coco_set.anns[self.image_list[im_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_image(self, im_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, im_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def get_class_name(self, im_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
frame = self._get_image(image_id)
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_image_info(image_id)
|
||||
|
||||
object_meta = self.get_meta_info(image_id)
|
||||
|
||||
return frame, anno, object_meta
|
||||
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from pycocotools.coco import COCO
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class MSCOCOSeq(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
# Load the COCO set.
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.train.dataset.COCO_tool import COCO
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
import time
|
||||
|
||||
class MSCOCOSeq_lmdb(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO_lmdb', root, image_loader)
|
||||
self.root = root
|
||||
self.img_pth = 'images/{}{}/'.format(split, version)
|
||||
self.anno_path = 'annotations/instances_{}{}.json'.format(split, version)
|
||||
|
||||
# Load the COCO set.
|
||||
print('loading annotations into memory...')
|
||||
tic = time.time()
|
||||
coco_json = decode_json(root, self.anno_path)
|
||||
print('Done (t={:0.2f}s)'.format(time.time() - tic))
|
||||
|
||||
self.coco_set = COCO(coco_json)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
# img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
img = decode_img(self.root, os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Got10k(BaseVideoDataset):
|
||||
""" GOT-10k dataset.
|
||||
|
||||
Publication:
|
||||
GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
|
||||
Lianghua Huang, Xin Zhao, and Kaiqi Huang
|
||||
arXiv:1810.11981, 2018
|
||||
https://arxiv.org/pdf/1810.11981.pdf
|
||||
|
||||
Download dataset from http://got-10k.aitestunion.com/downloads
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().got10k_dir if root is None else root
|
||||
super().__init__('GOT10k', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
seq_ids = pandas.read_csv(file_path, header=None, dtype=np.int64).squeeze("columns").values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
sequence_meta_info = {s: self._read_meta(os.path.join(self.root, s)) for s in self.sequence_list}
|
||||
return sequence_meta_info
|
||||
|
||||
def _read_meta(self, seq_path):
|
||||
try:
|
||||
with open(os.path.join(seq_path, 'meta_info.ini')) as f:
|
||||
meta_info = f.readlines()
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1][:-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1][:-1],
|
||||
'major_class': meta_info[7].split(': ')[-1][:-1],
|
||||
'root_class': meta_info[8].split(': ')[-1][:-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1][:-1]})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
with open(os.path.join(self.root, 'list.txt')) as f:
|
||||
dir_list = list(csv.reader(f))
|
||||
dir_list = [dir_name[0] for dir_name in dir_list]
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
with open(cover_file, 'r', newline='') as f:
|
||||
cover = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join(self.root, self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
||||
@@ -0,0 +1,183 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
'''2021.1.16 Gok10k for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Got10k_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
use_lmdb - whether the dataset is stored in lmdb format
|
||||
"""
|
||||
root = env_settings().got10k_lmdb_dir if root is None else root
|
||||
super().__init__('GOT10k_lmdb', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
train_lib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
def _read_meta(meta_info):
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1],
|
||||
'major_class': meta_info[7].split(': ')[-1],
|
||||
'root_class': meta_info[8].split(': ')[-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1]})
|
||||
|
||||
return object_meta
|
||||
sequence_meta_info = {}
|
||||
for s in self.sequence_list:
|
||||
try:
|
||||
meta_str = decode_str(self.root, "train/%s/meta_info.ini" %s)
|
||||
sequence_meta_info[s] = _read_meta(meta_str.split('\n'))
|
||||
except:
|
||||
sequence_meta_info[s] = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return sequence_meta_info
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
dir_str = decode_str(self.root, 'train/list.txt')
|
||||
dir_list = dir_str.split('\n')
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line in got10k is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# full occlusion and out_of_view files
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
# Read these files
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
cover_list = list(map(int, decode_str(self.root, cover_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
cover = torch.ByteTensor(cover_list)
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join("train", self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
||||
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import xml.etree.ElementTree as ET
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid", root, image_loader)
|
||||
|
||||
cache_file = os.path.join(root, 'cache.json')
|
||||
if os.path.isfile(cache_file):
|
||||
# If available, load the pre-processed cache file containing meta-info for each sequence
|
||||
with open(cache_file, 'r') as f:
|
||||
sequence_list_dict = json.load(f)
|
||||
|
||||
self.sequence_list = sequence_list_dict
|
||||
else:
|
||||
# Else process the imagenet annotations and generate the cache file
|
||||
self.sequence_list = self._process_anno(root)
|
||||
|
||||
with open(cache_file, 'w') as f:
|
||||
json.dump(self.sequence_list, f)
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join(self.root, 'Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
||||
def _process_anno(self, root):
|
||||
# Builds individual tracklets
|
||||
base_vid_anno_path = os.path.join(root, 'Annotations', 'VID', 'train')
|
||||
|
||||
all_sequences = []
|
||||
for set in sorted(os.listdir(base_vid_anno_path)):
|
||||
set_id = int(set.split('_')[-1])
|
||||
for vid in sorted(os.listdir(os.path.join(base_vid_anno_path, set))):
|
||||
|
||||
vid_id = int(vid.split('_')[-1])
|
||||
anno_files = sorted(os.listdir(os.path.join(base_vid_anno_path, set, vid)))
|
||||
|
||||
frame1_anno = ET.parse(os.path.join(base_vid_anno_path, set, vid, anno_files[0]))
|
||||
image_size = [int(frame1_anno.find('size/width').text), int(frame1_anno.find('size/height').text)]
|
||||
|
||||
objects = [ET.ElementTree(file=os.path.join(base_vid_anno_path, set, vid, f)).findall('object')
|
||||
for f in anno_files]
|
||||
|
||||
tracklets = {}
|
||||
|
||||
# Find all tracklets along with start frame
|
||||
for f_id, all_targets in enumerate(objects):
|
||||
for target in all_targets:
|
||||
tracklet_id = target.find('trackid').text
|
||||
if tracklet_id not in tracklets:
|
||||
tracklets[tracklet_id] = f_id
|
||||
|
||||
for tracklet_id, tracklet_start in tracklets.items():
|
||||
tracklet_anno = []
|
||||
target_visible = []
|
||||
class_name_id = None
|
||||
|
||||
for f_id in range(tracklet_start, len(objects)):
|
||||
found = False
|
||||
for target in objects[f_id]:
|
||||
if target.find('trackid').text == tracklet_id:
|
||||
if not class_name_id:
|
||||
class_name_id = target.find('name').text
|
||||
x1 = int(target.find('bndbox/xmin').text)
|
||||
y1 = int(target.find('bndbox/ymin').text)
|
||||
x2 = int(target.find('bndbox/xmax').text)
|
||||
y2 = int(target.find('bndbox/ymax').text)
|
||||
|
||||
tracklet_anno.append([x1, y1, x2 - x1, y2 - y1])
|
||||
target_visible.append(target.find('occluded').text == '0')
|
||||
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
break
|
||||
|
||||
new_sequence = {'set_id': set_id, 'vid_id': vid_id, 'class_name': class_name_id,
|
||||
'start_frame': tracklet_start, 'anno': tracklet_anno,
|
||||
'target_visible': target_visible, 'image_size': image_size}
|
||||
all_sequences.append(new_sequence)
|
||||
|
||||
return all_sequences
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID_lmdb(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid_lmdb", root, image_loader)
|
||||
|
||||
sequence_list_dict = decode_json(root, "cache.json")
|
||||
self.sequence_list = sequence_list_dict
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid_lmdb'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return decode_img(self.root, frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Lasot(BaseVideoDataset):
|
||||
""" LaSOT dataset.
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_dir if root is None else root
|
||||
super().__init__('LaSOT', root, image_loader)
|
||||
|
||||
# Keep a list of all classes
|
||||
self.class_list = [f for f in os.listdir(self.root)]
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
sequence_list = pandas.read_csv(file_path, header=None).squeeze("columns").values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
with open(out_of_view_file, 'r') as f:
|
||||
out_of_view = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(self.root, class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
'''2021.1.16 Lasot for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Lasot_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_lmdb_dir if root is None else root
|
||||
super().__init__('LaSOT_lmdb', root, image_loader)
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
class_list = [seq_name.split('-')[0] for seq_name in self.sequence_list]
|
||||
self.class_list = []
|
||||
for ele in class_list:
|
||||
if ele not in self.class_list:
|
||||
self.class_list.append(ele)
|
||||
# Keep a list of all classes
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split(',')))
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
out_view_list = list(map(int, decode_str(self.root, out_of_view_file).split(',')))
|
||||
out_of_view = torch.ByteTensor(out_view_list)
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,151 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def list_sequences(root, set_ids):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
set_ids: Sets (0-11) which are to be used
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
sequence_list = []
|
||||
|
||||
for s in set_ids:
|
||||
anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno")
|
||||
|
||||
sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
|
||||
sequence_list += sequences_cur_set
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_dir if root is None else root
|
||||
super().__init__('TrackingNet', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root, self.set_ids)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
bb_anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False,
|
||||
low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg")
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
import json
|
||||
from lib.utils.lmdb_utils import decode_img, decode_str
|
||||
|
||||
|
||||
def list_sequences(root):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
fname = os.path.join(root, "seq_list.json")
|
||||
with open(fname, "r") as f:
|
||||
sequence_list = json.loads(f.read())
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet_lmdb(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_lmdb_dir if root is None else root
|
||||
super().__init__('TrackingNet_lmdb', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
gt_str_list = decode_str(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("anno", vid_name + ".txt")).split('\n')[:-1]
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
return decode_img(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("frames", vid_name, str(frame_id) + ".jpg"))
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import importlib
|
||||
import cv2 as cv
|
||||
import torch.backends.cudnn
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
import _init_paths
|
||||
import lib.train.admin.settings as ws_settings
|
||||
|
||||
|
||||
def init_seeds(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(1)
|
||||
cv.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None,
|
||||
use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False,
|
||||
distill=None, script_teacher=None, config_teacher=None):
|
||||
"""Run the train script.
|
||||
args:
|
||||
script_name: Name of emperiment in the "experiments/" folder.
|
||||
config_name: Name of the yaml file in the "experiments/<script_name>".
|
||||
cudnn_benchmark: Use cudnn benchmark or not (default is True).
|
||||
"""
|
||||
if save_dir is None:
|
||||
print("save_dir dir is not given. Use the default dir instead.")
|
||||
# This is needed to avoid strange crashes related to opencv
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(4)
|
||||
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
|
||||
print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name))
|
||||
|
||||
'''2021.1.5 set seed for different process'''
|
||||
if base_seed is not None:
|
||||
if local_rank != -1:
|
||||
init_seeds(base_seed + local_rank)
|
||||
else:
|
||||
init_seeds(base_seed)
|
||||
|
||||
settings = ws_settings.Settings()
|
||||
settings.script_name = script_name
|
||||
settings.config_name = config_name
|
||||
settings.project_path = 'train/{}/{}'.format(script_name, config_name)
|
||||
if script_name_prv is not None and config_name_prv is not None:
|
||||
settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv)
|
||||
settings.local_rank = local_rank
|
||||
settings.save_dir = os.path.abspath(save_dir)
|
||||
settings.use_lmdb = use_lmdb
|
||||
prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name))
|
||||
settings.use_wandb = use_wandb
|
||||
if distill:
|
||||
settings.distill = distill
|
||||
settings.script_teacher = script_teacher
|
||||
settings.config_teacher = config_teacher
|
||||
if script_teacher is not None and config_teacher is not None:
|
||||
settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher)
|
||||
settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher))
|
||||
expr_module = importlib.import_module('lib.train.train_script_distill')
|
||||
else:
|
||||
expr_module = importlib.import_module('lib.train.train_script')
|
||||
expr_func = getattr(expr_module, 'run')
|
||||
|
||||
expr_func(settings)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
|
||||
parser.add_argument('--script', type=str, required=True, help='Name of the train script.')
|
||||
parser.add_argument('--config', type=str, required=True, help="Name of the config file.")
|
||||
parser.add_argument('--cudnn_benchmark', type=bool, default=False, help='Set cudnn benchmark on (1) or off (0) (default is on).')
|
||||
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
|
||||
parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs')
|
||||
parser.add_argument('--seed', type=int, default=42, help='seed for random numbers')
|
||||
parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format
|
||||
parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.')
|
||||
parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.")
|
||||
parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb
|
||||
# for knowledge distillation
|
||||
parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation
|
||||
parser.add_argument('--script_teacher', type=str, help='teacher script name')
|
||||
parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.local_rank != -1:
|
||||
dist.init_process_group(backend='nccl')
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
else:
|
||||
torch.cuda.set_device(0)
|
||||
run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark,
|
||||
local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed,
|
||||
use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv,
|
||||
use_wandb=args.use_wandb,
|
||||
distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer, LTRSeqTrainer
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader, sequence_sampler
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
# forward propagation related
|
||||
from lib.train.actors import ARTrackActor, ARTrackSeqActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
from ..utils.focal_loss import FocalLoss
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
def slt_collate(batch):
|
||||
ret = {}
|
||||
for k in batch[0].keys():
|
||||
here_list = []
|
||||
for ex in batch:
|
||||
here_list.append(ex[k])
|
||||
ret[k] = here_list
|
||||
return ret
|
||||
|
||||
class SLTLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = slt_collate
|
||||
|
||||
super(SLTLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
bins = cfg.MODEL.BINS
|
||||
search_size = cfg.DATA.SEARCH.SIZE
|
||||
# Create network
|
||||
if settings.script_name == "artrack":
|
||||
net = build_artrack(cfg)
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
net = build_artrack_seq(cfg)
|
||||
dataset_train = sequence_sampler.SequenceSampler(
|
||||
datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_GAP, max_interval=cfg.DATA.MAX_INTERVAL,
|
||||
num_search_frames=cfg.DATA.SEARCH.NUMBER, num_template_frames=1,
|
||||
frame_sample_mode='random_interval',
|
||||
prob=cfg.DATA.INTERVAL_PROB)
|
||||
loader_train = SLTLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER,
|
||||
shuffle=False, drop_last=True)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
if settings.local_rank != -1:
|
||||
# net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "artrack":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackSeqActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# if cfg.TRAIN.DEEP_SUPERVISION:
|
||||
# raise ValueError("Deep supervision is not supported now.")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
if settings.script_name == "artrack":
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
trainer = LTRSeqTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|
||||
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.stark import build_starks, build_starkst
|
||||
from lib.models.stark import build_stark_lightning_x_trt
|
||||
# forward propagation related
|
||||
from lib.train.actors import STARKLightningXtrtdistillActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
|
||||
def build_network(script_name, cfg):
|
||||
# Create network
|
||||
if script_name == "stark_s":
|
||||
net = build_starks(cfg)
|
||||
elif script_name == "stark_st1" or script_name == "stark_st2":
|
||||
net = build_starkst(cfg)
|
||||
elif script_name == "stark_lightning_X_trt":
|
||||
net = build_stark_lightning_x_trt(cfg, phase="train")
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
return net
|
||||
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update the default teacher configs with teacher config file
|
||||
if not os.path.exists(settings.cfg_file_teacher):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file_teacher)
|
||||
config_module_teacher = importlib.import_module("lib.config.%s.config" % settings.script_teacher)
|
||||
cfg_teacher = config_module_teacher.cfg
|
||||
config_module_teacher.update_config_from_file(settings.cfg_file_teacher)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New teacher configuration is shown below.")
|
||||
for key in cfg_teacher.keys():
|
||||
print("%s configuration:" % key, cfg_teacher[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
"""turn on the distillation mode"""
|
||||
cfg.TRAIN.DISTILL = True
|
||||
cfg_teacher.TRAIN.DISTILL = True
|
||||
net = build_network(settings.script_name, cfg)
|
||||
net_teacher = build_network(settings.script_teacher, cfg_teacher)
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
net_teacher.cuda()
|
||||
net_teacher.eval()
|
||||
|
||||
if settings.local_rank != -1:
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
net_teacher = DDP(net_teacher, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
# settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
# settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "L1")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "stark_lightning_X_trt":
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
|
||||
actor = STARKLightningXtrtdistillActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings,
|
||||
net_teacher=net_teacher)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True, distill=True)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base_trainer import BaseTrainer
|
||||
from .ltr_trainer import LTRTrainer
|
||||
from .ltr_seq_trainer import LTRSeqTrainer
|
||||
@@ -0,0 +1,275 @@
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import traceback
|
||||
from lib.train.admin import multigpu
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""Base trainer class. Contains functions for training and saving/loading checkpoints.
|
||||
Trainer classes should inherit from this one and overload the train_epoch function."""
|
||||
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
self.actor = actor
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.loaders = loaders
|
||||
|
||||
self.update_settings(settings)
|
||||
|
||||
self.epoch = 0
|
||||
self.stats = {}
|
||||
|
||||
self.device = getattr(settings, 'device', None)
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() and settings.use_gpu else "cpu")
|
||||
|
||||
self.actor.to(self.device)
|
||||
self.settings = settings
|
||||
|
||||
def update_settings(self, settings=None):
|
||||
"""Updates the trainer settings. Must be called to update internal settings."""
|
||||
if settings is not None:
|
||||
self.settings = settings
|
||||
|
||||
if self.settings.env.workspace_dir is not None:
|
||||
self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
|
||||
'''2021.1.4 New function: specify checkpoint dir'''
|
||||
if self.settings.save_dir is None:
|
||||
self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
|
||||
else:
|
||||
self._checkpoint_dir = os.path.join(self.settings.save_dir, 'checkpoints')
|
||||
print("checkpoints will be saved to %s" % self._checkpoint_dir)
|
||||
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(self._checkpoint_dir):
|
||||
print("Training with multiple GPUs. checkpoints directory doesn't exist. "
|
||||
"Create checkpoints directory")
|
||||
os.makedirs(self._checkpoint_dir)
|
||||
else:
|
||||
self._checkpoint_dir = None
|
||||
|
||||
def train(self, max_epochs, load_latest=False, fail_safe=True, load_previous_ckpt=False, distill=False):
|
||||
"""Do training for the given number of epochs.
|
||||
args:
|
||||
max_epochs - Max number of training epochs,
|
||||
load_latest - Bool indicating whether to resume from latest epoch.
|
||||
fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
|
||||
"""
|
||||
|
||||
epoch = -1
|
||||
num_tries = 1
|
||||
for i in range(num_tries):
|
||||
try:
|
||||
if load_latest:
|
||||
self.load_checkpoint()
|
||||
if load_previous_ckpt:
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_prv)
|
||||
self.load_state_dict(directory)
|
||||
if distill:
|
||||
directory_teacher = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_teacher)
|
||||
self.load_state_dict(directory_teacher, distill=True)
|
||||
for epoch in range(self.epoch+1, max_epochs+1):
|
||||
self.epoch = epoch
|
||||
|
||||
self.train_epoch()
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
if self.settings.scheduler_type != 'cosine':
|
||||
self.lr_scheduler.step()
|
||||
else:
|
||||
self.lr_scheduler.step(epoch - 1)
|
||||
# only save the last 10 checkpoints
|
||||
save_every_epoch = getattr(self.settings, "save_every_epoch", False)
|
||||
save_epochs = []
|
||||
if epoch > (max_epochs - 1) or save_every_epoch or epoch % 5 == 0 or epoch in save_epochs or epoch > (max_epochs - 5):
|
||||
# if epoch > (max_epochs - 10) or save_every_epoch or epoch % 100 == 0:
|
||||
if self._checkpoint_dir:
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
self.save_checkpoint()
|
||||
except:
|
||||
print('Training crashed at epoch {}'.format(epoch))
|
||||
if fail_safe:
|
||||
self.epoch -= 1
|
||||
load_latest = True
|
||||
print('Traceback for the error!')
|
||||
print(traceback.format_exc())
|
||||
print('Restarting training from last epoch ...')
|
||||
else:
|
||||
raise
|
||||
|
||||
print('Finished training!')
|
||||
|
||||
def train_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self):
|
||||
"""Saves a checkpoint of the network and other variables."""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'actor_type': actor_type,
|
||||
'net_type': net_type,
|
||||
'net': net.state_dict(),
|
||||
'net_info': getattr(net, 'info', None),
|
||||
'constructor': getattr(net, 'constructor', None),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'stats': self.stats,
|
||||
'settings': self.settings
|
||||
}
|
||||
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
|
||||
print(directory)
|
||||
if not os.path.exists(directory):
|
||||
print("directory doesn't exist. creating...")
|
||||
os.makedirs(directory)
|
||||
|
||||
# First save as a tmp file
|
||||
tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
|
||||
torch.save(state, tmp_file_path)
|
||||
|
||||
file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
|
||||
|
||||
# Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
|
||||
os.rename(tmp_file_path, file_path)
|
||||
|
||||
def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
|
||||
if checkpoint is None:
|
||||
# Load most recent checkpoint
|
||||
checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
|
||||
self.settings.project_path, net_type)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
print('No matching checkpoint file found')
|
||||
return
|
||||
elif isinstance(checkpoint, int):
|
||||
# Checkpoint is the epoch number
|
||||
checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
|
||||
net_type, checkpoint)
|
||||
elif isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
if fields is None:
|
||||
fields = checkpoint_dict.keys()
|
||||
if ignore_fields is None:
|
||||
ignore_fields = ['settings']
|
||||
|
||||
# Never load the scheduler. It exists in older checkpoints.
|
||||
ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])
|
||||
|
||||
# Load all fields
|
||||
for key in fields:
|
||||
if key in ignore_fields:
|
||||
continue
|
||||
if key == 'net':
|
||||
net.load_state_dict(checkpoint_dict[key])
|
||||
elif key == 'optimizer':
|
||||
self.optimizer.load_state_dict(checkpoint_dict[key])
|
||||
else:
|
||||
setattr(self, key, checkpoint_dict[key])
|
||||
|
||||
# Set the net info
|
||||
if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
|
||||
net.constructor = checkpoint_dict['constructor']
|
||||
if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
|
||||
net.info = checkpoint_dict['net_info']
|
||||
|
||||
# Update the epoch in lr scheduler
|
||||
if 'epoch' in fields:
|
||||
self.lr_scheduler.last_epoch = self.epoch
|
||||
# 2021.1.10 Update the epoch in data_samplers
|
||||
for loader in self.loaders:
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
return True
|
||||
|
||||
def load_state_dict(self, checkpoint=None, distill=False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
if distill:
|
||||
net = self.actor.net_teacher.module if multigpu.is_multi_gpu(self.actor.net_teacher) \
|
||||
else self.actor.net_teacher
|
||||
else:
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
net_type = type(net).__name__
|
||||
|
||||
if isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
print("Loading pretrained model from ", checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
|
||||
print("previous checkpoint is loaded.")
|
||||
print("missing keys: ", missing_k)
|
||||
print("unexpected keys:", unexpected_k)
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
# from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
from memory_profiler import profile
|
||||
# from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRSeqTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
# self.wandb_writer = None
|
||||
# if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
self.miou_list = []
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.actor.eval()
|
||||
self.data_read_done_time = time.time()
|
||||
with torch.no_grad():
|
||||
explore_result = self.actor.explore(data)
|
||||
if explore_result == None:
|
||||
print("this time i skip")
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
continue
|
||||
# get inputs
|
||||
# print(data)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
|
||||
stats = {}
|
||||
reward_record = []
|
||||
miou_record = []
|
||||
e_miou_record = []
|
||||
num_seq = len(data['num_frames'])
|
||||
|
||||
# Calculate reward tensor
|
||||
# reward_tensor = torch.zeros(explore_result['baseline_iou'].size())
|
||||
baseline_iou = explore_result['baseline_iou']
|
||||
# explore_iou = explore_result['explore_iou']
|
||||
for seq_idx in range(num_seq):
|
||||
num_frames = data['num_frames'][seq_idx] - 1
|
||||
b_miou = torch.mean(baseline_iou[:num_frames, seq_idx])
|
||||
# e_miou = torch.mean(explore_iou[:num_frames, seq_idx])
|
||||
miou_record.append(b_miou.item())
|
||||
# e_miou_record.append(e_miou.item())
|
||||
|
||||
b_reward = b_miou.item()
|
||||
# e_reward = e_miou.item()
|
||||
# iou_gap = e_reward - b_reward
|
||||
# reward_record.append(iou_gap)
|
||||
# reward_tensor[:num_frames, seq_idx] = iou_gap
|
||||
|
||||
# Training mode
|
||||
cursor = 0
|
||||
bs_backward = 1
|
||||
|
||||
# print(self.actor.net.module.box_head.decoder.layers[2].mlpx.fc1.weight)
|
||||
self.optimizer.zero_grad()
|
||||
while cursor < num_seq:
|
||||
# print("now is ", cursor , "and all is ", num_seq)
|
||||
model_inputs = {}
|
||||
model_inputs['slt_loss_weight'] = 15
|
||||
if cursor < num_seq:
|
||||
model_inputs['template_images'] = explore_result['template_images'][
|
||||
cursor:cursor + bs_backward].cuda()
|
||||
else:
|
||||
model_inputs['template_images'] = explore_result['template_images_reverse'][
|
||||
cursor - num_seq:cursor - num_seq + bs_backward].cuda()
|
||||
model_inputs['search_images'] = explore_result['search_images'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['search_anno'] = explore_result['search_anno'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['pre_seq'] = explore_result['pre_seq'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['x_feat'] = explore_result['x_feat'].squeeze(1)[:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['epoch'] = data['epoch']
|
||||
# model_inputs['template_update'] = explore_result['template_update'].squeeze(1)[:,
|
||||
# cursor:cursor + bs_backward].cuda()
|
||||
# print("this is cursor")
|
||||
# print(explore_result['pre_seq'].shape)
|
||||
# print(explore_result['x_feat'].squeeze(1).shape)
|
||||
# model_inputs['action_tensor'] = explore_result['action_tensor'][:, cursor:cursor + bs_backward].cuda()
|
||||
# model_inputs['reward_tensor'] = reward_tensor[:, cursor:cursor + bs_backward].cuda()
|
||||
|
||||
loss, stats_cur = self.actor.compute_sequence_losses(model_inputs)
|
||||
# for name, param in self.actor.net.named_parameters():
|
||||
# shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
|
||||
# print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')
|
||||
# print("i make this!")
|
||||
loss.backward()
|
||||
# print("i made that?")
|
||||
|
||||
for key, val in stats_cur.items():
|
||||
if key in stats:
|
||||
stats[key] += val * (bs_backward / num_seq)
|
||||
else:
|
||||
stats[key] = val * (bs_backward / num_seq)
|
||||
cursor += bs_backward
|
||||
grad_norm = clip_grad_norm_(self.actor.net.parameters(), 100)
|
||||
stats['grad_norm'] = grad_norm
|
||||
# print(self.actor.net.module.backbone.blocks[8].mlp.fc1.weight)
|
||||
self.optimizer.step()
|
||||
# print(self.optimizer)
|
||||
|
||||
miou = np.mean(miou_record)
|
||||
self.miou_list.append(miou)
|
||||
# stats['reward'] = np.mean(reward_record)
|
||||
# stats['e_mIoU'] = np.mean(e_miou_record)
|
||||
stats['mIoU'] = miou
|
||||
stats['mIoU10'] = np.mean(self.miou_list[-10:])
|
||||
stats['mIoU100'] = np.mean(self.miou_list[-100:])
|
||||
|
||||
batch_size = num_seq * np.max(data['num_frames'])
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
self._print_stats(i, loader, batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# # forward pass
|
||||
# if not self.use_amp:
|
||||
# loss, stats = self.actor(data)
|
||||
# else:
|
||||
# with autocast():
|
||||
# loss, stats = self.actor(data)
|
||||
#
|
||||
# # backward pass and update weights
|
||||
# if loader.training:
|
||||
# self.optimizer.zero_grad()
|
||||
# if not self.use_amp:
|
||||
# loss.backward()
|
||||
# if self.settings.grad_clip_norm > 0:
|
||||
# torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
# self.optimizer.step()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
# self.scaler.step(self.optimizer)
|
||||
# self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
# batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
# self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
# if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
# epoch_time = self.prev_time - self.start_time
|
||||
# print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
# print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
# print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
# print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (
|
||||
self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
# def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
||||
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
|
||||
#from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
#from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
#self.wandb_writer = None
|
||||
#if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.data_read_done_time = time.time()
|
||||
# get inputs
|
||||
if self.move_data_to_gpu:
|
||||
data = data.to(self.device)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
# forward pass
|
||||
if not self.use_amp:
|
||||
loss, stats = self.actor(data)
|
||||
else:
|
||||
with autocast():
|
||||
loss, stats = self.actor(data)
|
||||
|
||||
# backward pass and update weights
|
||||
if loader.training:
|
||||
self.optimizer.zero_grad()
|
||||
if not self.use_amp:
|
||||
loss.backward()
|
||||
if self.settings.grad_clip_norm > 0:
|
||||
torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
#if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
epoch_time = self.prev_time - self.start_time
|
||||
print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
#if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
#def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
||||
@@ -0,0 +1 @@
|
||||
from .tensor import TensorDict, TensorList
|
||||
@@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
import numpy as np
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xywh_to_xyxy(x):
|
||||
x1, y1, w, h = x.unbind(-1)
|
||||
b = [x1, y1, x1 + w, y1 + h]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(x):
|
||||
x1, y1, x2, y2 = x.unbind(-1)
|
||||
b = [x1, y1, x2 - x1, y2 - y1]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
'''Note that this function only supports shape (N,4)'''
|
||||
|
||||
|
||||
def box_iou(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
area1 = box_area(boxes1) # (N,)
|
||||
area2 = box_area(boxes2) # (N,)
|
||||
|
||||
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2)
|
||||
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2)
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
inter = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
union = area1 + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
'''Note that this implementation is different from DETR's'''
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
boxes1: (N, 4)
|
||||
boxes2: (N, 4)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
# try:
|
||||
#assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
# assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2) # (N,)
|
||||
|
||||
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
area = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
return iou - (area - union) / area, iou
|
||||
|
||||
|
||||
def giou_loss(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
giou, iou = generalized_box_iou(boxes1, boxes2)
|
||||
return (1 - giou).mean(), iou
|
||||
|
||||
|
||||
def clip_box(box: list, H, W, margin=0):
|
||||
x1, y1, w, h = box
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
x1 = min(max(0, x1), W-margin)
|
||||
x2 = min(max(margin, x2), W)
|
||||
y1 = min(max(0, y1), H-margin)
|
||||
y2 = min(max(margin, y2), H)
|
||||
w = max(margin, x2-x1)
|
||||
h = max(margin, y2-y1)
|
||||
return [x1, y1, w, h]
|
||||
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def generate_bbox_mask(bbox_mask, bbox):
|
||||
b, h, w = bbox_mask.shape
|
||||
for i in range(b):
|
||||
bbox_i = bbox[i].cpu().tolist()
|
||||
bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1
|
||||
return bbox_mask
|
||||
|
||||
|
||||
def generate_mask_cond(cfg, bs, device, gt_bbox):
|
||||
template_size = cfg.DATA.TEMPLATE.SIZE
|
||||
stride = cfg.MODEL.BACKBONE.STRIDE
|
||||
template_feat_size = template_size // stride
|
||||
|
||||
if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL':
|
||||
box_mask_z = None
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT':
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 6)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 14:
|
||||
index = slice(6, 7)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC':
|
||||
# use fixed 4x4 region, 3:5 for 8x8
|
||||
# use fixed 4x4 region 5:6 for 12x12
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 5)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 7)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX':
|
||||
box_mask_z = torch.zeros([bs, template_size, template_size], device=device)
|
||||
# box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128)
|
||||
box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to(
|
||||
torch.float) # (batch, 1, 128, 128)
|
||||
# box_mask_z_vis = box_mask_z.cpu().numpy()
|
||||
box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear',
|
||||
align_corners=False)
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
# box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy()
|
||||
# gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return box_mask_z
|
||||
|
||||
|
||||
def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1):
|
||||
if epoch < warmup_epochs:
|
||||
return 1
|
||||
if epoch >= total_epochs:
|
||||
return base_keep_rate
|
||||
if iters == -1:
|
||||
iters = epoch * ITERS_PER_EPOCH
|
||||
total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs)
|
||||
iters = iters - ITERS_PER_EPOCH * warmup_epochs
|
||||
keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \
|
||||
* (math.cos(iters / total_iters * math.pi) + 1) * 0.5
|
||||
|
||||
return keep_rate
|
||||
@@ -0,0 +1,63 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.Module, ABC):
|
||||
def __init__(self, alpha=2, beta=4):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, prediction, target):
|
||||
positive_index = target.eq(1).float()
|
||||
negative_index = target.lt(1).float()
|
||||
|
||||
negative_weights = torch.pow(1 - target, self.beta)
|
||||
# clamp min value is set to 1e-12 to maintain the numerical stability
|
||||
prediction = torch.clamp(prediction, 1e-12)
|
||||
|
||||
positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index
|
||||
negative_loss = torch.log(1 - prediction) * torch.pow(prediction,
|
||||
self.alpha) * negative_weights * negative_index
|
||||
|
||||
num_positive = positive_index.float().sum()
|
||||
positive_loss = positive_loss.sum()
|
||||
negative_loss = negative_loss.sum()
|
||||
|
||||
if num_positive == 0:
|
||||
loss = -negative_loss
|
||||
else:
|
||||
loss = -(positive_loss + negative_loss) / num_positive
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LBHinge(nn.Module):
|
||||
"""Loss that uses a 'hinge' on the lower bound.
|
||||
This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
|
||||
also smaller than that threshold.
|
||||
args:
|
||||
error_matric: What base loss to use (MSE by default).
|
||||
threshold: Threshold to use for the hinge.
|
||||
clip: Clip the loss if it is above this value.
|
||||
"""
|
||||
def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None):
|
||||
super().__init__()
|
||||
self.error_metric = error_metric
|
||||
self.threshold = threshold if threshold is not None else -100
|
||||
self.clip = clip
|
||||
|
||||
def forward(self, prediction, label, target_bb=None):
|
||||
negative_mask = (label < self.threshold).float()
|
||||
positive_mask = (1.0 - negative_mask)
|
||||
|
||||
prediction = negative_mask * F.relu(prediction) + positive_mask * prediction
|
||||
|
||||
loss = self.error_metric(prediction, positive_mask * label)
|
||||
|
||||
if self.clip is not None:
|
||||
loss = torch.min(loss, torch.tensor([self.clip], device=loss.device))
|
||||
return loss
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user