init commit of samurai
This commit is contained in:
159
.gitignore
vendored
Normal file
159
.gitignore
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# evaluation results
|
||||
evaluation_results/*
|
||||
raw_results/*
|
||||
debug/*
|
||||
4
data/.gitignore
vendored
Normal file
4
data/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
||||
0
lib/test/__init__.py
Normal file
0
lib/test/__init__.py
Normal file
0
lib/test/analysis/__init__.py
Normal file
0
lib/test/analysis/__init__.py
Normal file
226
lib/test/analysis/extract_results.py
Normal file
226
lib/test/analysis/extract_results.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import numpy as np
|
||||
from lib.test.utils.load_text import load_text
|
||||
import torch
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
env_path = os.path.join(os.path.dirname(__file__), '../../..')
|
||||
if env_path not in sys.path:
|
||||
sys.path.append(env_path)
|
||||
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def calc_err_center(pred_bb, anno_bb, normalized=False):
|
||||
pred_center = pred_bb[:, :2] + 0.5 * (pred_bb[:, 2:] - 1.0)
|
||||
anno_center = anno_bb[:, :2] + 0.5 * (anno_bb[:, 2:] - 1.0)
|
||||
|
||||
if normalized:
|
||||
pred_center = pred_center / anno_bb[:, 2:]
|
||||
anno_center = anno_center / anno_bb[:, 2:]
|
||||
|
||||
err_center = ((pred_center - anno_center)**2).sum(1).sqrt()
|
||||
return err_center
|
||||
|
||||
|
||||
def calc_iou_overlap(pred_bb, anno_bb):
|
||||
tl = torch.max(pred_bb[:, :2], anno_bb[:, :2])
|
||||
br = torch.min(pred_bb[:, :2] + pred_bb[:, 2:] - 1.0, anno_bb[:, :2] + anno_bb[:, 2:] - 1.0)
|
||||
sz = (br - tl + 1.0).clamp(0)
|
||||
|
||||
# Area
|
||||
intersection = sz.prod(dim=1)
|
||||
union = pred_bb[:, 2:].prod(dim=1) + anno_bb[:, 2:].prod(dim=1) - intersection
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
def calc_seq_err_robust(pred_bb, anno_bb, dataset, target_visible=None):
|
||||
pred_bb = pred_bb.clone()
|
||||
|
||||
# Check if invalid values are present
|
||||
if torch.isnan(pred_bb).any() or (pred_bb[:, 2:] < 0.0).any():
|
||||
raise Exception('Error: Invalid results')
|
||||
|
||||
if torch.isnan(anno_bb).any():
|
||||
if dataset == 'uav':
|
||||
pass
|
||||
else:
|
||||
raise Exception('Warning: NaNs in annotation')
|
||||
|
||||
if (pred_bb[:, 2:] == 0.0).any():
|
||||
for i in range(1, pred_bb.shape[0]):
|
||||
if i >= anno_bb.shape[0]:
|
||||
continue
|
||||
if (pred_bb[i, 2:] == 0.0).any() and not torch.isnan(anno_bb[i, :]).any():
|
||||
pred_bb[i, :] = pred_bb[i-1, :]
|
||||
|
||||
if pred_bb.shape[0] != anno_bb.shape[0]:
|
||||
if dataset == 'lasot':
|
||||
if pred_bb.shape[0] > anno_bb.shape[0]:
|
||||
# For monkey-17, there is a mismatch for some trackers.
|
||||
pred_bb = pred_bb[:anno_bb.shape[0], :]
|
||||
else:
|
||||
raise Exception('Mis-match in tracker prediction and GT lengths')
|
||||
else:
|
||||
# print('Warning: Mis-match in tracker prediction and GT lengths')
|
||||
if pred_bb.shape[0] > anno_bb.shape[0]:
|
||||
pred_bb = pred_bb[:anno_bb.shape[0], :]
|
||||
else:
|
||||
pad = torch.zeros((anno_bb.shape[0] - pred_bb.shape[0], 4)).type_as(pred_bb)
|
||||
pred_bb = torch.cat((pred_bb, pad), dim=0)
|
||||
|
||||
pred_bb[0, :] = anno_bb[0, :]
|
||||
|
||||
if target_visible is not None:
|
||||
target_visible = target_visible.bool()
|
||||
valid = ((anno_bb[:, 2:] > 0.0).sum(1) == 2) & target_visible
|
||||
else:
|
||||
valid = ((anno_bb[:, 2:] > 0.0).sum(1) == 2)
|
||||
|
||||
err_center = calc_err_center(pred_bb, anno_bb)
|
||||
err_center_normalized = calc_err_center(pred_bb, anno_bb, normalized=True)
|
||||
err_overlap = calc_iou_overlap(pred_bb, anno_bb)
|
||||
|
||||
# handle invalid anno cases
|
||||
if dataset in ['uav']:
|
||||
err_center[~valid] = -1.0
|
||||
else:
|
||||
err_center[~valid] = float("Inf")
|
||||
err_center_normalized[~valid] = -1.0
|
||||
err_overlap[~valid] = -1.0
|
||||
|
||||
if dataset == 'lasot':
|
||||
err_center_normalized[~target_visible] = float("Inf")
|
||||
err_center[~target_visible] = float("Inf")
|
||||
|
||||
if torch.isnan(err_overlap).any():
|
||||
raise Exception('Nans in calculated overlap')
|
||||
return err_overlap, err_center, err_center_normalized, valid
|
||||
|
||||
|
||||
def extract_results(trackers, dataset, report_name, skip_missing_seq=False, plot_bin_gap=0.05,
|
||||
exclude_invalid_frames=False):
|
||||
settings = env_settings()
|
||||
eps = 1e-16
|
||||
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
|
||||
if not os.path.exists(result_plot_path):
|
||||
os.makedirs(result_plot_path)
|
||||
|
||||
threshold_set_overlap = torch.arange(0.0, 1.0 + plot_bin_gap, plot_bin_gap, dtype=torch.float64)
|
||||
threshold_set_center = torch.arange(0, 51, dtype=torch.float64)
|
||||
threshold_set_center_norm = torch.arange(0, 51, dtype=torch.float64) / 100.0
|
||||
|
||||
avg_overlap_all = torch.zeros((len(dataset), len(trackers)), dtype=torch.float64)
|
||||
ave_success_rate_plot_overlap = torch.zeros((len(dataset), len(trackers), threshold_set_overlap.numel()),
|
||||
dtype=torch.float32)
|
||||
ave_success_rate_plot_center = torch.zeros((len(dataset), len(trackers), threshold_set_center.numel()),
|
||||
dtype=torch.float32)
|
||||
ave_success_rate_plot_center_norm = torch.zeros((len(dataset), len(trackers), threshold_set_center.numel()),
|
||||
dtype=torch.float32)
|
||||
|
||||
from collections import defaultdict
|
||||
# default dict of default dict of list
|
||||
|
||||
|
||||
valid_sequence = torch.ones(len(dataset), dtype=torch.uint8)
|
||||
|
||||
for seq_id, seq in enumerate(tqdm(dataset)):
|
||||
frame_success_rate_plot_overlap = defaultdict(lambda: defaultdict(list))
|
||||
frame_success_rate_plot_center = defaultdict(lambda: defaultdict(list))
|
||||
frame_success_rate_plot_center_norm = defaultdict(lambda: defaultdict(list))
|
||||
# Load anno
|
||||
anno_bb = torch.tensor(seq.ground_truth_rect)
|
||||
target_visible = torch.tensor(seq.target_visible, dtype=torch.uint8) if seq.target_visible is not None else None
|
||||
for trk_id, trk in enumerate(trackers):
|
||||
# Load results
|
||||
base_results_path = '{}/{}'.format(trk.results_dir, seq.name)
|
||||
results_path = '{}.txt'.format(base_results_path)
|
||||
|
||||
if os.path.isfile(results_path):
|
||||
pred_bb = torch.tensor(load_text(str(results_path), delimiter=('\t', ','), dtype=np.float64))
|
||||
else:
|
||||
if skip_missing_seq:
|
||||
valid_sequence[seq_id] = 0
|
||||
break
|
||||
else:
|
||||
raise Exception('Result not found. {}'.format(results_path))
|
||||
|
||||
# Calculate measures
|
||||
err_overlap, err_center, err_center_normalized, valid_frame = calc_seq_err_robust(
|
||||
pred_bb, anno_bb, seq.dataset, target_visible)
|
||||
|
||||
avg_overlap_all[seq_id, trk_id] = err_overlap[valid_frame].mean()
|
||||
|
||||
if exclude_invalid_frames:
|
||||
seq_length = valid_frame.long().sum()
|
||||
else:
|
||||
seq_length = anno_bb.shape[0]
|
||||
|
||||
if seq_length <= 0:
|
||||
raise Exception('Seq length zero')
|
||||
|
||||
ave_success_rate_plot_overlap[seq_id, trk_id, :] = (err_overlap.view(-1, 1) > threshold_set_overlap.view(1, -1)).sum(0).float() / seq_length
|
||||
ave_success_rate_plot_center[seq_id, trk_id, :] = (err_center.view(-1, 1) <= threshold_set_center.view(1, -1)).sum(0).float() / seq_length
|
||||
ave_success_rate_plot_center_norm[seq_id, trk_id, :] = (err_center_normalized.view(-1, 1) <= threshold_set_center_norm.view(1, -1)).sum(0).float() / seq_length
|
||||
|
||||
# for frame_id in range(seq_length):
|
||||
# frame_success_rate_plot_overlap[trk_id][frame_id].append((err_overlap[frame_id]).item())
|
||||
# frame_success_rate_plot_center[trk_id][frame_id].append((err_center[frame_id]).item())
|
||||
# frame_success_rate_plot_center_norm[trk_id][frame_id].append((err_center_normalized[frame_id] < 0.2).item())
|
||||
|
||||
# output_folder = "../cvpr2025/per_frame_success_rate"
|
||||
# os.makedirs(output_folder, exist_ok=True)
|
||||
# with open(osp.join(output_folder, f"{seq.name}.txt"), 'w') as f:
|
||||
# for frame_id in range(seq_length):
|
||||
# suc_score = frame_success_rate_plot_overlap[trk_id][frame_id][0]
|
||||
# f.write(f"{suc_score}\n")
|
||||
|
||||
# # plot the average success rate, center normalized for each tracker
|
||||
# # y axis: success rate
|
||||
# # x axis: frame number
|
||||
# # different color for each tracker
|
||||
# # save the plot as a figure
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.figure(figsize=(10, 6))
|
||||
# for trk_id, trk in enumerate(trackers):
|
||||
# list_to_plot = [np.mean(frame_success_rate_plot_overlap[trk_id][frame_id]) for frame_id in range(2000)]
|
||||
# # smooth the curve; window size = 10
|
||||
# smooth_list_to_plot = np.convolve(list_to_plot, np.ones((10,))/10, mode='valid')
|
||||
# # the smooth curve and non smooth curve have the same label
|
||||
# plt.plot(smooth_list_to_plot, label=trk.display_name, alpha=1)
|
||||
# plt.xlabel('Frame Number')
|
||||
# plt.ylabel('Success Rate')
|
||||
# plt.title('Average Success Rate Over Frames')
|
||||
# plt.legend()
|
||||
# plt.grid(True)
|
||||
# plt.savefig('average_success_rate_plot_overlap.png')
|
||||
# plt.close()
|
||||
|
||||
|
||||
print('\n\nComputed results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
# Prepare dictionary for saving data
|
||||
seq_names = [s.name for s in dataset]
|
||||
tracker_names = [{'name': t.name, 'param': t.parameter_name, 'run_id': t.run_id, 'disp_name': t.display_name}
|
||||
for t in trackers]
|
||||
|
||||
eval_data = {'sequences': seq_names, 'trackers': tracker_names,
|
||||
'valid_sequence': valid_sequence.tolist(),
|
||||
'ave_success_rate_plot_overlap': ave_success_rate_plot_overlap.tolist(),
|
||||
'ave_success_rate_plot_center': ave_success_rate_plot_center.tolist(),
|
||||
'ave_success_rate_plot_center_norm': ave_success_rate_plot_center_norm.tolist(),
|
||||
'avg_overlap_all': avg_overlap_all.tolist(),
|
||||
'threshold_set_overlap': threshold_set_overlap.tolist(),
|
||||
'threshold_set_center': threshold_set_center.tolist(),
|
||||
'threshold_set_center_norm': threshold_set_center_norm.tolist()}
|
||||
|
||||
with open(result_plot_path + '/eval_data.pkl', 'wb') as fh:
|
||||
pickle.dump(eval_data, fh)
|
||||
|
||||
return eval_data
|
||||
796
lib/test/analysis/plot_results.py
Normal file
796
lib/test/analysis/plot_results.py
Normal file
@@ -0,0 +1,796 @@
|
||||
import tikzplotlib
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path as osp
|
||||
import torch
|
||||
import pickle
|
||||
import json
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.test.analysis.extract_results import extract_results
|
||||
|
||||
|
||||
def get_plot_draw_styles():
|
||||
plot_draw_style = [
|
||||
# {'color': (1.0, 0.0, 0.0), 'line_style': '-'},
|
||||
# {'color': (0.0, 1.0, 0.0), 'line_style': '-'},
|
||||
{'color': (0.0, 1.0, 0.0), 'line_style': '-'},
|
||||
{'color': (0.0, 0.0, 0.0), 'line_style': '-'},
|
||||
{'color': (1.0, 0.0, 1.0), 'line_style': '-'},
|
||||
{'color': (0.0, 1.0, 1.0), 'line_style': '-'},
|
||||
{'color': (0.5, 0.5, 0.5), 'line_style': '-'},
|
||||
{'color': (136.0 / 255.0, 0.0, 21.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (1.0, 127.0 / 255.0, 39.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (0.0, 162.0 / 255.0, 232.0 / 255.0), 'line_style': '-'},
|
||||
{'color': (0.0, 0.5, 0.0), 'line_style': '-'},
|
||||
{'color': (1.0, 0.5, 0.2), 'line_style': '-'},
|
||||
{'color': (0.1, 0.4, 0.0), 'line_style': '-'},
|
||||
{'color': (0.6, 0.3, 0.9), 'line_style': '-'},
|
||||
{'color': (0.4, 0.7, 0.1), 'line_style': '-'},
|
||||
{'color': (0.2, 0.1, 0.7), 'line_style': '-'},
|
||||
{'color': (0.7, 0.6, 0.2), 'line_style': '-'}]
|
||||
|
||||
return plot_draw_style
|
||||
|
||||
|
||||
def check_eval_data_is_valid(eval_data, trackers, dataset):
|
||||
""" Checks if the pre-computed results are valid"""
|
||||
seq_names = [s.name for s in dataset]
|
||||
seq_names_saved = eval_data['sequences']
|
||||
|
||||
tracker_names_f = [(t.name, t.parameter_name, t.run_id) for t in trackers]
|
||||
tracker_names_f_saved = [(t['name'], t['param'], t['run_id']) for t in eval_data['trackers']]
|
||||
|
||||
return seq_names == seq_names_saved and tracker_names_f == tracker_names_f_saved
|
||||
|
||||
|
||||
def merge_multiple_runs(eval_data):
|
||||
new_tracker_names = []
|
||||
ave_success_rate_plot_overlap_merged = []
|
||||
ave_success_rate_plot_center_merged = []
|
||||
ave_success_rate_plot_center_norm_merged = []
|
||||
avg_overlap_all_merged = []
|
||||
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
avg_overlap_all = torch.tensor(eval_data['avg_overlap_all'])
|
||||
|
||||
trackers = eval_data['trackers']
|
||||
merged = torch.zeros(len(trackers), dtype=torch.uint8)
|
||||
for i in range(len(trackers)):
|
||||
if merged[i]:
|
||||
continue
|
||||
base_tracker = trackers[i]
|
||||
new_tracker_names.append(base_tracker)
|
||||
|
||||
match = [t['name'] == base_tracker['name'] and t['param'] == base_tracker['param'] for t in trackers]
|
||||
match = torch.tensor(match)
|
||||
|
||||
ave_success_rate_plot_overlap_merged.append(ave_success_rate_plot_overlap[:, match, :].mean(1))
|
||||
ave_success_rate_plot_center_merged.append(ave_success_rate_plot_center[:, match, :].mean(1))
|
||||
ave_success_rate_plot_center_norm_merged.append(ave_success_rate_plot_center_norm[:, match, :].mean(1))
|
||||
avg_overlap_all_merged.append(avg_overlap_all[:, match].mean(1))
|
||||
|
||||
merged[match] = 1
|
||||
|
||||
ave_success_rate_plot_overlap_merged = torch.stack(ave_success_rate_plot_overlap_merged, dim=1)
|
||||
ave_success_rate_plot_center_merged = torch.stack(ave_success_rate_plot_center_merged, dim=1)
|
||||
ave_success_rate_plot_center_norm_merged = torch.stack(ave_success_rate_plot_center_norm_merged, dim=1)
|
||||
avg_overlap_all_merged = torch.stack(avg_overlap_all_merged, dim=1)
|
||||
|
||||
eval_data['trackers'] = new_tracker_names
|
||||
eval_data['ave_success_rate_plot_overlap'] = ave_success_rate_plot_overlap_merged.tolist()
|
||||
eval_data['ave_success_rate_plot_center'] = ave_success_rate_plot_center_merged.tolist()
|
||||
eval_data['ave_success_rate_plot_center_norm'] = ave_success_rate_plot_center_norm_merged.tolist()
|
||||
eval_data['avg_overlap_all'] = avg_overlap_all_merged.tolist()
|
||||
|
||||
return eval_data
|
||||
|
||||
|
||||
def get_tracker_display_name(tracker):
|
||||
if tracker['disp_name'] is None:
|
||||
if tracker['run_id'] is None:
|
||||
disp_name = '{}_{}'.format(tracker['name'], tracker['param'])
|
||||
else:
|
||||
disp_name = '{}_{}_{:03d}'.format(tracker['name'], tracker['param'],
|
||||
tracker['run_id'])
|
||||
else:
|
||||
disp_name = tracker['disp_name']
|
||||
|
||||
return disp_name
|
||||
|
||||
|
||||
def plot_draw_save(y, x, scores, trackers, plot_draw_styles, result_plot_path, plot_opts):
|
||||
plt.rcParams['text.usetex']=True
|
||||
plt.rcParams["font.family"] = "Times New Roman"
|
||||
# Plot settings
|
||||
font_size = plot_opts.get('font_size', 25)
|
||||
font_size_axis = plot_opts.get('font_size_axis', 20)
|
||||
line_width = plot_opts.get('line_width', 2)
|
||||
font_size_legend = plot_opts.get('font_size_legend', 15)
|
||||
|
||||
plot_type = plot_opts['plot_type']
|
||||
legend_loc = plot_opts['legend_loc']
|
||||
if 'attr' in plot_opts:
|
||||
attr = plot_opts['attr']
|
||||
else:
|
||||
attr = None
|
||||
|
||||
xlabel = plot_opts['xlabel']
|
||||
ylabel = plot_opts['ylabel']
|
||||
ylabel = "%s"%(ylabel.replace('%','\%'))
|
||||
xlim = plot_opts['xlim']
|
||||
ylim = plot_opts['ylim']
|
||||
|
||||
title = r"\textbf{%s}" %(plot_opts['title'])
|
||||
print
|
||||
|
||||
matplotlib.rcParams.update({'font.size': font_size})
|
||||
matplotlib.rcParams.update({'axes.titlesize': font_size_axis})
|
||||
matplotlib.rcParams.update({'axes.titleweight': 'black'})
|
||||
matplotlib.rcParams.update({'axes.labelsize': font_size_axis})
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
index_sort = scores.argsort(descending=False)
|
||||
|
||||
plotted_lines = []
|
||||
legend_text = []
|
||||
|
||||
for id, id_sort in enumerate(index_sort):
|
||||
if trackers[id_sort]['disp_name'].startswith('SAMURAI'):
|
||||
alpha = 1.0
|
||||
line_style = '-'
|
||||
if trackers[id_sort]['disp_name'] == 'SAMURAI-L':
|
||||
color = (1.0, 0.0, 0.0)
|
||||
elif trackers[id_sort]['disp_name'] == 'SAMURAI-B':
|
||||
color = (0.0, 0.0, 1.0)
|
||||
elif trackers[id_sort]['disp_name'].startswith('SAM2.1'):
|
||||
alpha = 0.8
|
||||
line_style = '--'
|
||||
if trackers[id_sort]['disp_name'] == 'SAM2.1-L':
|
||||
color = (1.0, 0.0, 0.0)
|
||||
elif trackers[id_sort]['disp_name'] == 'SAM2.1-B':
|
||||
color = (0.0, 0.0, 1.0)
|
||||
else:
|
||||
alpha = 0.5
|
||||
color = plot_draw_styles[index_sort.numel() - id - 1]['color']
|
||||
line_style = ":"
|
||||
line = ax.plot(x.tolist(), y[id_sort, :].tolist(),
|
||||
linewidth=line_width,
|
||||
color=color,
|
||||
linestyle=line_style,
|
||||
alpha=alpha)
|
||||
|
||||
plotted_lines.append(line[0])
|
||||
|
||||
tracker = trackers[id_sort]
|
||||
disp_name = get_tracker_display_name(tracker)
|
||||
|
||||
legend_text.append('{} [{:.1f}]'.format(disp_name, scores[id_sort]))
|
||||
|
||||
try:
|
||||
# add bold to top method
|
||||
# for i in range(1,2):
|
||||
# legend_text[-i] = r'\textbf{%s}'%(legend_text[-i])
|
||||
|
||||
for id, id_sort in enumerate(index_sort):
|
||||
if trackers[id_sort]['disp_name'].startswith('SAMTrack'):
|
||||
legend_text[id] = r'\textbf{%s}'%(legend_text[id])
|
||||
|
||||
ax.legend(plotted_lines[::-1], legend_text[::-1], loc=legend_loc, fancybox=False, edgecolor='black',
|
||||
fontsize=font_size_legend, framealpha=1.0)
|
||||
except:
|
||||
pass
|
||||
|
||||
ax.set(xlabel=xlabel,
|
||||
ylabel=ylabel,
|
||||
xlim=xlim, ylim=ylim,
|
||||
title=title)
|
||||
|
||||
ax.grid(True, linestyle='-.')
|
||||
fig.tight_layout()
|
||||
|
||||
def tikzplotlib_fix_ncols(obj):
|
||||
"""
|
||||
workaround for matplotlib 3.6 renamed legend's _ncol to _ncols, which breaks tikzplotlib
|
||||
"""
|
||||
if hasattr(obj, "_ncols"):
|
||||
obj._ncol = obj._ncols
|
||||
for child in obj.get_children():
|
||||
tikzplotlib_fix_ncols(child)
|
||||
|
||||
tikzplotlib_fix_ncols(fig)
|
||||
|
||||
# tikzplotlib.save('{}/{}_plot.tex'.format(result_plot_path, plot_type))
|
||||
if attr is not None:
|
||||
fig.savefig('{}/{}_{}_plot.pdf'.format(result_plot_path, plot_type, attr), dpi=300, format='pdf', transparent=True)
|
||||
else:
|
||||
fig.savefig('{}/{}_plot.pdf'.format(result_plot_path, plot_type), dpi=300, format='pdf', transparent=True)
|
||||
plt.draw()
|
||||
|
||||
|
||||
def check_and_load_precomputed_results(trackers, dataset, report_name, force_evaluation=False, **kwargs):
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data_path = os.path.join(result_plot_path, 'eval_data.pkl')
|
||||
|
||||
if os.path.isfile(eval_data_path) and not force_evaluation:
|
||||
with open(eval_data_path, 'rb') as fh:
|
||||
eval_data = pickle.load(fh)
|
||||
else:
|
||||
# print('Pre-computed evaluation data not found. Computing results!')
|
||||
eval_data = extract_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
if not check_eval_data_is_valid(eval_data, trackers, dataset):
|
||||
# print('Pre-computed evaluation data invalid. Re-computing results!')
|
||||
eval_data = extract_results(trackers, dataset, report_name, **kwargs)
|
||||
# pass
|
||||
else:
|
||||
# Update display names
|
||||
tracker_names = [{'name': t.name, 'param': t.parameter_name, 'run_id': t.run_id, 'disp_name': t.display_name}
|
||||
for t in trackers]
|
||||
eval_data['trackers'] = tracker_names
|
||||
with open(eval_data_path, 'wb') as fh:
|
||||
pickle.dump(eval_data, fh)
|
||||
return eval_data
|
||||
|
||||
|
||||
def get_auc_curve(ave_success_rate_plot_overlap, valid_sequence):
|
||||
ave_success_rate_plot_overlap = ave_success_rate_plot_overlap[valid_sequence, :, :]
|
||||
auc_curve = ave_success_rate_plot_overlap.mean(0) * 100.0
|
||||
auc = auc_curve.mean(-1)
|
||||
|
||||
return auc_curve, auc
|
||||
|
||||
|
||||
def get_prec_curve(ave_success_rate_plot_center, valid_sequence):
|
||||
ave_success_rate_plot_center = ave_success_rate_plot_center[valid_sequence, :, :]
|
||||
prec_curve = ave_success_rate_plot_center.mean(0) * 100.0
|
||||
prec_score = prec_curve[:, 20]
|
||||
|
||||
return prec_curve, prec_score
|
||||
|
||||
def plot_per_attribute_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
attr_folder = 'data/LaSOT/att'
|
||||
|
||||
attr_list = ['Illumination Variation', 'Partial Occlusion', 'Deformation', 'Motion Blur', 'Camera Motion', 'Rotation', 'Background Clutter', 'Viewpoint Change', 'Scale Variation', 'Full Occlusion', 'Fast Motion', 'Out-of-View', 'Low Resolution', 'Aspect Ration Change']
|
||||
attr_list = ['IV', 'POC', 'DEF', 'MB', 'CM', 'ROT', 'BC', 'VC', 'SV', 'FOC', 'FM', 'OV', 'LR', 'ARC']
|
||||
|
||||
# Iterate over the sequence and construct a valid_sequence for each attribute
|
||||
valid_sequence_attr = {}
|
||||
for attr in attr_list:
|
||||
valid_sequence_attr[attr] = torch.zeros(valid_sequence.shape[0], dtype=torch.bool)
|
||||
for seq_id, seq_obj in enumerate(dataset):
|
||||
seq_name = seq_obj.name
|
||||
attr_txt = osp.join(attr_folder, f'{seq_name}.txt')
|
||||
if osp.exists(attr_txt):
|
||||
# read the attribute file into a list of True and False
|
||||
# the attribute file looks like this: 0,0,0,0,0,1,0,1,1,0,0,0,0,0
|
||||
attr_anno = np.loadtxt(attr_txt, dtype=int, delimiter=',')
|
||||
# broadcast the valid_sequence to the attribute list
|
||||
for attr_id, attr in enumerate(attr_list):
|
||||
valid_sequence_attr[attr][seq_id] = attr_anno[attr_id]
|
||||
else:
|
||||
raise Exception(f'Attribute file not found for sequence {seq_name}')
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
if report_name == 'LaSOT-ext':
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
report_name = "LaSOT_{ext}"
|
||||
else:
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 100)
|
||||
ylim_norm_precision = (0, 88)
|
||||
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
for attr in attr_list:
|
||||
scores = {}
|
||||
|
||||
print(f'{attr}: {valid_sequence_attr[attr].sum().item()}')
|
||||
valid_sequence_attr[attr] = valid_sequence_attr[attr] & valid_sequence
|
||||
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence_attr[attr])
|
||||
scores['AUC'] = auc
|
||||
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence_attr[attr])
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence_attr[attr])
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=attr)
|
||||
print(report_text)
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold', 'attr': attr,
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': ylim_success, 'title': f'Success\ of\ {attr}\ ({report_name})'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, auc, tracker_names, plot_draw_styles, result_plot_path, success_plot_opts)
|
||||
|
||||
|
||||
|
||||
|
||||
def plot_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), force_evaluation=False, **kwargs):
|
||||
"""
|
||||
Plot results for the given trackers
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success',
|
||||
'prec' (precision), and 'norm_prec' (normalized precision)
|
||||
"""
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
# Load pre-computed results
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, force_evaluation, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nPlotting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
print('\nGenerating plots for: {}'.format(report_name))
|
||||
|
||||
print(report_name)
|
||||
if report_name == 'LaSOT':
|
||||
ylim_success = (0, 95)
|
||||
ylim_precision = (0, 95)
|
||||
ylim_norm_precision = (0, 95)
|
||||
elif report_name == 'LaSOT-ext':
|
||||
ylim_success = (0, 85)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
else:
|
||||
ylim_success = (0, 85)
|
||||
ylim_precision = (0, 85)
|
||||
ylim_norm_precision = (0, 85)
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold',
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': ylim_success, 'title': f'Success\ ({report_name})'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, auc, tracker_names, plot_draw_styles, result_plot_path, success_plot_opts)
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
threshold_set_center = torch.tensor(eval_data['threshold_set_center'])
|
||||
|
||||
precision_plot_opts = {'plot_type': 'precision', 'legend_loc': 'lower right',
|
||||
'xlabel': 'Location error threshold [pixels]', 'ylabel': 'Distance Precision [%]',
|
||||
'xlim': (0, 50), 'ylim': ylim_precision, 'title': f'Precision\ ({report_name})'}
|
||||
plot_draw_save(prec_curve, threshold_set_center, prec_score, tracker_names, plot_draw_styles, result_plot_path,
|
||||
precision_plot_opts)
|
||||
|
||||
# ******************************** Norm Precision Plot **************************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
threshold_set_center_norm = torch.tensor(eval_data['threshold_set_center_norm'])
|
||||
|
||||
norm_precision_plot_opts = {'plot_type': 'norm_precision', 'legend_loc': 'lower right',
|
||||
'xlabel': 'Location error threshold', 'ylabel': 'Distance Precision [%]',
|
||||
'xlim': (0, 0.5), 'ylim': ylim_norm_precision, 'title': f'Normalized\ Precision\ ({report_name})'}
|
||||
plot_draw_save(prec_curve, threshold_set_center_norm, prec_score, tracker_names, plot_draw_styles, result_plot_path,
|
||||
norm_precision_plot_opts)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def generate_formatted_report(row_labels, scores, table_name=''):
|
||||
name_width = max([len(d) for d in row_labels] + [len(table_name)]) + 5
|
||||
min_score_width = 10
|
||||
|
||||
report_text = '\n{label: <{width}} |'.format(label=table_name, width=name_width)
|
||||
|
||||
score_widths = [max(min_score_width, len(k) + 3) for k in scores.keys()]
|
||||
|
||||
for s, s_w in zip(scores.keys(), score_widths):
|
||||
report_text = '{prev} {s: <{width}} |'.format(prev=report_text, s=s, width=s_w)
|
||||
|
||||
report_text = '{prev}\n'.format(prev=report_text)
|
||||
|
||||
for trk_id, d_name in enumerate(row_labels):
|
||||
# display name
|
||||
report_text = '{prev}{tracker: <{width}} |'.format(prev=report_text, tracker=d_name,
|
||||
width=name_width)
|
||||
for (score_type, score_value), s_w in zip(scores.items(), score_widths):
|
||||
report_text = '{prev} {score: <{width}} |'.format(prev=report_text,
|
||||
score='{:0.2f}'.format(score_value[trk_id].item()),
|
||||
width=s_w)
|
||||
report_text = '{prev}\n'.format(prev=report_text)
|
||||
|
||||
return report_text
|
||||
|
||||
def print_per_attribute_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
attr_folder = 'data/LaSOT/att'
|
||||
|
||||
attr_list = ['Illumination Variation', 'Partial Occlusion', 'Deformation', 'Motion Blur', 'Camera Motion', 'Rotation', 'Background Clutter', 'Viewpoint Change', 'Scale Variation', 'Full Occlusion', 'Fast Motion', 'Out-of-View', 'Low Resolution', 'Aspect Ration Change']
|
||||
attr_list = ['IV', 'POC', 'DEF', 'MB', 'CM', 'ROT', 'BC', 'VC', 'SV', 'FOC', 'FM', 'OV', 'LR', 'ARC']
|
||||
|
||||
# Iterate over the sequence and construct a valid_sequence for each attribute
|
||||
valid_sequence_attr = {}
|
||||
for attr in attr_list:
|
||||
valid_sequence_attr[attr] = torch.zeros(valid_sequence.shape[0], dtype=torch.bool)
|
||||
for seq_id, seq_obj in enumerate(dataset):
|
||||
seq_name = seq_obj.name
|
||||
attr_txt = osp.join(attr_folder, f'{seq_name}.txt')
|
||||
if osp.exists(attr_txt):
|
||||
# read the attribute file into a list of True and False
|
||||
# the attribute file looks like this: 0,0,0,0,0,1,0,1,1,0,0,0,0,0
|
||||
attr_anno = np.loadtxt(attr_txt, dtype=int, delimiter=',')
|
||||
# broadcast the valid_sequence to the attribute list
|
||||
for attr_id, attr in enumerate(attr_list):
|
||||
valid_sequence_attr[attr][seq_id] = attr_anno[attr_id]
|
||||
else:
|
||||
raise Exception(f'Attribute file not found for sequence {seq_name}')
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
for attr in attr_list:
|
||||
scores = {}
|
||||
|
||||
print(f'{attr}: {valid_sequence_attr[attr].sum().item()}')
|
||||
valid_sequence_attr[attr] = valid_sequence_attr[attr] & valid_sequence
|
||||
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence_attr[attr])
|
||||
scores['AUC'] = auc
|
||||
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence_attr[attr])
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence_attr[attr])
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=attr)
|
||||
print(report_text)
|
||||
|
||||
|
||||
def print_results(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), **kwargs):
|
||||
""" Print the results for the given trackers in a formatted table
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success' (prints AUC, OP50, and OP75 scores),
|
||||
'prec' (prints precision score), and 'norm_prec' (prints normalized precision score)
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nReporting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
scores = {}
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
|
||||
|
||||
def plot_got_success(trackers, report_name):
|
||||
""" Plot success plot for GOT-10k dataset using the json reports.
|
||||
Save the json reports from http://got-10k.aitestunion.com/leaderboard in the directory set to
|
||||
env_settings.got_reports_path
|
||||
|
||||
The tracker name in the experiment file should be set to the name of the report file for that tracker,
|
||||
e.g. DiMP50_report_2019_09_02_15_44_25 if the report is name DiMP50_report_2019_09_02_15_44_25.json
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
"""
|
||||
# Load data
|
||||
settings = env_settings()
|
||||
plot_draw_styles = get_plot_draw_styles()
|
||||
|
||||
result_plot_path = os.path.join(settings.result_plot_path, report_name)
|
||||
|
||||
auc_curve = torch.zeros((len(trackers), 101))
|
||||
scores = torch.zeros(len(trackers))
|
||||
|
||||
# Load results
|
||||
tracker_names = []
|
||||
for trk_id, trk in enumerate(trackers):
|
||||
json_path = '{}/{}.json'.format(settings.got_reports_path, trk.name)
|
||||
|
||||
if os.path.isfile(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
eval_data = json.load(f)
|
||||
else:
|
||||
raise Exception('Report not found {}'.format(json_path))
|
||||
|
||||
if len(eval_data.keys()) > 1:
|
||||
raise Exception
|
||||
|
||||
# First field is the tracker name. Index it out
|
||||
eval_data = eval_data[list(eval_data.keys())[0]]
|
||||
if 'succ_curve' in eval_data.keys():
|
||||
curve = eval_data['succ_curve']
|
||||
ao = eval_data['ao']
|
||||
elif 'overall' in eval_data.keys() and 'succ_curve' in eval_data['overall'].keys():
|
||||
curve = eval_data['overall']['succ_curve']
|
||||
ao = eval_data['overall']['ao']
|
||||
else:
|
||||
raise Exception('Invalid JSON file {}'.format(json_path))
|
||||
|
||||
auc_curve[trk_id, :] = torch.tensor(curve) * 100.0
|
||||
scores[trk_id] = ao * 100.0
|
||||
|
||||
tracker_names.append({'name': trk.name, 'param': trk.parameter_name, 'run_id': trk.run_id,
|
||||
'disp_name': trk.display_name})
|
||||
|
||||
threshold_set_overlap = torch.arange(0.0, 1.01, 0.01, dtype=torch.float64)
|
||||
|
||||
success_plot_opts = {'plot_type': 'success', 'legend_loc': 'lower left', 'xlabel': 'Overlap threshold',
|
||||
'ylabel': 'Overlap Precision [%]', 'xlim': (0, 1.0), 'ylim': (0, 100), 'title': 'Success plot'}
|
||||
plot_draw_save(auc_curve, threshold_set_overlap, scores, tracker_names, plot_draw_styles, result_plot_path,
|
||||
success_plot_opts)
|
||||
plt.show()
|
||||
|
||||
|
||||
def print_per_sequence_results(trackers, dataset, report_name, merge_results=False,
|
||||
filter_criteria=None, **kwargs):
|
||||
""" Print per-sequence results for the given trackers. Additionally, the sequences to list can be filtered using
|
||||
the filter criteria.
|
||||
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
filter_criteria - Filter sequence results which are reported. Following modes are supported
|
||||
None: No filtering. Display results for all sequences in dataset
|
||||
'ao_min': Only display sequences for which the minimum average overlap (AO) score over the
|
||||
trackers is less than a threshold filter_criteria['threshold']. This mode can
|
||||
be used to select sequences where at least one tracker performs poorly.
|
||||
'ao_max': Only display sequences for which the maximum average overlap (AO) score over the
|
||||
trackers is less than a threshold filter_criteria['threshold']. This mode can
|
||||
be used to select sequences all tracker performs poorly.
|
||||
'delta_ao': Only display sequences for which the performance of different trackers vary by at
|
||||
least filter_criteria['threshold'] in average overlap (AO) score. This mode can
|
||||
be used to select sequences where the behaviour of the trackers greatly differ
|
||||
between each other.
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
sequence_names = eval_data['sequences']
|
||||
avg_overlap_all = torch.tensor(eval_data['avg_overlap_all']) * 100.0
|
||||
|
||||
# Filter sequences
|
||||
if filter_criteria is not None:
|
||||
if filter_criteria['mode'] == 'ao_min':
|
||||
min_ao = avg_overlap_all.min(dim=1)[0]
|
||||
valid_sequence = valid_sequence & (min_ao < filter_criteria['threshold'])
|
||||
elif filter_criteria['mode'] == 'ao_max':
|
||||
max_ao = avg_overlap_all.max(dim=1)[0]
|
||||
valid_sequence = valid_sequence & (max_ao < filter_criteria['threshold'])
|
||||
elif filter_criteria['mode'] == 'delta_ao':
|
||||
min_ao = avg_overlap_all.min(dim=1)[0]
|
||||
max_ao = avg_overlap_all.max(dim=1)[0]
|
||||
valid_sequence = valid_sequence & ((max_ao - min_ao) > filter_criteria['threshold'])
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
avg_overlap_all = avg_overlap_all[valid_sequence, :]
|
||||
sequence_names = [s + ' (ID={})'.format(i) for i, (s, v) in enumerate(zip(sequence_names, valid_sequence.tolist())) if v]
|
||||
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
|
||||
scores_per_tracker = {k: avg_overlap_all[:, i] for i, k in enumerate(tracker_disp_names)}
|
||||
report_text = generate_formatted_report(sequence_names, scores_per_tracker)
|
||||
|
||||
print(report_text)
|
||||
|
||||
|
||||
def print_results_per_video(trackers, dataset, report_name, merge_results=False,
|
||||
plot_types=('success'), per_video=False, **kwargs):
|
||||
""" Print the results for the given trackers in a formatted table
|
||||
args:
|
||||
trackers - List of trackers to evaluate
|
||||
dataset - List of sequences to evaluate
|
||||
report_name - Name of the folder in env_settings.perm_mat_path where the computed results and plots are saved
|
||||
merge_results - If True, multiple random runs for a non-deterministic trackers are averaged
|
||||
plot_types - List of scores to display. Can contain 'success' (prints AUC, OP50, and OP75 scores),
|
||||
'prec' (prints precision score), and 'norm_prec' (prints normalized precision score)
|
||||
"""
|
||||
# Load pre-computed results
|
||||
eval_data = check_and_load_precomputed_results(trackers, dataset, report_name, **kwargs)
|
||||
|
||||
# Merge results from multiple runs
|
||||
if merge_results:
|
||||
eval_data = merge_multiple_runs(eval_data)
|
||||
|
||||
seq_lens = len(eval_data['sequences'])
|
||||
eval_datas = [{} for _ in range(seq_lens)]
|
||||
if per_video:
|
||||
for key, value in eval_data.items():
|
||||
if len(value) == seq_lens:
|
||||
for i in range(seq_lens):
|
||||
eval_datas[i][key] = [value[i]]
|
||||
else:
|
||||
for i in range(seq_lens):
|
||||
eval_datas[i][key] = value
|
||||
|
||||
tracker_names = eval_data['trackers']
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
print('\nReporting results over {} / {} sequences'.format(valid_sequence.long().sum().item(), valid_sequence.shape[0]))
|
||||
|
||||
scores = {}
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
|
||||
if per_video:
|
||||
for i in range(seq_lens):
|
||||
eval_data = eval_datas[i]
|
||||
|
||||
print('\n{} sequences'.format(eval_data['sequences'][0]))
|
||||
|
||||
scores = {}
|
||||
valid_sequence = torch.tensor(eval_data['valid_sequence'], dtype=torch.bool)
|
||||
|
||||
# ******************************** Success Plot **************************************
|
||||
if 'success' in plot_types:
|
||||
threshold_set_overlap = torch.tensor(eval_data['threshold_set_overlap'])
|
||||
ave_success_rate_plot_overlap = torch.tensor(eval_data['ave_success_rate_plot_overlap'])
|
||||
|
||||
# Index out valid sequences
|
||||
auc_curve, auc = get_auc_curve(ave_success_rate_plot_overlap, valid_sequence)
|
||||
scores['AUC'] = auc
|
||||
scores['OP50'] = auc_curve[:, threshold_set_overlap == 0.50]
|
||||
scores['OP75'] = auc_curve[:, threshold_set_overlap == 0.75]
|
||||
|
||||
# ******************************** Precision Plot **************************************
|
||||
if 'prec' in plot_types:
|
||||
ave_success_rate_plot_center = torch.tensor(eval_data['ave_success_rate_plot_center'])
|
||||
|
||||
# Index out valid sequences
|
||||
prec_curve, prec_score = get_prec_curve(ave_success_rate_plot_center, valid_sequence)
|
||||
scores['Precision'] = prec_score
|
||||
|
||||
# ******************************** Norm Precision Plot *********************************
|
||||
if 'norm_prec' in plot_types:
|
||||
ave_success_rate_plot_center_norm = torch.tensor(eval_data['ave_success_rate_plot_center_norm'])
|
||||
|
||||
# Index out valid sequences
|
||||
norm_prec_curve, norm_prec_score = get_prec_curve(ave_success_rate_plot_center_norm, valid_sequence)
|
||||
scores['Norm Precision'] = norm_prec_score
|
||||
|
||||
# Print
|
||||
tracker_disp_names = [get_tracker_display_name(trk) for trk in tracker_names]
|
||||
report_text = generate_formatted_report(tracker_disp_names, scores, table_name=report_name)
|
||||
print(report_text)
|
||||
4
lib/test/evaluation/__init__.py
Normal file
4
lib/test/evaluation/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .data import Sequence
|
||||
from .tracker import Tracker, trackerlist
|
||||
from .datasets import get_dataset
|
||||
from .environment import create_default_local_file_ITP_test
|
||||
169
lib/test/evaluation/data.py
Normal file
169
lib/test/evaluation/data.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.train.data.image_loader import imread_indexed
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class BaseDataset:
|
||||
"""Base class for all datasets."""
|
||||
def __init__(self):
|
||||
self.env_settings = env_settings()
|
||||
|
||||
def __len__(self):
|
||||
"""Overload this function in your dataset. This should return number of sequences in the dataset."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_sequence_list(self):
|
||||
"""Overload this in your dataset. Should return the list of sequences in the dataset."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""Class for the sequence in an evaluation."""
|
||||
def __init__(self, name, frames, dataset, ground_truth_rect, ground_truth_seg=None, init_data=None,
|
||||
object_class=None, target_visible=None, object_ids=None, multiobj_mode=False):
|
||||
self.name = name
|
||||
self.frames = frames
|
||||
self.dataset = dataset
|
||||
self.ground_truth_rect = ground_truth_rect
|
||||
self.ground_truth_seg = ground_truth_seg
|
||||
self.object_class = object_class
|
||||
self.target_visible = target_visible
|
||||
self.object_ids = object_ids
|
||||
self.multiobj_mode = multiobj_mode
|
||||
self.init_data = self._construct_init_data(init_data)
|
||||
self._ensure_start_frame()
|
||||
|
||||
def _ensure_start_frame(self):
|
||||
# Ensure start frame is 0
|
||||
start_frame = min(list(self.init_data.keys()))
|
||||
if start_frame > 0:
|
||||
self.frames = self.frames[start_frame:]
|
||||
if self.ground_truth_rect is not None:
|
||||
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
||||
for obj_id, gt in self.ground_truth_rect.items():
|
||||
self.ground_truth_rect[obj_id] = gt[start_frame:,:]
|
||||
else:
|
||||
self.ground_truth_rect = self.ground_truth_rect[start_frame:,:]
|
||||
if self.ground_truth_seg is not None:
|
||||
self.ground_truth_seg = self.ground_truth_seg[start_frame:]
|
||||
assert len(self.frames) == len(self.ground_truth_seg)
|
||||
|
||||
if self.target_visible is not None:
|
||||
self.target_visible = self.target_visible[start_frame:]
|
||||
self.init_data = {frame-start_frame: val for frame, val in self.init_data.items()}
|
||||
|
||||
def _construct_init_data(self, init_data):
|
||||
if init_data is not None:
|
||||
if not self.multiobj_mode:
|
||||
assert self.object_ids is None or len(self.object_ids) == 1
|
||||
for frame, init_val in init_data.items():
|
||||
if 'bbox' in init_val and isinstance(init_val['bbox'], (dict, OrderedDict)):
|
||||
init_val['bbox'] = init_val['bbox'][self.object_ids[0]]
|
||||
# convert to list
|
||||
for frame, init_val in init_data.items():
|
||||
if 'bbox' in init_val:
|
||||
if isinstance(init_val['bbox'], (dict, OrderedDict)):
|
||||
init_val['bbox'] = OrderedDict({obj_id: list(init) for obj_id, init in init_val['bbox'].items()})
|
||||
else:
|
||||
init_val['bbox'] = list(init_val['bbox'])
|
||||
else:
|
||||
init_data = {0: dict()} # Assume start from frame 0
|
||||
|
||||
if self.object_ids is not None:
|
||||
init_data[0]['object_ids'] = self.object_ids
|
||||
|
||||
if self.ground_truth_rect is not None:
|
||||
if self.multiobj_mode:
|
||||
assert isinstance(self.ground_truth_rect, (dict, OrderedDict))
|
||||
init_data[0]['bbox'] = OrderedDict({obj_id: list(gt[0,:]) for obj_id, gt in self.ground_truth_rect.items()})
|
||||
else:
|
||||
assert self.object_ids is None or len(self.object_ids) == 1
|
||||
if isinstance(self.ground_truth_rect, (dict, OrderedDict)):
|
||||
init_data[0]['bbox'] = list(self.ground_truth_rect[self.object_ids[0]][0, :])
|
||||
else:
|
||||
init_data[0]['bbox'] = list(self.ground_truth_rect[0,:])
|
||||
|
||||
if self.ground_truth_seg is not None:
|
||||
init_data[0]['mask'] = self.ground_truth_seg[0]
|
||||
|
||||
return init_data
|
||||
|
||||
def init_info(self):
|
||||
info = self.frame_info(frame_num=0)
|
||||
return info
|
||||
|
||||
def frame_info(self, frame_num):
|
||||
info = self.object_init_data(frame_num=frame_num)
|
||||
return info
|
||||
|
||||
def init_bbox(self, frame_num=0):
|
||||
return self.object_init_data(frame_num=frame_num).get('init_bbox')
|
||||
|
||||
def init_mask(self, frame_num=0):
|
||||
return self.object_init_data(frame_num=frame_num).get('init_mask')
|
||||
|
||||
def get_info(self, keys, frame_num=None):
|
||||
info = dict()
|
||||
for k in keys:
|
||||
val = self.get(k, frame_num=frame_num)
|
||||
if val is not None:
|
||||
info[k] = val
|
||||
return info
|
||||
|
||||
def object_init_data(self, frame_num=None) -> dict:
|
||||
if frame_num is None:
|
||||
frame_num = 0
|
||||
if frame_num not in self.init_data:
|
||||
return dict()
|
||||
|
||||
init_data = dict()
|
||||
for key, val in self.init_data[frame_num].items():
|
||||
if val is None:
|
||||
continue
|
||||
init_data['init_'+key] = val
|
||||
|
||||
if 'init_mask' in init_data and init_data['init_mask'] is not None:
|
||||
anno = imread_indexed(init_data['init_mask'])
|
||||
if not self.multiobj_mode and self.object_ids is not None:
|
||||
assert len(self.object_ids) == 1
|
||||
anno = (anno == int(self.object_ids[0])).astype(np.uint8)
|
||||
init_data['init_mask'] = anno
|
||||
|
||||
if self.object_ids is not None:
|
||||
init_data['object_ids'] = self.object_ids
|
||||
init_data['sequence_object_ids'] = self.object_ids
|
||||
|
||||
return init_data
|
||||
|
||||
def target_class(self, frame_num=None):
|
||||
return self.object_class
|
||||
|
||||
def get(self, name, frame_num=None):
|
||||
return getattr(self, name)(frame_num)
|
||||
|
||||
def __repr__(self):
|
||||
return "{self.__class__.__name__} {self.name}, length={len} frames".format(self=self, len=len(self.frames))
|
||||
|
||||
|
||||
|
||||
class SequenceList(list):
|
||||
"""List of sequences. Supports the addition operator to concatenate sequence lists."""
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str):
|
||||
for seq in self:
|
||||
if seq.name == item:
|
||||
return seq
|
||||
raise IndexError('Sequence name not in the dataset.')
|
||||
elif isinstance(item, int):
|
||||
return super(SequenceList, self).__getitem__(item)
|
||||
elif isinstance(item, (tuple, list)):
|
||||
return SequenceList([super(SequenceList, self).__getitem__(i) for i in item])
|
||||
else:
|
||||
return SequenceList(super(SequenceList, self).__getitem__(item))
|
||||
|
||||
def __add__(self, other):
|
||||
return SequenceList(super(SequenceList, self).__add__(other))
|
||||
|
||||
def copy(self):
|
||||
return SequenceList(super(SequenceList, self).copy())
|
||||
48
lib/test/evaluation/datasets.py
Normal file
48
lib/test/evaluation/datasets.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from collections import namedtuple
|
||||
import importlib
|
||||
from lib.test.evaluation.data import SequenceList
|
||||
|
||||
DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs'])
|
||||
|
||||
pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter
|
||||
|
||||
dataset_dict = dict(
|
||||
otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()),
|
||||
nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()),
|
||||
uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()),
|
||||
tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()),
|
||||
tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()),
|
||||
trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()),
|
||||
got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')),
|
||||
got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')),
|
||||
got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')),
|
||||
lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()),
|
||||
lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()),
|
||||
|
||||
vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()),
|
||||
vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)),
|
||||
itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()),
|
||||
tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()),
|
||||
lasot_extension_subset=DatasetInfo(module=pt % "lasotextensionsubset", class_name="LaSOTExtensionSubsetDataset",
|
||||
kwargs=dict()),
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(name: str):
|
||||
""" Import and load a single dataset."""
|
||||
name = name.lower()
|
||||
dset_info = dataset_dict.get(name)
|
||||
if dset_info is None:
|
||||
raise ValueError('Unknown dataset \'%s\'' % name)
|
||||
|
||||
m = importlib.import_module(dset_info.module)
|
||||
dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor
|
||||
return dataset.get_sequence_list()
|
||||
|
||||
|
||||
def get_dataset(*args):
|
||||
""" Get a single or set of datasets."""
|
||||
dset = SequenceList()
|
||||
for name in args:
|
||||
dset.extend(load_dataset(name))
|
||||
return dset
|
||||
124
lib/test/evaluation/environment.py
Normal file
124
lib/test/evaluation/environment.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
class EnvSettings:
|
||||
def __init__(self):
|
||||
test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
self.results_path = '{}/results/'.format(test_path)
|
||||
self.segmentation_path = '{}/segmentation_results/'.format(test_path)
|
||||
self.network_path = '{}/networks/'.format(test_path)
|
||||
self.result_plot_path = '{}/result_plots/'.format(test_path)
|
||||
self.otb_path = ''
|
||||
self.nfs_path = ''
|
||||
self.uav_path = ''
|
||||
self.tpl_path = ''
|
||||
self.vot_path = ''
|
||||
self.got10k_path = ''
|
||||
self.lasot_path = ''
|
||||
self.trackingnet_path = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
|
||||
self.got_packed_results_path = ''
|
||||
self.got_reports_path = ''
|
||||
self.tn_packed_results_path = ''
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
comment = {'results_path': 'Where to store tracking results',
|
||||
'network_path': 'Where tracking networks are stored.'}
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
with open(path, 'w') as f:
|
||||
settings = EnvSettings()
|
||||
|
||||
f.write('from test.evaluation.environment import EnvSettings\n\n')
|
||||
f.write('def local_env_settings():\n')
|
||||
f.write(' settings = EnvSettings()\n\n')
|
||||
f.write(' # Set your local paths here.\n\n')
|
||||
|
||||
for attr in dir(settings):
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
attr_val = getattr(settings, attr)
|
||||
if not attr.startswith('__') and not callable(attr_val):
|
||||
if comment_str is None:
|
||||
f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
f.write('\n return settings\n\n')
|
||||
|
||||
|
||||
class EnvSettings_ITP:
|
||||
def __init__(self, workspace_dir, data_dir, save_dir):
|
||||
self.prj_dir = workspace_dir
|
||||
self.save_dir = save_dir
|
||||
self.results_path = os.path.join(save_dir, 'test/tracking_results')
|
||||
self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results')
|
||||
self.network_path = os.path.join(save_dir, 'test/networks')
|
||||
self.result_plot_path = os.path.join(save_dir, 'test/result_plots')
|
||||
self.otb_path = os.path.join(data_dir, 'otb')
|
||||
self.nfs_path = os.path.join(data_dir, 'nfs')
|
||||
self.uav_path = os.path.join(data_dir, 'uav')
|
||||
self.tc128_path = os.path.join(data_dir, 'TC128')
|
||||
self.tpl_path = ''
|
||||
self.vot_path = os.path.join(data_dir, 'VOT2019')
|
||||
self.got10k_path = os.path.join(data_dir, 'got10k')
|
||||
self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb')
|
||||
self.lasot_path = os.path.join(data_dir, 'lasot')
|
||||
self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb')
|
||||
self.trackingnet_path = os.path.join(data_dir, 'trackingnet')
|
||||
self.vot18_path = os.path.join(data_dir, 'vot2018')
|
||||
self.vot22_path = os.path.join(data_dir, 'vot2022')
|
||||
self.itb_path = os.path.join(data_dir, 'itb')
|
||||
self.tnl2k_path = os.path.join(data_dir, 'tnl2k')
|
||||
self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset')
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
|
||||
self.got_packed_results_path = ''
|
||||
self.got_reports_path = ''
|
||||
self.tn_packed_results_path = ''
|
||||
|
||||
|
||||
def create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir):
|
||||
comment = {'results_path': 'Where to store tracking results',
|
||||
'network_path': 'Where tracking networks are stored.'}
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
with open(path, 'w') as f:
|
||||
settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir)
|
||||
|
||||
f.write('from lib.test.evaluation.environment import EnvSettings\n\n')
|
||||
f.write('def local_env_settings():\n')
|
||||
f.write(' settings = EnvSettings()\n\n')
|
||||
f.write(' # Set your local paths here.\n\n')
|
||||
|
||||
for attr in dir(settings):
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
attr_val = getattr(settings, attr)
|
||||
if not attr.startswith('__') and not callable(attr_val):
|
||||
if comment_str is None:
|
||||
f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
f.write('\n return settings\n\n')
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.test.evaluation.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.local_env_settings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
# Create a default file
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. '
|
||||
'Then try to run again.'.format(env_file))
|
||||
56
lib/test/evaluation/got10kdataset.py
Normal file
56
lib/test/evaluation/got10kdataset.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
import os
|
||||
|
||||
|
||||
class GOT10KDataset(BaseDataset):
|
||||
""" GOT-10k dataset.
|
||||
|
||||
Publication:
|
||||
GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
|
||||
Lianghua Huang, Xin Zhao, and Kaiqi Huang
|
||||
arXiv:1810.11981, 2018
|
||||
https://arxiv.org/pdf/1810.11981.pdf
|
||||
|
||||
Download dataset from http://got-10k.aitestunion.com/downloads
|
||||
"""
|
||||
def __init__(self, split):
|
||||
super().__init__()
|
||||
# Split can be test, val, or ltrval (a validation split consisting of videos from the official train set)
|
||||
if split == 'test' or split == 'val':
|
||||
self.base_path = os.path.join(self.env_settings.got10k_path, split)
|
||||
else:
|
||||
self.base_path = os.path.join(self.env_settings.got10k_path, 'train')
|
||||
|
||||
self.sequence_list = self._get_sequence_list(split)
|
||||
self.split = split
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
frames_path = '{}/{}'.format(self.base_path, sequence_name)
|
||||
frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
|
||||
frame_list.sort(key=lambda f: int(f[:-4]))
|
||||
frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
|
||||
|
||||
return Sequence(sequence_name, frames_list, 'got10k', ground_truth_rect.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self, split):
|
||||
with open('{}/list.txt'.format(self.base_path)) as f:
|
||||
sequence_list = f.read().splitlines()
|
||||
|
||||
if split == 'ltrval':
|
||||
with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f:
|
||||
seq_ids = f.read().splitlines()
|
||||
|
||||
sequence_list = [sequence_list[int(x)] for x in seq_ids]
|
||||
return sequence_list
|
||||
75
lib/test/evaluation/itbdataset.py
Normal file
75
lib/test/evaluation/itbdataset.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
import os
|
||||
|
||||
|
||||
class ITBDataset(BaseDataset):
|
||||
""" NUS-PRO dataset
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.itb_path
|
||||
self.sequence_info_list = self._get_sequence_info_list(self.base_path)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num,
|
||||
nz=nz, ext=ext) for frame_num in
|
||||
range(start_frame + init_omit, end_frame + 1)]
|
||||
|
||||
anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
|
||||
# NOTE: NUS has some weird annos which panda cannot handle
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy')
|
||||
return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:, :],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def get_fileNames(self, rootdir):
|
||||
fs = []
|
||||
fs_all = []
|
||||
for root, dirs, files in os.walk(rootdir, topdown=True):
|
||||
files.sort()
|
||||
files.sort(key=len)
|
||||
if files is not None:
|
||||
for name in files:
|
||||
_, ending = os.path.splitext(name)
|
||||
if ending == ".jpg":
|
||||
_, root_ = os.path.split(root)
|
||||
fs.append(os.path.join(root_, name))
|
||||
fs_all.append(os.path.join(root, name))
|
||||
|
||||
return fs_all, fs
|
||||
|
||||
def _get_sequence_info_list(self, base_path):
|
||||
sequence_info_list = []
|
||||
for scene in os.listdir(base_path):
|
||||
if '.' in scene:
|
||||
continue
|
||||
videos = os.listdir(os.path.join(base_path, scene))
|
||||
for video in videos:
|
||||
_, fs = self.get_fileNames(os.path.join(base_path, scene, video))
|
||||
video_tmp = {"name": video, "path": scene + '/' + video, "startFrame": 1, "endFrame": len(fs),
|
||||
"nz": len(fs[0].split('/')[-1].split('.')[0]), "ext": "jpg",
|
||||
"anno_path": scene + '/' + video + "/groundtruth.txt",
|
||||
"object_class": "unknown"}
|
||||
sequence_info_list.append(video_tmp)
|
||||
|
||||
return sequence_info_list # sequence_info_list_50 #
|
||||
345
lib/test/evaluation/lasot_lmdbdataset.py
Normal file
345
lib/test/evaluation/lasot_lmdbdataset.py
Normal file
@@ -0,0 +1,345 @@
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
'''2021.1.27 LaSOT dataset using lmdb data'''
|
||||
|
||||
|
||||
class LaSOTlmdbDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_lmdb_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = str('{}/{}/groundtruth.txt'.format(class_name, sequence_name))
|
||||
# decode the groundtruth
|
||||
gt_str_list = decode_str(self.base_path, anno_path).split('\n')[:-1] # the last line is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
ground_truth_rect = np.array(gt_list).astype(np.float64)
|
||||
# decode occlusion file
|
||||
occlusion_label_path = str('{}/{}/full_occlusion.txt'.format(class_name, sequence_name))
|
||||
occ_list = list(map(int, decode_str(self.base_path, occlusion_label_path).split(',')))
|
||||
full_occlusion = np.array(occ_list).astype(np.float64)
|
||||
# decode out of view file
|
||||
out_of_view_label_path = str('{}/{}/out_of_view.txt'.format(class_name, sequence_name))
|
||||
out_of_view_list = list(map(int, decode_str(self.base_path, out_of_view_label_path).split(',')))
|
||||
out_of_view = np.array(out_of_view_list).astype(np.float64)
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/img'.format(class_name, sequence_name)
|
||||
|
||||
frames_list = [[self.base_path, '{}/{:08d}.jpg'.format(frames_path, frame_number)] for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['airplane-1',
|
||||
'airplane-9',
|
||||
'airplane-13',
|
||||
'airplane-15',
|
||||
'basketball-1',
|
||||
'basketball-6',
|
||||
'basketball-7',
|
||||
'basketball-11',
|
||||
'bear-2',
|
||||
'bear-4',
|
||||
'bear-6',
|
||||
'bear-17',
|
||||
'bicycle-2',
|
||||
'bicycle-7',
|
||||
'bicycle-9',
|
||||
'bicycle-18',
|
||||
'bird-2',
|
||||
'bird-3',
|
||||
'bird-15',
|
||||
'bird-17',
|
||||
'boat-3',
|
||||
'boat-4',
|
||||
'boat-12',
|
||||
'boat-17',
|
||||
'book-3',
|
||||
'book-10',
|
||||
'book-11',
|
||||
'book-19',
|
||||
'bottle-1',
|
||||
'bottle-12',
|
||||
'bottle-14',
|
||||
'bottle-18',
|
||||
'bus-2',
|
||||
'bus-5',
|
||||
'bus-17',
|
||||
'bus-19',
|
||||
'car-2',
|
||||
'car-6',
|
||||
'car-9',
|
||||
'car-17',
|
||||
'cat-1',
|
||||
'cat-3',
|
||||
'cat-18',
|
||||
'cat-20',
|
||||
'cattle-2',
|
||||
'cattle-7',
|
||||
'cattle-12',
|
||||
'cattle-13',
|
||||
'spider-14',
|
||||
'spider-16',
|
||||
'spider-18',
|
||||
'spider-20',
|
||||
'coin-3',
|
||||
'coin-6',
|
||||
'coin-7',
|
||||
'coin-18',
|
||||
'crab-3',
|
||||
'crab-6',
|
||||
'crab-12',
|
||||
'crab-18',
|
||||
'surfboard-12',
|
||||
'surfboard-4',
|
||||
'surfboard-5',
|
||||
'surfboard-8',
|
||||
'cup-1',
|
||||
'cup-4',
|
||||
'cup-7',
|
||||
'cup-17',
|
||||
'deer-4',
|
||||
'deer-8',
|
||||
'deer-10',
|
||||
'deer-14',
|
||||
'dog-1',
|
||||
'dog-7',
|
||||
'dog-15',
|
||||
'dog-19',
|
||||
'guitar-3',
|
||||
'guitar-8',
|
||||
'guitar-10',
|
||||
'guitar-16',
|
||||
'person-1',
|
||||
'person-5',
|
||||
'person-10',
|
||||
'person-12',
|
||||
'pig-2',
|
||||
'pig-10',
|
||||
'pig-13',
|
||||
'pig-18',
|
||||
'rubicCube-1',
|
||||
'rubicCube-6',
|
||||
'rubicCube-14',
|
||||
'rubicCube-19',
|
||||
'swing-10',
|
||||
'swing-14',
|
||||
'swing-17',
|
||||
'swing-20',
|
||||
'drone-13',
|
||||
'drone-15',
|
||||
'drone-2',
|
||||
'drone-7',
|
||||
'pool-12',
|
||||
'pool-15',
|
||||
'pool-3',
|
||||
'pool-7',
|
||||
'rabbit-10',
|
||||
'rabbit-13',
|
||||
'rabbit-17',
|
||||
'rabbit-19',
|
||||
'racing-10',
|
||||
'racing-15',
|
||||
'racing-16',
|
||||
'racing-20',
|
||||
'robot-1',
|
||||
'robot-19',
|
||||
'robot-5',
|
||||
'robot-8',
|
||||
'sepia-13',
|
||||
'sepia-16',
|
||||
'sepia-6',
|
||||
'sepia-8',
|
||||
'sheep-3',
|
||||
'sheep-5',
|
||||
'sheep-7',
|
||||
'sheep-9',
|
||||
'skateboard-16',
|
||||
'skateboard-19',
|
||||
'skateboard-3',
|
||||
'skateboard-8',
|
||||
'tank-14',
|
||||
'tank-16',
|
||||
'tank-6',
|
||||
'tank-9',
|
||||
'tiger-12',
|
||||
'tiger-18',
|
||||
'tiger-4',
|
||||
'tiger-6',
|
||||
'train-1',
|
||||
'train-11',
|
||||
'train-20',
|
||||
'train-7',
|
||||
'truck-16',
|
||||
'truck-3',
|
||||
'truck-6',
|
||||
'truck-7',
|
||||
'turtle-16',
|
||||
'turtle-5',
|
||||
'turtle-8',
|
||||
'turtle-9',
|
||||
'umbrella-17',
|
||||
'umbrella-19',
|
||||
'umbrella-2',
|
||||
'umbrella-9',
|
||||
'yoyo-15',
|
||||
'yoyo-17',
|
||||
'yoyo-19',
|
||||
'yoyo-7',
|
||||
'zebra-10',
|
||||
'zebra-14',
|
||||
'zebra-16',
|
||||
'zebra-17',
|
||||
'elephant-1',
|
||||
'elephant-12',
|
||||
'elephant-16',
|
||||
'elephant-18',
|
||||
'goldfish-3',
|
||||
'goldfish-7',
|
||||
'goldfish-8',
|
||||
'goldfish-10',
|
||||
'hat-1',
|
||||
'hat-2',
|
||||
'hat-5',
|
||||
'hat-18',
|
||||
'kite-4',
|
||||
'kite-6',
|
||||
'kite-10',
|
||||
'kite-15',
|
||||
'motorcycle-1',
|
||||
'motorcycle-3',
|
||||
'motorcycle-9',
|
||||
'motorcycle-18',
|
||||
'mouse-1',
|
||||
'mouse-8',
|
||||
'mouse-9',
|
||||
'mouse-17',
|
||||
'flag-3',
|
||||
'flag-9',
|
||||
'flag-5',
|
||||
'flag-2',
|
||||
'frog-3',
|
||||
'frog-4',
|
||||
'frog-20',
|
||||
'frog-9',
|
||||
'gametarget-1',
|
||||
'gametarget-2',
|
||||
'gametarget-7',
|
||||
'gametarget-13',
|
||||
'hand-2',
|
||||
'hand-3',
|
||||
'hand-9',
|
||||
'hand-16',
|
||||
'helmet-5',
|
||||
'helmet-11',
|
||||
'helmet-19',
|
||||
'helmet-13',
|
||||
'licenseplate-6',
|
||||
'licenseplate-12',
|
||||
'licenseplate-13',
|
||||
'licenseplate-15',
|
||||
'electricfan-1',
|
||||
'electricfan-10',
|
||||
'electricfan-18',
|
||||
'electricfan-20',
|
||||
'chameleon-3',
|
||||
'chameleon-6',
|
||||
'chameleon-11',
|
||||
'chameleon-20',
|
||||
'crocodile-3',
|
||||
'crocodile-4',
|
||||
'crocodile-10',
|
||||
'crocodile-14',
|
||||
'gecko-1',
|
||||
'gecko-5',
|
||||
'gecko-16',
|
||||
'gecko-19',
|
||||
'fox-2',
|
||||
'fox-3',
|
||||
'fox-5',
|
||||
'fox-20',
|
||||
'giraffe-2',
|
||||
'giraffe-10',
|
||||
'giraffe-13',
|
||||
'giraffe-15',
|
||||
'gorilla-4',
|
||||
'gorilla-6',
|
||||
'gorilla-9',
|
||||
'gorilla-13',
|
||||
'hippo-1',
|
||||
'hippo-7',
|
||||
'hippo-9',
|
||||
'hippo-20',
|
||||
'horse-1',
|
||||
'horse-4',
|
||||
'horse-12',
|
||||
'horse-15',
|
||||
'kangaroo-2',
|
||||
'kangaroo-5',
|
||||
'kangaroo-11',
|
||||
'kangaroo-14',
|
||||
'leopard-1',
|
||||
'leopard-7',
|
||||
'leopard-16',
|
||||
'leopard-20',
|
||||
'lion-1',
|
||||
'lion-5',
|
||||
'lion-12',
|
||||
'lion-20',
|
||||
'lizard-1',
|
||||
'lizard-3',
|
||||
'lizard-6',
|
||||
'lizard-13',
|
||||
'microphone-2',
|
||||
'microphone-6',
|
||||
'microphone-14',
|
||||
'microphone-16',
|
||||
'monkey-3',
|
||||
'monkey-4',
|
||||
'monkey-9',
|
||||
'monkey-17',
|
||||
'shark-2',
|
||||
'shark-3',
|
||||
'shark-5',
|
||||
'shark-6',
|
||||
'squirrel-8',
|
||||
'squirrel-11',
|
||||
'squirrel-13',
|
||||
'squirrel-19',
|
||||
'volleyball-1',
|
||||
'volleyball-13',
|
||||
'volleyball-18',
|
||||
'volleyball-19']
|
||||
return sequence_list
|
||||
342
lib/test/evaluation/lasotdataset.py
Normal file
342
lib/test/evaluation/lasotdataset.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class LaSOTDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/{}/groundtruth.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
occlusion_label_path = '{}/{}/{}/full_occlusion.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
# NOTE: pandas backed seems super super slow for loading occlusion/oov masks
|
||||
full_occlusion = load_text(str(occlusion_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
out_of_view_label_path = '{}/{}/{}/out_of_view.txt'.format(self.base_path, class_name, sequence_name)
|
||||
out_of_view = load_text(str(out_of_view_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/{}/img'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
frames_list = ['{}/{:08d}.jpg'.format(frames_path, frame_number) for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['airplane-1',
|
||||
'airplane-9',
|
||||
'airplane-13',
|
||||
'airplane-15',
|
||||
'basketball-1',
|
||||
'basketball-6',
|
||||
'basketball-7',
|
||||
'basketball-11',
|
||||
'bear-2',
|
||||
'bear-4',
|
||||
'bear-6',
|
||||
'bear-17',
|
||||
'bicycle-2',
|
||||
'bicycle-7',
|
||||
'bicycle-9',
|
||||
'bicycle-18',
|
||||
'bird-2',
|
||||
'bird-3',
|
||||
'bird-15',
|
||||
'bird-17',
|
||||
'boat-3',
|
||||
'boat-4',
|
||||
'boat-12',
|
||||
'boat-17',
|
||||
'book-3',
|
||||
'book-10',
|
||||
'book-11',
|
||||
'book-19',
|
||||
'bottle-1',
|
||||
'bottle-12',
|
||||
'bottle-14',
|
||||
'bottle-18',
|
||||
'bus-2',
|
||||
'bus-5',
|
||||
'bus-17',
|
||||
'bus-19',
|
||||
'car-2',
|
||||
'car-6',
|
||||
'car-9',
|
||||
'car-17',
|
||||
'cat-1',
|
||||
'cat-3',
|
||||
'cat-18',
|
||||
'cat-20',
|
||||
'cattle-2',
|
||||
'cattle-7',
|
||||
'cattle-12',
|
||||
'cattle-13',
|
||||
'spider-14',
|
||||
'spider-16',
|
||||
'spider-18',
|
||||
'spider-20',
|
||||
'coin-3',
|
||||
'coin-6',
|
||||
'coin-7',
|
||||
'coin-18',
|
||||
'crab-3',
|
||||
'crab-6',
|
||||
'crab-12',
|
||||
'crab-18',
|
||||
'surfboard-12',
|
||||
'surfboard-4',
|
||||
'surfboard-5',
|
||||
'surfboard-8',
|
||||
'cup-1',
|
||||
'cup-4',
|
||||
'cup-7',
|
||||
'cup-17',
|
||||
'deer-4',
|
||||
'deer-8',
|
||||
'deer-10',
|
||||
'deer-14',
|
||||
'dog-1',
|
||||
'dog-7',
|
||||
'dog-15',
|
||||
'dog-19',
|
||||
'guitar-3',
|
||||
'guitar-8',
|
||||
'guitar-10',
|
||||
'guitar-16',
|
||||
'person-1',
|
||||
'person-5',
|
||||
'person-10',
|
||||
'person-12',
|
||||
'pig-2',
|
||||
'pig-10',
|
||||
'pig-13',
|
||||
'pig-18',
|
||||
'rubicCube-1',
|
||||
'rubicCube-6',
|
||||
'rubicCube-14',
|
||||
'rubicCube-19',
|
||||
'swing-10',
|
||||
'swing-14',
|
||||
'swing-17',
|
||||
'swing-20',
|
||||
'drone-13',
|
||||
'drone-15',
|
||||
'drone-2',
|
||||
'drone-7',
|
||||
'pool-12',
|
||||
'pool-15',
|
||||
'pool-3',
|
||||
'pool-7',
|
||||
'rabbit-10',
|
||||
'rabbit-13',
|
||||
'rabbit-17',
|
||||
'rabbit-19',
|
||||
'racing-10',
|
||||
'racing-15',
|
||||
'racing-16',
|
||||
'racing-20',
|
||||
'robot-1',
|
||||
'robot-19',
|
||||
'robot-5',
|
||||
'robot-8',
|
||||
'sepia-13',
|
||||
'sepia-16',
|
||||
'sepia-6',
|
||||
'sepia-8',
|
||||
'sheep-3',
|
||||
'sheep-5',
|
||||
'sheep-7',
|
||||
'sheep-9',
|
||||
'skateboard-16',
|
||||
'skateboard-19',
|
||||
'skateboard-3',
|
||||
'skateboard-8',
|
||||
'tank-14',
|
||||
'tank-16',
|
||||
'tank-6',
|
||||
'tank-9',
|
||||
'tiger-12',
|
||||
'tiger-18',
|
||||
'tiger-4',
|
||||
'tiger-6',
|
||||
'train-1',
|
||||
'train-11',
|
||||
'train-20',
|
||||
'train-7',
|
||||
'truck-16',
|
||||
'truck-3',
|
||||
'truck-6',
|
||||
'truck-7',
|
||||
'turtle-16',
|
||||
'turtle-5',
|
||||
'turtle-8',
|
||||
'turtle-9',
|
||||
'umbrella-17',
|
||||
'umbrella-19',
|
||||
'umbrella-2',
|
||||
'umbrella-9',
|
||||
'yoyo-15',
|
||||
'yoyo-17',
|
||||
'yoyo-19',
|
||||
'yoyo-7',
|
||||
'zebra-10',
|
||||
'zebra-14',
|
||||
'zebra-16',
|
||||
'zebra-17',
|
||||
'elephant-1',
|
||||
'elephant-12',
|
||||
'elephant-16',
|
||||
'elephant-18',
|
||||
'goldfish-3',
|
||||
'goldfish-7',
|
||||
'goldfish-8',
|
||||
'goldfish-10',
|
||||
'hat-1',
|
||||
'hat-2',
|
||||
'hat-5',
|
||||
'hat-18',
|
||||
'kite-4',
|
||||
'kite-6',
|
||||
'kite-10',
|
||||
'kite-15',
|
||||
'motorcycle-1',
|
||||
'motorcycle-3',
|
||||
'motorcycle-9',
|
||||
'motorcycle-18',
|
||||
'mouse-1',
|
||||
'mouse-8',
|
||||
'mouse-9',
|
||||
'mouse-17',
|
||||
'flag-3',
|
||||
'flag-9',
|
||||
'flag-5',
|
||||
'flag-2',
|
||||
'frog-3',
|
||||
'frog-4',
|
||||
'frog-20',
|
||||
'frog-9',
|
||||
'gametarget-1',
|
||||
'gametarget-2',
|
||||
'gametarget-7',
|
||||
'gametarget-13',
|
||||
'hand-2',
|
||||
'hand-3',
|
||||
'hand-9',
|
||||
'hand-16',
|
||||
'helmet-5',
|
||||
'helmet-11',
|
||||
'helmet-19',
|
||||
'helmet-13',
|
||||
'licenseplate-6',
|
||||
'licenseplate-12',
|
||||
'licenseplate-13',
|
||||
'licenseplate-15',
|
||||
'electricfan-1',
|
||||
'electricfan-10',
|
||||
'electricfan-18',
|
||||
'electricfan-20',
|
||||
'chameleon-3',
|
||||
'chameleon-6',
|
||||
'chameleon-11',
|
||||
'chameleon-20',
|
||||
'crocodile-3',
|
||||
'crocodile-4',
|
||||
'crocodile-10',
|
||||
'crocodile-14',
|
||||
'gecko-1',
|
||||
'gecko-5',
|
||||
'gecko-16',
|
||||
'gecko-19',
|
||||
'fox-2',
|
||||
'fox-3',
|
||||
'fox-5',
|
||||
'fox-20',
|
||||
'giraffe-2',
|
||||
'giraffe-10',
|
||||
'giraffe-13',
|
||||
'giraffe-15',
|
||||
'gorilla-4',
|
||||
'gorilla-6',
|
||||
'gorilla-9',
|
||||
'gorilla-13',
|
||||
'hippo-1',
|
||||
'hippo-7',
|
||||
'hippo-9',
|
||||
'hippo-20',
|
||||
'horse-1',
|
||||
'horse-4',
|
||||
'horse-12',
|
||||
'horse-15',
|
||||
'kangaroo-2',
|
||||
'kangaroo-5',
|
||||
'kangaroo-11',
|
||||
'kangaroo-14',
|
||||
'leopard-1',
|
||||
'leopard-7',
|
||||
'leopard-16',
|
||||
'leopard-20',
|
||||
'lion-1',
|
||||
'lion-5',
|
||||
'lion-12',
|
||||
'lion-20',
|
||||
'lizard-1',
|
||||
'lizard-3',
|
||||
'lizard-6',
|
||||
'lizard-13',
|
||||
'microphone-2',
|
||||
'microphone-6',
|
||||
'microphone-14',
|
||||
'microphone-16',
|
||||
'monkey-3',
|
||||
'monkey-4',
|
||||
'monkey-9',
|
||||
'monkey-17',
|
||||
'shark-2',
|
||||
'shark-3',
|
||||
'shark-5',
|
||||
'shark-6',
|
||||
'squirrel-8',
|
||||
'squirrel-11',
|
||||
'squirrel-13',
|
||||
'squirrel-19',
|
||||
'volleyball-1',
|
||||
'volleyball-13',
|
||||
'volleyball-18',
|
||||
'volleyball-19']
|
||||
return sequence_list
|
||||
211
lib/test/evaluation/lasotextensionsubsetdataset.py
Normal file
211
lib/test/evaluation/lasotextensionsubsetdataset.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class LaSOTExtensionSubsetDataset(BaseDataset):
|
||||
"""
|
||||
LaSOT test set consisting of 280 videos (see Protocol-II in the LaSOT paper)
|
||||
Publication:
|
||||
LaSOT: A High-quality Large-scale Single Object Tracking Benchmark
|
||||
Heng Fan, Hexin Bai, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Harshit, Mingzhen Huang, Juehuan Liu,
|
||||
Yong Xu, Chunyuan Liao, Lin Yuan, Haibin Ling
|
||||
IJCV, 2020
|
||||
https://arxiv.org/pdf/2009.03465.pdf
|
||||
Download the dataset from http://vision.cs.stonybrook.edu/~lasot/download.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.lasot_extension_subset_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
self.clean_list = self.clean_seq_list()
|
||||
|
||||
def clean_seq_list(self):
|
||||
clean_lst = []
|
||||
for i in range(len(self.sequence_list)):
|
||||
cls, _ = self.sequence_list[i].split('-')
|
||||
clean_lst.append(cls)
|
||||
return clean_lst
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/{}/groundtruth.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
occlusion_label_path = '{}/{}/{}/full_occlusion.txt'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
# NOTE: pandas backed seems super super slow for loading occlusion/oov masks
|
||||
full_occlusion = load_text(str(occlusion_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
out_of_view_label_path = '{}/{}/{}/out_of_view.txt'.format(self.base_path, class_name, sequence_name)
|
||||
out_of_view = load_text(str(out_of_view_label_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
target_visible = np.logical_and(full_occlusion == 0, out_of_view == 0)
|
||||
|
||||
frames_path = '{}/{}/{}/img'.format(self.base_path, class_name, sequence_name)
|
||||
|
||||
frames_list = ['{}/{:08d}.jpg'.format(frames_path, frame_number) for frame_number in range(1, ground_truth_rect.shape[0] + 1)]
|
||||
|
||||
target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'lasot_extension_subset', ground_truth_rect.reshape(-1, 4),
|
||||
object_class=target_class, target_visible=target_visible)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = ['atv-1',
|
||||
'atv-2',
|
||||
'atv-3',
|
||||
'atv-4',
|
||||
'atv-5',
|
||||
'atv-6',
|
||||
'atv-7',
|
||||
'atv-8',
|
||||
'atv-9',
|
||||
'atv-10',
|
||||
'badminton-1',
|
||||
'badminton-2',
|
||||
'badminton-3',
|
||||
'badminton-4',
|
||||
'badminton-5',
|
||||
'badminton-6',
|
||||
'badminton-7',
|
||||
'badminton-8',
|
||||
'badminton-9',
|
||||
'badminton-10',
|
||||
'cosplay-1',
|
||||
'cosplay-10',
|
||||
'cosplay-2',
|
||||
'cosplay-3',
|
||||
'cosplay-4',
|
||||
'cosplay-5',
|
||||
'cosplay-6',
|
||||
'cosplay-7',
|
||||
'cosplay-8',
|
||||
'cosplay-9',
|
||||
'dancingshoe-1',
|
||||
'dancingshoe-2',
|
||||
'dancingshoe-3',
|
||||
'dancingshoe-4',
|
||||
'dancingshoe-5',
|
||||
'dancingshoe-6',
|
||||
'dancingshoe-7',
|
||||
'dancingshoe-8',
|
||||
'dancingshoe-9',
|
||||
'dancingshoe-10',
|
||||
'footbag-1',
|
||||
'footbag-2',
|
||||
'footbag-3',
|
||||
'footbag-4',
|
||||
'footbag-5',
|
||||
'footbag-6',
|
||||
'footbag-7',
|
||||
'footbag-8',
|
||||
'footbag-9',
|
||||
'footbag-10',
|
||||
'frisbee-1',
|
||||
'frisbee-2',
|
||||
'frisbee-3',
|
||||
'frisbee-4',
|
||||
'frisbee-5',
|
||||
'frisbee-6',
|
||||
'frisbee-7',
|
||||
'frisbee-8',
|
||||
'frisbee-9',
|
||||
'frisbee-10',
|
||||
'jianzi-1',
|
||||
'jianzi-2',
|
||||
'jianzi-3',
|
||||
'jianzi-4',
|
||||
'jianzi-5',
|
||||
'jianzi-6',
|
||||
'jianzi-7',
|
||||
'jianzi-8',
|
||||
'jianzi-9',
|
||||
'jianzi-10',
|
||||
'lantern-1',
|
||||
'lantern-2',
|
||||
'lantern-3',
|
||||
'lantern-4',
|
||||
'lantern-5',
|
||||
'lantern-6',
|
||||
'lantern-7',
|
||||
'lantern-8',
|
||||
'lantern-9',
|
||||
'lantern-10',
|
||||
'misc-1',
|
||||
'misc-2',
|
||||
'misc-3',
|
||||
'misc-4',
|
||||
'misc-5',
|
||||
'misc-6',
|
||||
'misc-7',
|
||||
'misc-8',
|
||||
'misc-9',
|
||||
'misc-10',
|
||||
'opossum-1',
|
||||
'opossum-2',
|
||||
'opossum-3',
|
||||
'opossum-4',
|
||||
'opossum-5',
|
||||
'opossum-6',
|
||||
'opossum-7',
|
||||
'opossum-8',
|
||||
'opossum-9',
|
||||
'opossum-10',
|
||||
'paddle-1',
|
||||
'paddle-2',
|
||||
'paddle-3',
|
||||
'paddle-4',
|
||||
'paddle-5',
|
||||
'paddle-6',
|
||||
'paddle-7',
|
||||
'paddle-8',
|
||||
'paddle-9',
|
||||
'paddle-10',
|
||||
'raccoon-1',
|
||||
'raccoon-2',
|
||||
'raccoon-3',
|
||||
'raccoon-4',
|
||||
'raccoon-5',
|
||||
'raccoon-6',
|
||||
'raccoon-7',
|
||||
'raccoon-8',
|
||||
'raccoon-9',
|
||||
'raccoon-10',
|
||||
'rhino-1',
|
||||
'rhino-2',
|
||||
'rhino-3',
|
||||
'rhino-4',
|
||||
'rhino-5',
|
||||
'rhino-6',
|
||||
'rhino-7',
|
||||
'rhino-8',
|
||||
'rhino-9',
|
||||
'rhino-10',
|
||||
'skatingshoe-1',
|
||||
'skatingshoe-2',
|
||||
'skatingshoe-3',
|
||||
'skatingshoe-4',
|
||||
'skatingshoe-5',
|
||||
'skatingshoe-6',
|
||||
'skatingshoe-7',
|
||||
'skatingshoe-8',
|
||||
'skatingshoe-9',
|
||||
'skatingshoe-10',
|
||||
'wingsuit-1',
|
||||
'wingsuit-2',
|
||||
'wingsuit-3',
|
||||
'wingsuit-4',
|
||||
'wingsuit-5',
|
||||
'wingsuit-6',
|
||||
'wingsuit-7',
|
||||
'wingsuit-8',
|
||||
'wingsuit-9',
|
||||
'wingsuit-10']
|
||||
return sequence_list
|
||||
38
lib/test/evaluation/local.py
Normal file
38
lib/test/evaluation/local.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from lib.test.evaluation.environment import EnvSettings
|
||||
|
||||
def local_env_settings():
|
||||
settings = EnvSettings()
|
||||
|
||||
# Set your local paths here.
|
||||
|
||||
settings.lasot_path = '/home/cycyang/code/vot-sam/data/LaSOT'
|
||||
settings.lasot_extension_subset_path = '/home/cycyang/code/vot-sam/data/LaSOT-ext'
|
||||
settings.nfs_path = '/home/cycyang/code/vot-sam/data/NFS'
|
||||
settings.otb_path = '/home/cycyang/code/vot-sam/data/otb'
|
||||
settings.uav_path = '//home/cycyang/code/vot-sam/data/uav'
|
||||
settings.results_path = '/home/cycyang/code/vot-sam/raw_results'
|
||||
settings.result_plot_path = '/home/cycyang/code/vot-sam/evaluation_results'
|
||||
settings.save_dir = '/home/cycyang/code/vot-sam/evaluation_results'
|
||||
|
||||
settings.davis_dir = ''
|
||||
settings.got10k_lmdb_path = '/home/baiyifan/code/OSTrack/data/got10k_lmdb'
|
||||
settings.got10k_path = '/home/baiyifan/GOT-10k'
|
||||
settings.got_packed_results_path = ''
|
||||
settings.got_reports_path = ''
|
||||
settings.itb_path = '/home/baiyifan/code/OSTrack/data/itb'
|
||||
settings.lasot_lmdb_path = '/home/baiyifan/code/OSTrack/data/lasot_lmdb'
|
||||
settings.network_path = '/ssddata/baiyifan/artrack_256_full_re/' # Where tracking networks are stored.
|
||||
settings.prj_dir = '/home/baiyifan/code/2d_autoregressive/bins_mask'
|
||||
settings.segmentation_path = '/data1/os/test/segmentation_results'
|
||||
settings.tc128_path = '/home/baiyifan/code/OSTrack/data/TC128'
|
||||
settings.tn_packed_results_path = ''
|
||||
settings.tnl2k_path = '/home/baiyifan/code/OSTrack/data/tnl2k'
|
||||
settings.tpl_path = ''
|
||||
settings.trackingnet_path = '/ssddata/TrackingNet/all_zip'
|
||||
settings.vot18_path = '/home/baiyifan/code/OSTrack/data/vot2018'
|
||||
settings.vot22_path = '/home/baiyifan/code/OSTrack/data/vot2022'
|
||||
settings.vot_path = '/home/baiyifan/code/OSTrack/data/VOT2019'
|
||||
settings.youtubevos_dir = ''
|
||||
|
||||
return settings
|
||||
|
||||
153
lib/test/evaluation/nfsdataset.py
Normal file
153
lib/test/evaluation/nfsdataset.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class NFSDataset(BaseDataset):
|
||||
""" NFS dataset.
|
||||
Publication:
|
||||
Need for Speed: A Benchmark for Higher Frame Rate Object Tracking
|
||||
H. Kiani Galoogahi, A. Fagg, C. Huang, D. Ramanan, and S.Lucey
|
||||
ICCV, 2017
|
||||
http://openaccess.thecvf.com/content_ICCV_2017/papers/Galoogahi_Need_for_Speed_ICCV_2017_paper.pdf
|
||||
Download the dataset from http://ci2cv.net/nfs/index.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.nfs_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
# anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
anno_path = f"{self.base_path}/{sequence_info['name'][4:]}/30/groundtruth.txt"
|
||||
|
||||
# ground_truth_rect = load_text(str(anno_path), delimiter='\t', dtype=np.float64)
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
return Sequence(sequence_info['name'][4:], frames, 'nfs', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "nfs_Gymnastics", "path": "sequences/Gymnastics", "startFrame": 1, "endFrame": 368, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Gymnastics.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_MachLoop_jet", "path": "sequences/MachLoop_jet", "startFrame": 1, "endFrame": 99, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_MachLoop_jet.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_Skiing_red", "path": "sequences/Skiing_red", "startFrame": 1, "endFrame": 69, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Skiing_red.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_Skydiving", "path": "sequences/Skydiving", "startFrame": 1, "endFrame": 196, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_Skydiving.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_airboard_1", "path": "sequences/airboard_1", "startFrame": 1, "endFrame": 425, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airboard_1.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_airplane_landing", "path": "sequences/airplane_landing", "startFrame": 1, "endFrame": 81, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airplane_landing.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_airtable_3", "path": "sequences/airtable_3", "startFrame": 1, "endFrame": 482, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_airtable_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_1", "path": "sequences/basketball_1", "startFrame": 1, "endFrame": 282, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_1.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_2", "path": "sequences/basketball_2", "startFrame": 1, "endFrame": 102, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_3", "path": "sequences/basketball_3", "startFrame": 1, "endFrame": 421, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_6", "path": "sequences/basketball_6", "startFrame": 1, "endFrame": 224, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_basketball_7", "path": "sequences/basketball_7", "startFrame": 1, "endFrame": 240, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_7.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_basketball_player", "path": "sequences/basketball_player", "startFrame": 1, "endFrame": 369, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_player.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_basketball_player_2", "path": "sequences/basketball_player_2", "startFrame": 1, "endFrame": 437, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_basketball_player_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_beach_flipback_person", "path": "sequences/beach_flipback_person", "startFrame": 1, "endFrame": 61, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_beach_flipback_person.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_bee", "path": "sequences/bee", "startFrame": 1, "endFrame": 45, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bee.txt", "object_class": "insect", 'occlusion': False},
|
||||
{"name": "nfs_biker_acrobat", "path": "sequences/biker_acrobat", "startFrame": 1, "endFrame": 128, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_acrobat.txt", "object_class": "bicycle", 'occlusion': False},
|
||||
{"name": "nfs_biker_all_1", "path": "sequences/biker_all_1", "startFrame": 1, "endFrame": 113, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_all_1.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_biker_head_2", "path": "sequences/biker_head_2", "startFrame": 1, "endFrame": 132, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_head_2.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_biker_head_3", "path": "sequences/biker_head_3", "startFrame": 1, "endFrame": 254, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_head_3.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_biker_upper_body", "path": "sequences/biker_upper_body", "startFrame": 1, "endFrame": 194, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_upper_body.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_biker_whole_body", "path": "sequences/biker_whole_body", "startFrame": 1, "endFrame": 572, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_biker_whole_body.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_billiard_2", "path": "sequences/billiard_2", "startFrame": 1, "endFrame": 604, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_3", "path": "sequences/billiard_3", "startFrame": 1, "endFrame": 698, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_6", "path": "sequences/billiard_6", "startFrame": 1, "endFrame": 771, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_7", "path": "sequences/billiard_7", "startFrame": 1, "endFrame": 724, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_7.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_billiard_8", "path": "sequences/billiard_8", "startFrame": 1, "endFrame": 778, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_billiard_8.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_bird_2", "path": "sequences/bird_2", "startFrame": 1, "endFrame": 476, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bird_2.txt", "object_class": "bird", 'occlusion': False},
|
||||
{"name": "nfs_book", "path": "sequences/book", "startFrame": 1, "endFrame": 288, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_book.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_bottle", "path": "sequences/bottle", "startFrame": 1, "endFrame": 2103, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bottle.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_bowling_1", "path": "sequences/bowling_1", "startFrame": 1, "endFrame": 303, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_1.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_2", "path": "sequences/bowling_2", "startFrame": 1, "endFrame": 710, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_2.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_3", "path": "sequences/bowling_3", "startFrame": 1, "endFrame": 271, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_3.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bowling_6", "path": "sequences/bowling_6", "startFrame": 1, "endFrame": 260, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_6.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_bowling_ball", "path": "sequences/bowling_ball", "startFrame": 1, "endFrame": 275, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bowling_ball.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_bunny", "path": "sequences/bunny", "startFrame": 1, "endFrame": 705, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_bunny.txt", "object_class": "mammal", 'occlusion': False},
|
||||
{"name": "nfs_car", "path": "sequences/car", "startFrame": 1, "endFrame": 2020, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car.txt", "object_class": "car", 'occlusion': True},
|
||||
{"name": "nfs_car_camaro", "path": "sequences/car_camaro", "startFrame": 1, "endFrame": 36, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_camaro.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_drifting", "path": "sequences/car_drifting", "startFrame": 1, "endFrame": 173, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_drifting.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_jumping", "path": "sequences/car_jumping", "startFrame": 1, "endFrame": 22, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_jumping.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_rc_rolling", "path": "sequences/car_rc_rolling", "startFrame": 1, "endFrame": 62, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_rc_rolling.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_rc_rotating", "path": "sequences/car_rc_rotating", "startFrame": 1, "endFrame": 80, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_rc_rotating.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_side", "path": "sequences/car_side", "startFrame": 1, "endFrame": 108, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_side.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_car_white", "path": "sequences/car_white", "startFrame": 1, "endFrame": 2063, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_car_white.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_cheetah", "path": "sequences/cheetah", "startFrame": 1, "endFrame": 167, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cheetah.txt", "object_class": "mammal", 'occlusion': True},
|
||||
{"name": "nfs_cup", "path": "sequences/cup", "startFrame": 1, "endFrame": 1281, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cup.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_cup_2", "path": "sequences/cup_2", "startFrame": 1, "endFrame": 182, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_cup_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_dog", "path": "sequences/dog", "startFrame": 1, "endFrame": 1030, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dog_1", "path": "sequences/dog_1", "startFrame": 1, "endFrame": 168, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_1.txt", "object_class": "dog", 'occlusion': False},
|
||||
# {"name": "nfs_dog_2", "path": "sequences/dog_2", "startFrame": 1, "endFrame": 594, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_2.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dog_3", "path": "sequences/dog_3", "startFrame": 1, "endFrame": 200, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dog_3.txt", "object_class": "dog", 'occlusion': False},
|
||||
{"name": "nfs_dogs", "path": "sequences/dogs", "startFrame": 1, "endFrame": 198, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dogs.txt", "object_class": "dog", 'occlusion': True},
|
||||
{"name": "nfs_dollar", "path": "sequences/dollar", "startFrame": 1, "endFrame": 1426, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_dollar.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_drone", "path": "sequences/drone", "startFrame": 1, "endFrame": 70, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_drone.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_ducks_lake", "path": "sequences/ducks_lake", "startFrame": 1, "endFrame": 107, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_ducks_lake.txt", "object_class": "bird", 'occlusion': False},
|
||||
{"name": "nfs_exit", "path": "sequences/exit", "startFrame": 1, "endFrame": 359, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_exit.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_first", "path": "sequences/first", "startFrame": 1, "endFrame": 435, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_first.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_flower", "path": "sequences/flower", "startFrame": 1, "endFrame": 448, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_flower.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_footbal_skill", "path": "sequences/footbal_skill", "startFrame": 1, "endFrame": 131, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_footbal_skill.txt", "object_class": "ball", 'occlusion': True},
|
||||
{"name": "nfs_helicopter", "path": "sequences/helicopter", "startFrame": 1, "endFrame": 310, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_helicopter.txt", "object_class": "aircraft", 'occlusion': False},
|
||||
{"name": "nfs_horse_jumping", "path": "sequences/horse_jumping", "startFrame": 1, "endFrame": 117, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_horse_jumping.txt", "object_class": "horse", 'occlusion': True},
|
||||
{"name": "nfs_horse_running", "path": "sequences/horse_running", "startFrame": 1, "endFrame": 139, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_horse_running.txt", "object_class": "horse", 'occlusion': False},
|
||||
{"name": "nfs_iceskating_6", "path": "sequences/iceskating_6", "startFrame": 1, "endFrame": 603, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_iceskating_6.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_jellyfish_5", "path": "sequences/jellyfish_5", "startFrame": 1, "endFrame": 746, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_jellyfish_5.txt", "object_class": "invertebrate", 'occlusion': False},
|
||||
{"name": "nfs_kid_swing", "path": "sequences/kid_swing", "startFrame": 1, "endFrame": 169, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_kid_swing.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_motorcross", "path": "sequences/motorcross", "startFrame": 1, "endFrame": 39, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_motorcross.txt", "object_class": "vehicle", 'occlusion': True},
|
||||
{"name": "nfs_motorcross_kawasaki", "path": "sequences/motorcross_kawasaki", "startFrame": 1, "endFrame": 65, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_motorcross_kawasaki.txt", "object_class": "vehicle", 'occlusion': False},
|
||||
{"name": "nfs_parkour", "path": "sequences/parkour", "startFrame": 1, "endFrame": 58, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_parkour.txt", "object_class": "person head", 'occlusion': False},
|
||||
{"name": "nfs_person_scooter", "path": "sequences/person_scooter", "startFrame": 1, "endFrame": 413, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_person_scooter.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_pingpong_2", "path": "sequences/pingpong_2", "startFrame": 1, "endFrame": 1277, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_pingpong_7", "path": "sequences/pingpong_7", "startFrame": 1, "endFrame": 1290, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_7.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_pingpong_8", "path": "sequences/pingpong_8", "startFrame": 1, "endFrame": 296, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_pingpong_8.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_purse", "path": "sequences/purse", "startFrame": 1, "endFrame": 968, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_purse.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_rubber", "path": "sequences/rubber", "startFrame": 1, "endFrame": 1328, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_rubber.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_running", "path": "sequences/running", "startFrame": 1, "endFrame": 677, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_running_100_m", "path": "sequences/running_100_m", "startFrame": 1, "endFrame": 313, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_100_m.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_running_100_m_2", "path": "sequences/running_100_m_2", "startFrame": 1, "endFrame": 337, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_100_m_2.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_running_2", "path": "sequences/running_2", "startFrame": 1, "endFrame": 363, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_running_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_1", "path": "sequences/shuffleboard_1", "startFrame": 1, "endFrame": 42, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_1.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_2", "path": "sequences/shuffleboard_2", "startFrame": 1, "endFrame": 41, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_4", "path": "sequences/shuffleboard_4", "startFrame": 1, "endFrame": 62, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_4.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_5", "path": "sequences/shuffleboard_5", "startFrame": 1, "endFrame": 32, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_5.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffleboard_6", "path": "sequences/shuffleboard_6", "startFrame": 1, "endFrame": 52, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffleboard_6.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_2", "path": "sequences/shuffletable_2", "startFrame": 1, "endFrame": 372, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_2.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_3", "path": "sequences/shuffletable_3", "startFrame": 1, "endFrame": 368, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_3.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_shuffletable_4", "path": "sequences/shuffletable_4", "startFrame": 1, "endFrame": 101, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_shuffletable_4.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_ski_long", "path": "sequences/ski_long", "startFrame": 1, "endFrame": 274, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_ski_long.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball", "path": "sequences/soccer_ball", "startFrame": 1, "endFrame": 163, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball_2", "path": "sequences/soccer_ball_2", "startFrame": 1, "endFrame": 1934, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball_2.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_ball_3", "path": "sequences/soccer_ball_3", "startFrame": 1, "endFrame": 1381, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_ball_3.txt", "object_class": "ball", 'occlusion': False},
|
||||
{"name": "nfs_soccer_player_2", "path": "sequences/soccer_player_2", "startFrame": 1, "endFrame": 475, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_player_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_soccer_player_3", "path": "sequences/soccer_player_3", "startFrame": 1, "endFrame": 319, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_soccer_player_3.txt", "object_class": "person", 'occlusion': True},
|
||||
{"name": "nfs_stop_sign", "path": "sequences/stop_sign", "startFrame": 1, "endFrame": 302, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_stop_sign.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_suv", "path": "sequences/suv", "startFrame": 1, "endFrame": 2584, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_suv.txt", "object_class": "car", 'occlusion': False},
|
||||
{"name": "nfs_tiger", "path": "sequences/tiger", "startFrame": 1, "endFrame": 1556, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_tiger.txt", "object_class": "mammal", 'occlusion': False},
|
||||
{"name": "nfs_walking", "path": "sequences/walking", "startFrame": 1, "endFrame": 555, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_walking.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_walking_3", "path": "sequences/walking_3", "startFrame": 1, "endFrame": 1427, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_walking_3.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_water_ski_2", "path": "sequences/water_ski_2", "startFrame": 1, "endFrame": 47, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_water_ski_2.txt", "object_class": "person", 'occlusion': False},
|
||||
{"name": "nfs_yoyo", "path": "sequences/yoyo", "startFrame": 1, "endFrame": 67, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_yoyo.txt", "object_class": "other", 'occlusion': False},
|
||||
{"name": "nfs_zebra_fish", "path": "sequences/zebra_fish", "startFrame": 1, "endFrame": 671, "nz": 5, "ext": "jpg", "anno_path": "anno/nfs_zebra_fish.txt", "object_class": "fish", 'occlusion': False},
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
259
lib/test/evaluation/otbdataset.py
Normal file
259
lib/test/evaluation/otbdataset.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class OTBDataset(BaseDataset):
|
||||
""" OTB-2015 dataset
|
||||
Publication:
|
||||
Object Tracking Benchmark
|
||||
Wu, Yi, Jongwoo Lim, and Ming-hsuan Yan
|
||||
TPAMI, 2015
|
||||
http://faculty.ucmerced.edu/mhyang/papers/pami15_tracking_benchmark.pdf
|
||||
Download the dataset from http://cvlab.hanyang.ac.kr/tracker_benchmark/index.html
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.otb_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
# anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_info['name'])
|
||||
|
||||
# NOTE: OTB has some weird annos which panda cannot handle
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy')
|
||||
|
||||
return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "Basketball", "path": "Basketball/img", "startFrame": 1, "endFrame": 725, "nz": 4, "ext": "jpg", "anno_path": "Basketball/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Biker", "path": "Biker/img", "startFrame": 1, "endFrame": 142, "nz": 4, "ext": "jpg", "anno_path": "Biker/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Bird1", "path": "Bird1/img", "startFrame": 1, "endFrame": 408, "nz": 4, "ext": "jpg", "anno_path": "Bird1/groundtruth_rect.txt",
|
||||
"object_class": "bird"},
|
||||
{"name": "Bird2", "path": "Bird2/img", "startFrame": 1, "endFrame": 99, "nz": 4, "ext": "jpg", "anno_path": "Bird2/groundtruth_rect.txt",
|
||||
"object_class": "bird"},
|
||||
{"name": "BlurBody", "path": "BlurBody/img", "startFrame": 1, "endFrame": 334, "nz": 4, "ext": "jpg", "anno_path": "BlurBody/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "BlurCar1", "path": "BlurCar1/img", "startFrame": 247, "endFrame": 988, "nz": 4, "ext": "jpg", "anno_path": "BlurCar1/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar2", "path": "BlurCar2/img", "startFrame": 1, "endFrame": 585, "nz": 4, "ext": "jpg", "anno_path": "BlurCar2/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar3", "path": "BlurCar3/img", "startFrame": 3, "endFrame": 359, "nz": 4, "ext": "jpg", "anno_path": "BlurCar3/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurCar4", "path": "BlurCar4/img", "startFrame": 18, "endFrame": 397, "nz": 4, "ext": "jpg", "anno_path": "BlurCar4/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "BlurFace", "path": "BlurFace/img", "startFrame": 1, "endFrame": 493, "nz": 4, "ext": "jpg", "anno_path": "BlurFace/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "BlurOwl", "path": "BlurOwl/img", "startFrame": 1, "endFrame": 631, "nz": 4, "ext": "jpg", "anno_path": "BlurOwl/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Board", "path": "Board/img", "startFrame": 1, "endFrame": 698, "nz": 5, "ext": "jpg", "anno_path": "Board/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Bolt", "path": "Bolt/img", "startFrame": 1, "endFrame": 350, "nz": 4, "ext": "jpg", "anno_path": "Bolt/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Bolt2", "path": "Bolt2/img", "startFrame": 1, "endFrame": 293, "nz": 4, "ext": "jpg", "anno_path": "Bolt2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Box", "path": "Box/img", "startFrame": 1, "endFrame": 1161, "nz": 4, "ext": "jpg", "anno_path": "Box/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Boy", "path": "Boy/img", "startFrame": 1, "endFrame": 602, "nz": 4, "ext": "jpg", "anno_path": "Boy/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Car1", "path": "Car1/img", "startFrame": 1, "endFrame": 1020, "nz": 4, "ext": "jpg", "anno_path": "Car1/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car2", "path": "Car2/img", "startFrame": 1, "endFrame": 913, "nz": 4, "ext": "jpg", "anno_path": "Car2/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car24", "path": "Car24/img", "startFrame": 1, "endFrame": 3059, "nz": 4, "ext": "jpg", "anno_path": "Car24/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Car4", "path": "Car4/img", "startFrame": 1, "endFrame": 659, "nz": 4, "ext": "jpg", "anno_path": "Car4/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "CarDark", "path": "CarDark/img", "startFrame": 1, "endFrame": 393, "nz": 4, "ext": "jpg", "anno_path": "CarDark/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "CarScale", "path": "CarScale/img", "startFrame": 1, "endFrame": 252, "nz": 4, "ext": "jpg", "anno_path": "CarScale/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "ClifBar", "path": "ClifBar/img", "startFrame": 1, "endFrame": 472, "nz": 4, "ext": "jpg", "anno_path": "ClifBar/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Coke", "path": "Coke/img", "startFrame": 1, "endFrame": 291, "nz": 4, "ext": "jpg", "anno_path": "Coke/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Couple", "path": "Couple/img", "startFrame": 1, "endFrame": 140, "nz": 4, "ext": "jpg", "anno_path": "Couple/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Coupon", "path": "Coupon/img", "startFrame": 1, "endFrame": 327, "nz": 4, "ext": "jpg", "anno_path": "Coupon/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Crossing", "path": "Crossing/img", "startFrame": 1, "endFrame": 120, "nz": 4, "ext": "jpg", "anno_path": "Crossing/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Crowds", "path": "Crowds/img", "startFrame": 1, "endFrame": 347, "nz": 4, "ext": "jpg", "anno_path": "Crowds/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dancer", "path": "Dancer/img", "startFrame": 1, "endFrame": 225, "nz": 4, "ext": "jpg", "anno_path": "Dancer/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dancer2", "path": "Dancer2/img", "startFrame": 1, "endFrame": 150, "nz": 4, "ext": "jpg", "anno_path": "Dancer2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "David", "path": "David/img", "startFrame": 300, "endFrame": 770, "nz": 4, "ext": "jpg", "anno_path": "David/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "David2", "path": "David2/img", "startFrame": 1, "endFrame": 537, "nz": 4, "ext": "jpg", "anno_path": "David2/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "David3", "path": "David3/img", "startFrame": 1, "endFrame": 252, "nz": 4, "ext": "jpg", "anno_path": "David3/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Deer", "path": "Deer/img", "startFrame": 1, "endFrame": 71, "nz": 4, "ext": "jpg", "anno_path": "Deer/groundtruth_rect.txt",
|
||||
"object_class": "mammal"},
|
||||
{"name": "Diving", "path": "Diving/img", "startFrame": 1, "endFrame": 215, "nz": 4, "ext": "jpg", "anno_path": "Diving/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Dog", "path": "Dog/img", "startFrame": 1, "endFrame": 127, "nz": 4, "ext": "jpg", "anno_path": "Dog/groundtruth_rect.txt",
|
||||
"object_class": "dog"},
|
||||
{"name": "Dog1", "path": "Dog1/img", "startFrame": 1, "endFrame": 1350, "nz": 4, "ext": "jpg", "anno_path": "Dog1/groundtruth_rect.txt",
|
||||
"object_class": "dog"},
|
||||
{"name": "Doll", "path": "Doll/img", "startFrame": 1, "endFrame": 3872, "nz": 4, "ext": "jpg", "anno_path": "Doll/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "DragonBaby", "path": "DragonBaby/img", "startFrame": 1, "endFrame": 113, "nz": 4, "ext": "jpg", "anno_path": "DragonBaby/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Dudek", "path": "Dudek/img", "startFrame": 1, "endFrame": 1145, "nz": 4, "ext": "jpg", "anno_path": "Dudek/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "FaceOcc1", "path": "FaceOcc1/img", "startFrame": 1, "endFrame": 892, "nz": 4, "ext": "jpg", "anno_path": "FaceOcc1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "FaceOcc2", "path": "FaceOcc2/img", "startFrame": 1, "endFrame": 812, "nz": 4, "ext": "jpg", "anno_path": "FaceOcc2/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Fish", "path": "Fish/img", "startFrame": 1, "endFrame": 476, "nz": 4, "ext": "jpg", "anno_path": "Fish/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "FleetFace", "path": "FleetFace/img", "startFrame": 1, "endFrame": 707, "nz": 4, "ext": "jpg", "anno_path": "FleetFace/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Football", "path": "Football/img", "startFrame": 1, "endFrame": 362, "nz": 4, "ext": "jpg", "anno_path": "Football/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Football1", "path": "Football1/img", "startFrame": 1, "endFrame": 74, "nz": 4, "ext": "jpg", "anno_path": "Football1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman1", "path": "Freeman1/img", "startFrame": 1, "endFrame": 326, "nz": 4, "ext": "jpg", "anno_path": "Freeman1/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman3", "path": "Freeman3/img", "startFrame": 1, "endFrame": 460, "nz": 4, "ext": "jpg", "anno_path": "Freeman3/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Freeman4", "path": "Freeman4/img", "startFrame": 1, "endFrame": 283, "nz": 4, "ext": "jpg", "anno_path": "Freeman4/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Girl", "path": "Girl/img", "startFrame": 1, "endFrame": 500, "nz": 4, "ext": "jpg", "anno_path": "Girl/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Girl2", "path": "Girl2/img", "startFrame": 1, "endFrame": 1500, "nz": 4, "ext": "jpg", "anno_path": "Girl2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Gym", "path": "Gym/img", "startFrame": 1, "endFrame": 767, "nz": 4, "ext": "jpg", "anno_path": "Gym/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human2", "path": "Human2/img", "startFrame": 1, "endFrame": 1128, "nz": 4, "ext": "jpg", "anno_path": "Human2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human3", "path": "Human3/img", "startFrame": 1, "endFrame": 1698, "nz": 4, "ext": "jpg", "anno_path": "Human3/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human4_2", "path": "Human4/img", "startFrame": 1, "endFrame": 667, "nz": 4, "ext": "jpg", "anno_path": "Human4/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human4", "path": "Human4/img", "startFrame": 1, "endFrame": 667, "nz": 4, "ext": "jpg", "anno_path": "Human4/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human5", "path": "Human5/img", "startFrame": 1, "endFrame": 713, "nz": 4, "ext": "jpg", "anno_path": "Human5/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human6", "path": "Human6/img", "startFrame": 1, "endFrame": 792, "nz": 4, "ext": "jpg", "anno_path": "Human6/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human7", "path": "Human7/img", "startFrame": 1, "endFrame": 250, "nz": 4, "ext": "jpg", "anno_path": "Human7/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human8", "path": "Human8/img", "startFrame": 1, "endFrame": 128, "nz": 4, "ext": "jpg", "anno_path": "Human8/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Human9", "path": "Human9/img", "startFrame": 1, "endFrame": 305, "nz": 4, "ext": "jpg", "anno_path": "Human9/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Ironman", "path": "Ironman/img", "startFrame": 1, "endFrame": 166, "nz": 4, "ext": "jpg", "anno_path": "Ironman/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Jogging", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
# {"name": "Jogging_1", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.1.txt",
|
||||
# "object_class": "person"},
|
||||
# {"name": "Jogging_2", "path": "Jogging/img", "startFrame": 1, "endFrame": 307, "nz": 4, "ext": "jpg", "anno_path": "Jogging/groundtruth_rect.2.txt",
|
||||
# "object_class": "person"},
|
||||
{"name": "Jump", "path": "Jump/img", "startFrame": 1, "endFrame": 122, "nz": 4, "ext": "jpg", "anno_path": "Jump/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Jumping", "path": "Jumping/img", "startFrame": 1, "endFrame": 313, "nz": 4, "ext": "jpg", "anno_path": "Jumping/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "KiteSurf", "path": "KiteSurf/img", "startFrame": 1, "endFrame": 84, "nz": 4, "ext": "jpg", "anno_path": "KiteSurf/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Lemming", "path": "Lemming/img", "startFrame": 1, "endFrame": 1336, "nz": 4, "ext": "jpg", "anno_path": "Lemming/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Liquor", "path": "Liquor/img", "startFrame": 1, "endFrame": 1741, "nz": 4, "ext": "jpg", "anno_path": "Liquor/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Man", "path": "Man/img", "startFrame": 1, "endFrame": 134, "nz": 4, "ext": "jpg", "anno_path": "Man/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Matrix", "path": "Matrix/img", "startFrame": 1, "endFrame": 100, "nz": 4, "ext": "jpg", "anno_path": "Matrix/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Mhyang", "path": "Mhyang/img", "startFrame": 1, "endFrame": 1490, "nz": 4, "ext": "jpg", "anno_path": "Mhyang/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "MotorRolling", "path": "MotorRolling/img", "startFrame": 1, "endFrame": 164, "nz": 4, "ext": "jpg", "anno_path": "MotorRolling/groundtruth_rect.txt",
|
||||
"object_class": "vehicle"},
|
||||
{"name": "MountainBike", "path": "MountainBike/img", "startFrame": 1, "endFrame": 228, "nz": 4, "ext": "jpg", "anno_path": "MountainBike/groundtruth_rect.txt",
|
||||
"object_class": "bicycle"},
|
||||
{"name": "Panda", "path": "Panda/img", "startFrame": 1, "endFrame": 1000, "nz": 4, "ext": "jpg", "anno_path": "Panda/groundtruth_rect.txt",
|
||||
"object_class": "mammal"},
|
||||
{"name": "RedTeam", "path": "RedTeam/img", "startFrame": 1, "endFrame": 1918, "nz": 4, "ext": "jpg", "anno_path": "RedTeam/groundtruth_rect.txt",
|
||||
"object_class": "vehicle"},
|
||||
{"name": "Rubik", "path": "Rubik/img", "startFrame": 1, "endFrame": 1997, "nz": 4, "ext": "jpg", "anno_path": "Rubik/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Shaking", "path": "Shaking/img", "startFrame": 1, "endFrame": 365, "nz": 4, "ext": "jpg", "anno_path": "Shaking/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Singer1", "path": "Singer1/img", "startFrame": 1, "endFrame": 351, "nz": 4, "ext": "jpg", "anno_path": "Singer1/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Singer2", "path": "Singer2/img", "startFrame": 1, "endFrame": 366, "nz": 4, "ext": "jpg", "anno_path": "Singer2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skater", "path": "Skater/img", "startFrame": 1, "endFrame": 160, "nz": 4, "ext": "jpg", "anno_path": "Skater/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skater2", "path": "Skater2/img", "startFrame": 1, "endFrame": 435, "nz": 4, "ext": "jpg", "anno_path": "Skater2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating1", "path": "Skating1/img", "startFrame": 1, "endFrame": 400, "nz": 4, "ext": "jpg", "anno_path": "Skating1/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2_1", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.1.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skating2_2", "path": "Skating2/img", "startFrame": 1, "endFrame": 473, "nz": 4, "ext": "jpg", "anno_path": "Skating2/groundtruth_rect.2.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Skiing", "path": "Skiing/img", "startFrame": 1, "endFrame": 81, "nz": 4, "ext": "jpg", "anno_path": "Skiing/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Soccer", "path": "Soccer/img", "startFrame": 1, "endFrame": 392, "nz": 4, "ext": "jpg", "anno_path": "Soccer/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Subway", "path": "Subway/img", "startFrame": 1, "endFrame": 175, "nz": 4, "ext": "jpg", "anno_path": "Subway/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Surfer", "path": "Surfer/img", "startFrame": 1, "endFrame": 376, "nz": 4, "ext": "jpg", "anno_path": "Surfer/groundtruth_rect.txt",
|
||||
"object_class": "person head"},
|
||||
{"name": "Suv", "path": "Suv/img", "startFrame": 1, "endFrame": 945, "nz": 4, "ext": "jpg", "anno_path": "Suv/groundtruth_rect.txt",
|
||||
"object_class": "car"},
|
||||
{"name": "Sylvester", "path": "Sylvester/img", "startFrame": 1, "endFrame": 1345, "nz": 4, "ext": "jpg", "anno_path": "Sylvester/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Tiger1", "path": "Tiger1/img", "startFrame": 1, "endFrame": 354, "nz": 4, "ext": "jpg", "anno_path": "Tiger1/groundtruth_rect.txt", "initOmit": 5,
|
||||
"object_class": "other"},
|
||||
{"name": "Tiger2", "path": "Tiger2/img", "startFrame": 1, "endFrame": 365, "nz": 4, "ext": "jpg", "anno_path": "Tiger2/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Toy", "path": "Toy/img", "startFrame": 1, "endFrame": 271, "nz": 4, "ext": "jpg", "anno_path": "Toy/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Trans", "path": "Trans/img", "startFrame": 1, "endFrame": 124, "nz": 4, "ext": "jpg", "anno_path": "Trans/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Trellis", "path": "Trellis/img", "startFrame": 1, "endFrame": 569, "nz": 4, "ext": "jpg", "anno_path": "Trellis/groundtruth_rect.txt",
|
||||
"object_class": "face"},
|
||||
{"name": "Twinnings", "path": "Twinnings/img", "startFrame": 1, "endFrame": 472, "nz": 4, "ext": "jpg", "anno_path": "Twinnings/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Vase", "path": "Vase/img", "startFrame": 1, "endFrame": 271, "nz": 4, "ext": "jpg", "anno_path": "Vase/groundtruth_rect.txt",
|
||||
"object_class": "other"},
|
||||
{"name": "Walking", "path": "Walking/img", "startFrame": 1, "endFrame": 412, "nz": 4, "ext": "jpg", "anno_path": "Walking/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Walking2", "path": "Walking2/img", "startFrame": 1, "endFrame": 500, "nz": 4, "ext": "jpg", "anno_path": "Walking2/groundtruth_rect.txt",
|
||||
"object_class": "person"},
|
||||
{"name": "Woman", "path": "Woman/img", "startFrame": 1, "endFrame": 597, "nz": 4, "ext": "jpg", "anno_path": "Woman/groundtruth_rect.txt",
|
||||
"object_class": "person"}
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
183
lib/test/evaluation/running.py
Normal file
183
lib/test/evaluation/running.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from itertools import product
|
||||
from collections import OrderedDict
|
||||
from lib.test.evaluation import Sequence, Tracker
|
||||
import torch
|
||||
|
||||
|
||||
def _save_tracker_output(seq: Sequence, tracker: Tracker, output: dict):
|
||||
"""Saves the output of the tracker."""
|
||||
|
||||
if not os.path.exists(tracker.results_dir):
|
||||
print("create tracking result dir:", tracker.results_dir)
|
||||
os.makedirs(tracker.results_dir)
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
if not os.path.exists(os.path.join(tracker.results_dir, seq.dataset)):
|
||||
os.makedirs(os.path.join(tracker.results_dir, seq.dataset))
|
||||
'''2021.1.5 create new folder for these two datasets'''
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.dataset, seq.name)
|
||||
else:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.name)
|
||||
|
||||
def save_bb(file, data):
|
||||
tracked_bb = np.array(data).astype(int)
|
||||
np.savetxt(file, tracked_bb, delimiter='\t', fmt='%d')
|
||||
|
||||
def save_time(file, data):
|
||||
exec_times = np.array(data).astype(float)
|
||||
np.savetxt(file, exec_times, delimiter='\t', fmt='%f')
|
||||
|
||||
def save_score(file, data):
|
||||
scores = np.array(data).astype(float)
|
||||
np.savetxt(file, scores, delimiter='\t', fmt='%.2f')
|
||||
|
||||
def _convert_dict(input_dict):
|
||||
data_dict = {}
|
||||
for elem in input_dict:
|
||||
for k, v in elem.items():
|
||||
if k in data_dict.keys():
|
||||
data_dict[k].append(v)
|
||||
else:
|
||||
data_dict[k] = [v, ]
|
||||
return data_dict
|
||||
|
||||
for key, data in output.items():
|
||||
# If data is empty
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if key == 'target_bbox':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}.txt'.format(base_results_path, obj_id)
|
||||
save_bb(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
save_bb(bbox_file, data)
|
||||
|
||||
if key == 'all_boxes':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}_all_boxes.txt'.format(base_results_path, obj_id)
|
||||
save_bb(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
bbox_file = '{}_all_boxes.txt'.format(base_results_path)
|
||||
save_bb(bbox_file, data)
|
||||
|
||||
if key == 'all_scores':
|
||||
if isinstance(data[0], (dict, OrderedDict)):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
bbox_file = '{}_{}_all_scores.txt'.format(base_results_path, obj_id)
|
||||
save_score(bbox_file, d)
|
||||
else:
|
||||
# Single-object mode
|
||||
print("saving scores...")
|
||||
bbox_file = '{}_all_scores.txt'.format(base_results_path)
|
||||
save_score(bbox_file, data)
|
||||
|
||||
elif key == 'time':
|
||||
if isinstance(data[0], dict):
|
||||
data_dict = _convert_dict(data)
|
||||
|
||||
for obj_id, d in data_dict.items():
|
||||
timings_file = '{}_{}_time.txt'.format(base_results_path, obj_id)
|
||||
save_time(timings_file, d)
|
||||
else:
|
||||
timings_file = '{}_time.txt'.format(base_results_path)
|
||||
save_time(timings_file, data)
|
||||
|
||||
|
||||
def run_sequence(seq: Sequence, tracker: Tracker, debug=False, num_gpu=8):
|
||||
"""Runs a tracker on a sequence."""
|
||||
'''2021.1.2 Add multiple gpu support'''
|
||||
try:
|
||||
worker_name = multiprocessing.current_process().name
|
||||
worker_id = int(worker_name[worker_name.find('-') + 1:]) - 1
|
||||
gpu_id = worker_id % num_gpu
|
||||
torch.cuda.set_device(gpu_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
def _results_exist():
|
||||
if seq.object_ids is None:
|
||||
if seq.dataset in ['trackingnet', 'got10k']:
|
||||
base_results_path = os.path.join(tracker.results_dir, seq.dataset, seq.name)
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
else:
|
||||
bbox_file = '{}/{}.txt'.format(tracker.results_dir, seq.name)
|
||||
return os.path.isfile(bbox_file)
|
||||
else:
|
||||
bbox_files = ['{}/{}_{}.txt'.format(tracker.results_dir, seq.name, obj_id) for obj_id in seq.object_ids]
|
||||
missing = [not os.path.isfile(f) for f in bbox_files]
|
||||
return sum(missing) == 0
|
||||
|
||||
if _results_exist() and not debug:
|
||||
print('FPS: {}'.format(-1))
|
||||
return
|
||||
|
||||
print('Tracker: {} {} {} , Sequence: {}'.format(tracker.name, tracker.parameter_name, tracker.run_id, seq.name))
|
||||
|
||||
if debug:
|
||||
output = tracker.run_sequence(seq, debug=debug)
|
||||
else:
|
||||
try:
|
||||
output = tracker.run_sequence(seq, debug=debug)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return
|
||||
|
||||
sys.stdout.flush()
|
||||
|
||||
if isinstance(output['time'][0], (dict, OrderedDict)):
|
||||
exec_time = sum([sum(times.values()) for times in output['time']])
|
||||
num_frames = len(output['time'])
|
||||
else:
|
||||
exec_time = sum(output['time'])
|
||||
num_frames = len(output['time'])
|
||||
|
||||
print('FPS: {}'.format(num_frames / exec_time))
|
||||
|
||||
if not debug:
|
||||
_save_tracker_output(seq, tracker, output)
|
||||
|
||||
|
||||
def run_dataset(dataset, trackers, debug=False, threads=0, num_gpus=8):
|
||||
"""Runs a list of trackers on a dataset.
|
||||
args:
|
||||
dataset: List of Sequence instances, forming a dataset.
|
||||
trackers: List of Tracker instances.
|
||||
debug: Debug level.
|
||||
threads: Number of threads to use (default 0).
|
||||
"""
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
print('Evaluating {:4d} trackers on {:5d} sequences'.format(len(trackers), len(dataset)))
|
||||
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
if threads == 0:
|
||||
mode = 'sequential'
|
||||
else:
|
||||
mode = 'parallel'
|
||||
|
||||
if mode == 'sequential':
|
||||
for seq in dataset:
|
||||
for tracker_info in trackers:
|
||||
run_sequence(seq, tracker_info, debug=debug)
|
||||
elif mode == 'parallel':
|
||||
param_list = [(seq, tracker_info, debug, num_gpus) for seq, tracker_info in product(dataset, trackers)]
|
||||
with multiprocessing.Pool(processes=threads) as pool:
|
||||
pool.starmap(run_sequence, param_list)
|
||||
print('Done')
|
||||
46
lib/test/evaluation/tc128cedataset.py
Normal file
46
lib/test/evaluation/tc128cedataset.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
import glob
|
||||
import six
|
||||
|
||||
|
||||
class TC128CEDataset(BaseDataset):
|
||||
"""
|
||||
TC-128 Dataset (78 newly added sequences)
|
||||
modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tc128_path
|
||||
self.anno_files = sorted(glob.glob(
|
||||
os.path.join(self.base_path, '*/*_gt.txt')))
|
||||
"""filter the newly added sequences (_ce)"""
|
||||
self.anno_files = [s for s in self.anno_files if "_ce" in s]
|
||||
self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
|
||||
self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
|
||||
# valid frame range for each sequence
|
||||
self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.seq_names])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
if isinstance(sequence_name, six.string_types):
|
||||
if not sequence_name in self.seq_names:
|
||||
raise Exception('Sequence {} not found.'.format(sequence_name))
|
||||
index = self.seq_names.index(sequence_name)
|
||||
# load valid frame range
|
||||
frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
|
||||
img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
|
||||
|
||||
# load annotations
|
||||
anno = np.loadtxt(self.anno_files[index], delimiter=',')
|
||||
assert len(img_files) == len(anno)
|
||||
assert anno.shape[1] == 4
|
||||
|
||||
# return img_files, anno
|
||||
return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
44
lib/test/evaluation/tc128dataset.py
Normal file
44
lib/test/evaluation/tc128dataset.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
import glob
|
||||
import six
|
||||
|
||||
|
||||
class TC128Dataset(BaseDataset):
|
||||
"""
|
||||
TC-128 Dataset
|
||||
modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tc128_path
|
||||
self.anno_files = sorted(glob.glob(
|
||||
os.path.join(self.base_path, '*/*_gt.txt')))
|
||||
self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
|
||||
self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
|
||||
# valid frame range for each sequence
|
||||
self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.seq_names])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
if isinstance(sequence_name, six.string_types):
|
||||
if not sequence_name in self.seq_names:
|
||||
raise Exception('Sequence {} not found.'.format(sequence_name))
|
||||
index = self.seq_names.index(sequence_name)
|
||||
# load valid frame range
|
||||
frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
|
||||
img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
|
||||
|
||||
# load annotations
|
||||
anno = np.loadtxt(self.anno_files[index], delimiter=',')
|
||||
assert len(img_files) == len(anno)
|
||||
assert anno.shape[1] == 4
|
||||
|
||||
# return img_files, anno
|
||||
return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
50
lib/test/evaluation/tnl2kdataset.py
Normal file
50
lib/test/evaluation/tnl2kdataset.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text, load_str
|
||||
|
||||
############
|
||||
# current 00000492.png of test_015_Sord_video_Q01_done is damaged and replaced by a copy of 00000491.png
|
||||
############
|
||||
|
||||
|
||||
class TNL2kDataset(BaseDataset):
|
||||
"""
|
||||
TNL2k test set
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.tnl2k_path
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
# class_name = sequence_name.split('-')[0]
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
text_dsp_path = '{}/{}/language.txt'.format(self.base_path, sequence_name)
|
||||
text_dsp = load_str(text_dsp_path)
|
||||
|
||||
frames_path = '{}/{}/imgs'.format(self.base_path, sequence_name)
|
||||
frames_list = [f for f in os.listdir(frames_path)]
|
||||
frames_list = sorted(frames_list)
|
||||
frames_list = ['{}/{}'.format(frames_path, frame_i) for frame_i in frames_list]
|
||||
|
||||
# target_class = class_name
|
||||
return Sequence(sequence_name, frames_list, 'tnl2k', ground_truth_rect.reshape(-1, 4), text_dsp=text_dsp)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self):
|
||||
sequence_list = []
|
||||
for seq in os.listdir(self.base_path):
|
||||
if os.path.isdir(os.path.join(self.base_path, seq)):
|
||||
sequence_list.append(seq)
|
||||
|
||||
return sequence_list
|
||||
291
lib/test/evaluation/tracker.py
Normal file
291
lib/test/evaluation/tracker.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
import time
|
||||
import cv2 as cv
|
||||
|
||||
from lib.utils.lmdb_utils import decode_img
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
|
||||
def trackerlist(name: str, parameter_name: str, dataset_name: str, run_ids = None, display_name: str = None,
|
||||
result_only=False):
|
||||
"""Generate list of trackers.
|
||||
args:
|
||||
name: Name of tracking method.
|
||||
parameter_name: Name of parameter file.
|
||||
run_ids: A single or list of run_ids.
|
||||
display_name: Name to be displayed in the result plots.
|
||||
"""
|
||||
if run_ids is None or isinstance(run_ids, int):
|
||||
run_ids = [run_ids]
|
||||
return [Tracker(name, parameter_name, dataset_name, run_id, display_name, result_only) for run_id in run_ids]
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Wraps the tracker for evaluation and running purposes.
|
||||
args:
|
||||
name: Name of tracking method.
|
||||
parameter_name: Name of parameter file.
|
||||
run_id: The run id.
|
||||
display_name: Name to be displayed in the result plots.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parameter_name: str, dataset_name: str, run_id: int = None, display_name: str = None,
|
||||
result_only=False):
|
||||
assert run_id is None or isinstance(run_id, int)
|
||||
|
||||
self.name = name
|
||||
self.parameter_name = parameter_name
|
||||
self.dataset_name = dataset_name
|
||||
self.run_id = run_id
|
||||
self.display_name = display_name
|
||||
|
||||
env = env_settings()
|
||||
if self.run_id is None:
|
||||
self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name)
|
||||
else:
|
||||
self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id)
|
||||
if result_only:
|
||||
self.results_dir = '{}/{}'.format(env.results_path, self.name)
|
||||
|
||||
tracker_module_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'..', 'tracker', '%s.py' % self.name))
|
||||
if os.path.isfile(tracker_module_abspath):
|
||||
tracker_module = importlib.import_module('lib.test.tracker.{}'.format(self.name))
|
||||
self.tracker_class = tracker_module.get_tracker_class()
|
||||
else:
|
||||
self.tracker_class = None
|
||||
|
||||
def create_tracker(self, params):
|
||||
tracker = self.tracker_class(params, self.dataset_name)
|
||||
return tracker
|
||||
|
||||
def run_sequence(self, seq, debug=None):
|
||||
"""Run tracker on sequence.
|
||||
args:
|
||||
seq: Sequence to run the tracker on.
|
||||
visualization: Set visualization flag (None means default value specified in the parameters).
|
||||
debug: Set debug level (None means default value specified in the parameters).
|
||||
multiobj_mode: Which mode to use for multiple objects.
|
||||
"""
|
||||
params = self.get_parameters()
|
||||
|
||||
debug_ = debug
|
||||
if debug is None:
|
||||
debug_ = getattr(params, 'debug', 0)
|
||||
|
||||
params.debug = debug_
|
||||
|
||||
# Get init information
|
||||
init_info = seq.init_info()
|
||||
|
||||
tracker = self.create_tracker(params)
|
||||
|
||||
output = self._track_sequence(tracker, seq, init_info)
|
||||
return output
|
||||
|
||||
def _track_sequence(self, tracker, seq, init_info):
|
||||
# Define outputs
|
||||
# Each field in output is a list containing tracker prediction for each frame.
|
||||
|
||||
# In case of single object tracking mode:
|
||||
# target_bbox[i] is the predicted bounding box for frame i
|
||||
# time[i] is the processing time for frame i
|
||||
|
||||
# In case of multi object tracking mode:
|
||||
# target_bbox[i] is an OrderedDict, where target_bbox[i][obj_id] is the predicted box for target obj_id in
|
||||
# frame i
|
||||
# time[i] is either the processing time for frame i, or an OrderedDict containing processing times for each
|
||||
# object in frame i
|
||||
|
||||
output = {'target_bbox': [],
|
||||
'time': []}
|
||||
if tracker.params.save_all_boxes:
|
||||
output['all_boxes'] = []
|
||||
output['all_scores'] = []
|
||||
|
||||
def _store_outputs(tracker_out: dict, defaults=None):
|
||||
defaults = {} if defaults is None else defaults
|
||||
for key in output.keys():
|
||||
val = tracker_out.get(key, defaults.get(key, None))
|
||||
if key in tracker_out or val is not None:
|
||||
output[key].append(val)
|
||||
|
||||
# Initialize
|
||||
image = self._read_image(seq.frames[0])
|
||||
|
||||
start_time = time.time()
|
||||
out = tracker.initialize(image, init_info)
|
||||
if out is None:
|
||||
out = {}
|
||||
|
||||
prev_output = OrderedDict(out)
|
||||
init_default = {'target_bbox': init_info.get('init_bbox'),
|
||||
'time': time.time() - start_time}
|
||||
if tracker.params.save_all_boxes:
|
||||
init_default['all_boxes'] = out['all_boxes']
|
||||
init_default['all_scores'] = out['all_scores']
|
||||
|
||||
_store_outputs(out, init_default)
|
||||
|
||||
for frame_num, frame_path in enumerate(seq.frames[1:], start=1):
|
||||
image = self._read_image(frame_path)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
info = seq.frame_info(frame_num)
|
||||
info['previous_output'] = prev_output
|
||||
|
||||
if len(seq.ground_truth_rect) > 1:
|
||||
info['gt_bbox'] = seq.ground_truth_rect[frame_num]
|
||||
out = tracker.track(image, info)
|
||||
prev_output = OrderedDict(out)
|
||||
_store_outputs(out, {'time': time.time() - start_time})
|
||||
|
||||
for key in ['target_bbox', 'all_boxes', 'all_scores']:
|
||||
if key in output and len(output[key]) <= 1:
|
||||
output.pop(key)
|
||||
|
||||
return output
|
||||
|
||||
def run_video(self, videofilepath, optional_box=None, debug=None, visdom_info=None, save_results=False):
|
||||
"""Run the tracker with the vieofile.
|
||||
args:
|
||||
debug: Debug level.
|
||||
"""
|
||||
|
||||
params = self.get_parameters()
|
||||
|
||||
debug_ = debug
|
||||
if debug is None:
|
||||
debug_ = getattr(params, 'debug', 0)
|
||||
params.debug = debug_
|
||||
|
||||
params.tracker_name = self.name
|
||||
params.param_name = self.parameter_name
|
||||
# self._init_visdom(visdom_info, debug_)
|
||||
|
||||
multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))
|
||||
|
||||
if multiobj_mode == 'default':
|
||||
tracker = self.create_tracker(params)
|
||||
|
||||
elif multiobj_mode == 'parallel':
|
||||
tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
|
||||
else:
|
||||
raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))
|
||||
|
||||
assert os.path.isfile(videofilepath), "Invalid param {}".format(videofilepath)
|
||||
", videofilepath must be a valid videofile"
|
||||
|
||||
output_boxes = []
|
||||
|
||||
cap = cv.VideoCapture(videofilepath)
|
||||
display_name = 'Display: ' + tracker.params.tracker_name
|
||||
cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
|
||||
cv.resizeWindow(display_name, 960, 720)
|
||||
success, frame = cap.read()
|
||||
cv.imshow(display_name, frame)
|
||||
|
||||
def _build_init_info(box):
|
||||
return {'init_bbox': box}
|
||||
|
||||
if success is not True:
|
||||
print("Read frame from {} failed.".format(videofilepath))
|
||||
exit(-1)
|
||||
if optional_box is not None:
|
||||
assert isinstance(optional_box, (list, tuple))
|
||||
assert len(optional_box) == 4, "valid box's foramt is [x,y,w,h]"
|
||||
tracker.initialize(frame, _build_init_info(optional_box))
|
||||
output_boxes.append(optional_box)
|
||||
else:
|
||||
while True:
|
||||
# cv.waitKey()
|
||||
frame_disp = frame.copy()
|
||||
|
||||
cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL,
|
||||
1.5, (0, 0, 0), 1)
|
||||
|
||||
x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
|
||||
init_state = [x, y, w, h]
|
||||
tracker.initialize(frame, _build_init_info(init_state))
|
||||
output_boxes.append(init_state)
|
||||
break
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
frame_disp = frame.copy()
|
||||
|
||||
# Draw box
|
||||
out = tracker.track(frame)
|
||||
state = [int(s) for s in out['target_bbox']]
|
||||
output_boxes.append(state)
|
||||
|
||||
cv.rectangle(frame_disp, (state[0], state[1]), (state[2] + state[0], state[3] + state[1]),
|
||||
(0, 255, 0), 5)
|
||||
|
||||
font_color = (0, 0, 0)
|
||||
cv.putText(frame_disp, 'Tracking!', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
cv.putText(frame_disp, 'Press r to reset', (20, 55), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
cv.putText(frame_disp, 'Press q to quit', (20, 80), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
|
||||
font_color, 1)
|
||||
|
||||
# Display the resulting frame
|
||||
cv.imshow(display_name, frame_disp)
|
||||
key = cv.waitKey(1)
|
||||
if key == ord('q'):
|
||||
break
|
||||
elif key == ord('r'):
|
||||
ret, frame = cap.read()
|
||||
frame_disp = frame.copy()
|
||||
|
||||
cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
|
||||
(0, 0, 0), 1)
|
||||
|
||||
cv.imshow(display_name, frame_disp)
|
||||
x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
|
||||
init_state = [x, y, w, h]
|
||||
tracker.initialize(frame, _build_init_info(init_state))
|
||||
output_boxes.append(init_state)
|
||||
|
||||
# When everything done, release the capture
|
||||
cap.release()
|
||||
cv.destroyAllWindows()
|
||||
|
||||
if save_results:
|
||||
if not os.path.exists(self.results_dir):
|
||||
os.makedirs(self.results_dir)
|
||||
video_name = Path(videofilepath).stem
|
||||
base_results_path = os.path.join(self.results_dir, 'video_{}'.format(video_name))
|
||||
|
||||
tracked_bb = np.array(output_boxes).astype(int)
|
||||
bbox_file = '{}.txt'.format(base_results_path)
|
||||
np.savetxt(bbox_file, tracked_bb, delimiter='\t', fmt='%d')
|
||||
|
||||
|
||||
def get_parameters(self):
|
||||
"""Get parameters."""
|
||||
param_module = importlib.import_module('lib.test.parameter.{}'.format(self.name))
|
||||
params = param_module.parameters(self.parameter_name)
|
||||
return params
|
||||
|
||||
def _read_image(self, image_file: str):
|
||||
if isinstance(image_file, str):
|
||||
im = cv.imread(image_file)
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
elif isinstance(image_file, list) and len(image_file) == 2:
|
||||
return decode_img(image_file[0], image_file[1])
|
||||
else:
|
||||
raise ValueError("type of image_file should be str or list")
|
||||
|
||||
|
||||
|
||||
58
lib/test/evaluation/trackingnetdataset.py
Normal file
58
lib/test/evaluation/trackingnetdataset.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
import os
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class TrackingNetDataset(BaseDataset):
|
||||
""" TrackingNet test set.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.trackingnet_path
|
||||
|
||||
sets = 'TEST'
|
||||
if not isinstance(sets, (list, tuple)):
|
||||
if sets == 'TEST':
|
||||
sets = ['TEST']
|
||||
elif sets == 'TRAIN':
|
||||
sets = ['TRAIN_{}'.format(i) for i in range(5)]
|
||||
|
||||
self.sequence_list = self._list_sequences(self.base_path, sets)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(set, seq_name) for set, seq_name in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, set, sequence_name):
|
||||
anno_path = '{}/{}/anno/{}.txt'.format(self.base_path, set, sequence_name)
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
frames_path = '{}/{}/frames/{}'.format(self.base_path, set, sequence_name)
|
||||
frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
|
||||
frame_list.sort(key=lambda f: int(f[:-4]))
|
||||
frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
|
||||
|
||||
return Sequence(sequence_name, frames_list, 'trackingnet', ground_truth_rect.reshape(-1, 4))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _list_sequences(self, root, set_ids):
|
||||
sequence_list = []
|
||||
|
||||
for s in set_ids:
|
||||
anno_dir = os.path.join(root, s, "anno")
|
||||
sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
|
||||
|
||||
sequence_list += sequences_cur_set
|
||||
|
||||
return sequence_list
|
||||
298
lib/test/evaluation/uavdataset.py
Normal file
298
lib/test/evaluation/uavdataset.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import numpy as np
|
||||
from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
|
||||
from lib.test.utils.load_text import load_text
|
||||
|
||||
|
||||
class UAVDataset(BaseDataset):
|
||||
""" UAV123 dataset.
|
||||
Publication:
|
||||
A Benchmark and Simulator for UAV Tracking.
|
||||
Matthias Mueller, Neil Smith and Bernard Ghanem
|
||||
ECCV, 2016
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2016/A%20Benchmark%20and%20Simulator%20for%20UAV%20Tracking.pdf
|
||||
Download the dataset from https://ivul.kaust.edu.sa/Pages/pub-benchmark-simulator-uav.aspx
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_path = self.env_settings.uav_path
|
||||
self.sequence_info_list = self._get_sequence_info_list()
|
||||
|
||||
def get_sequence_list(self):
|
||||
# return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
|
||||
|
||||
def _construct_sequence(self, sequence_info):
|
||||
sequence_path = sequence_info['path']
|
||||
nz = sequence_info['nz']
|
||||
ext = sequence_info['ext']
|
||||
start_frame = sequence_info['startFrame']
|
||||
end_frame = sequence_info['endFrame']
|
||||
|
||||
init_omit = 0
|
||||
if 'initOmit' in sequence_info:
|
||||
init_omit = sequence_info['initOmit']
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext) for frame_num in range(start_frame+init_omit, end_frame+1)]
|
||||
|
||||
anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
|
||||
|
||||
ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy')
|
||||
|
||||
return Sequence(sequence_info['name'][4:], frames, 'uav', ground_truth_rect[init_omit:,:],
|
||||
object_class=sequence_info['object_class'])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_info_list)
|
||||
|
||||
def _get_sequence_info_list(self):
|
||||
sequence_info_list = [
|
||||
{"name": "uav_bike1", "path": "data_seq/UAV123/bike1", "startFrame": 1, "endFrame": 3085, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike1.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bike2", "path": "data_seq/UAV123/bike2", "startFrame": 1, "endFrame": 553, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike2.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bike3", "path": "data_seq/UAV123/bike3", "startFrame": 1, "endFrame": 433, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bike3.txt", "object_class": "vehicle"},
|
||||
{"name": "uav_bird1_1", "path": "data_seq/UAV123/bird1", "startFrame": 1, "endFrame": 253, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_1.txt", "object_class": "bird"},
|
||||
{"name": "uav_bird1_2", "path": "data_seq/UAV123/bird1", "startFrame": 775, "endFrame": 1477, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_2.txt", "object_class": "bird"},
|
||||
{"name": "uav_bird1_3", "path": "data_seq/UAV123/bird1", "startFrame": 1573, "endFrame": 2437, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/bird1_3.txt", "object_class": "bird"},
|
||||
{"name": "uav_boat1", "path": "data_seq/UAV123/boat1", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat1.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat2", "path": "data_seq/UAV123/boat2", "startFrame": 1, "endFrame": 799, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat2.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat3", "path": "data_seq/UAV123/boat3", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat3.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat4", "path": "data_seq/UAV123/boat4", "startFrame": 1, "endFrame": 553, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat4.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat5", "path": "data_seq/UAV123/boat5", "startFrame": 1, "endFrame": 505, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat5.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat6", "path": "data_seq/UAV123/boat6", "startFrame": 1, "endFrame": 805, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat6.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat7", "path": "data_seq/UAV123/boat7", "startFrame": 1, "endFrame": 535, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat7.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat8", "path": "data_seq/UAV123/boat8", "startFrame": 1, "endFrame": 685, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat8.txt", "object_class": "vessel"},
|
||||
{"name": "uav_boat9", "path": "data_seq/UAV123/boat9", "startFrame": 1, "endFrame": 1399, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/boat9.txt", "object_class": "vessel"},
|
||||
{"name": "uav_building1", "path": "data_seq/UAV123/building1", "startFrame": 1, "endFrame": 469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building1.txt", "object_class": "other"},
|
||||
{"name": "uav_building2", "path": "data_seq/UAV123/building2", "startFrame": 1, "endFrame": 577, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building2.txt", "object_class": "other"},
|
||||
{"name": "uav_building3", "path": "data_seq/UAV123/building3", "startFrame": 1, "endFrame": 829, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building3.txt", "object_class": "other"},
|
||||
{"name": "uav_building4", "path": "data_seq/UAV123/building4", "startFrame": 1, "endFrame": 787, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building4.txt", "object_class": "other"},
|
||||
{"name": "uav_building5", "path": "data_seq/UAV123/building5", "startFrame": 1, "endFrame": 481, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/building5.txt", "object_class": "other"},
|
||||
{"name": "uav_car1_1", "path": "data_seq/UAV123/car1", "startFrame": 1, "endFrame": 751, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_2", "path": "data_seq/UAV123/car1", "startFrame": 751, "endFrame": 1627, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_3", "path": "data_seq/UAV123/car1", "startFrame": 1627, "endFrame": 2629, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_3.txt", "object_class": "car"},
|
||||
{"name": "uav_car10", "path": "data_seq/UAV123/car10", "startFrame": 1, "endFrame": 1405, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car10.txt", "object_class": "car"},
|
||||
{"name": "uav_car11", "path": "data_seq/UAV123/car11", "startFrame": 1, "endFrame": 337, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car11.txt", "object_class": "car"},
|
||||
{"name": "uav_car12", "path": "data_seq/UAV123/car12", "startFrame": 1, "endFrame": 499, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car12.txt", "object_class": "car"},
|
||||
{"name": "uav_car13", "path": "data_seq/UAV123/car13", "startFrame": 1, "endFrame": 415, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car13.txt", "object_class": "car"},
|
||||
{"name": "uav_car14", "path": "data_seq/UAV123/car14", "startFrame": 1, "endFrame": 1327, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car14.txt", "object_class": "car"},
|
||||
{"name": "uav_car15", "path": "data_seq/UAV123/car15", "startFrame": 1, "endFrame": 469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car15.txt", "object_class": "car"},
|
||||
{"name": "uav_car16_1", "path": "data_seq/UAV123/car16", "startFrame": 1, "endFrame": 415, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car16_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car16_2", "path": "data_seq/UAV123/car16", "startFrame": 415, "endFrame": 1993, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car16_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car17", "path": "data_seq/UAV123/car17", "startFrame": 1, "endFrame": 1057, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car17.txt", "object_class": "car"},
|
||||
{"name": "uav_car18", "path": "data_seq/UAV123/car18", "startFrame": 1, "endFrame": 1207, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car18.txt", "object_class": "car"},
|
||||
{"name": "uav_car1_s", "path": "data_seq/UAV123/car1_s", "startFrame": 1, "endFrame": 1475, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car1_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car2", "path": "data_seq/UAV123/car2", "startFrame": 1, "endFrame": 1321, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car2.txt", "object_class": "car"},
|
||||
{"name": "uav_car2_s", "path": "data_seq/UAV123/car2_s", "startFrame": 1, "endFrame": 320, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car2_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car3", "path": "data_seq/UAV123/car3", "startFrame": 1, "endFrame": 1717, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car3.txt", "object_class": "car"},
|
||||
{"name": "uav_car3_s", "path": "data_seq/UAV123/car3_s", "startFrame": 1, "endFrame": 1300, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car3_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car4", "path": "data_seq/UAV123/car4", "startFrame": 1, "endFrame": 1345, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car4.txt", "object_class": "car"},
|
||||
{"name": "uav_car4_s", "path": "data_seq/UAV123/car4_s", "startFrame": 1, "endFrame": 830, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car4_s.txt", "object_class": "car"},
|
||||
{"name": "uav_car5", "path": "data_seq/UAV123/car5", "startFrame": 1, "endFrame": 745, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car5.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_1", "path": "data_seq/UAV123/car6", "startFrame": 1, "endFrame": 487, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_2", "path": "data_seq/UAV123/car6", "startFrame": 487, "endFrame": 1807, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_3", "path": "data_seq/UAV123/car6", "startFrame": 1807, "endFrame": 2953, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_3.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_4", "path": "data_seq/UAV123/car6", "startFrame": 2953, "endFrame": 3925, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_4.txt", "object_class": "car"},
|
||||
{"name": "uav_car6_5", "path": "data_seq/UAV123/car6", "startFrame": 3925, "endFrame": 4861, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car6_5.txt", "object_class": "car"},
|
||||
{"name": "uav_car7", "path": "data_seq/UAV123/car7", "startFrame": 1, "endFrame": 1033, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car7.txt", "object_class": "car"},
|
||||
{"name": "uav_car8_1", "path": "data_seq/UAV123/car8", "startFrame": 1, "endFrame": 1357, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car8_1.txt", "object_class": "car"},
|
||||
{"name": "uav_car8_2", "path": "data_seq/UAV123/car8", "startFrame": 1357, "endFrame": 2575, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car8_2.txt", "object_class": "car"},
|
||||
{"name": "uav_car9", "path": "data_seq/UAV123/car9", "startFrame": 1, "endFrame": 1879, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/car9.txt", "object_class": "car"},
|
||||
{"name": "uav_group1_1", "path": "data_seq/UAV123/group1", "startFrame": 1, "endFrame": 1333, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_2", "path": "data_seq/UAV123/group1", "startFrame": 1333, "endFrame": 2515, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_3", "path": "data_seq/UAV123/group1", "startFrame": 2515, "endFrame": 3925, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group1_4", "path": "data_seq/UAV123/group1", "startFrame": 3925, "endFrame": 4873, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group1_4.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_1", "path": "data_seq/UAV123/group2", "startFrame": 1, "endFrame": 907, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_2", "path": "data_seq/UAV123/group2", "startFrame": 907, "endFrame": 1771, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group2_3", "path": "data_seq/UAV123/group2", "startFrame": 1771, "endFrame": 2683, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group2_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_1", "path": "data_seq/UAV123/group3", "startFrame": 1, "endFrame": 1567, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_1.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_2", "path": "data_seq/UAV123/group3", "startFrame": 1567, "endFrame": 2827, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_2.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_3", "path": "data_seq/UAV123/group3", "startFrame": 2827, "endFrame": 4369, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_3.txt", "object_class": "person"},
|
||||
{"name": "uav_group3_4", "path": "data_seq/UAV123/group3", "startFrame": 4369, "endFrame": 5527, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/group3_4.txt", "object_class": "person"},
|
||||
{"name": "uav_person1", "path": "data_seq/UAV123/person1", "startFrame": 1, "endFrame": 799, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person1.txt", "object_class": "person"},
|
||||
{"name": "uav_person10", "path": "data_seq/UAV123/person10", "startFrame": 1, "endFrame": 1021, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person10.txt", "object_class": "person"},
|
||||
{"name": "uav_person11", "path": "data_seq/UAV123/person11", "startFrame": 1, "endFrame": 721, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person11.txt", "object_class": "person"},
|
||||
{"name": "uav_person12_1", "path": "data_seq/UAV123/person12", "startFrame": 1, "endFrame": 601, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person12_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person12_2", "path": "data_seq/UAV123/person12", "startFrame": 601, "endFrame": 1621, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person12_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person13", "path": "data_seq/UAV123/person13", "startFrame": 1, "endFrame": 883, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person13.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_1", "path": "data_seq/UAV123/person14", "startFrame": 1, "endFrame": 847, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person14_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_2", "path": "data_seq/UAV123/person14", "startFrame": 847, "endFrame": 1813, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person14_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person14_3", "path": "data_seq/UAV123/person14", "startFrame": 1813, "endFrame": 2923,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person14_3.txt", "object_class": "person"},
|
||||
{"name": "uav_person15", "path": "data_seq/UAV123/person15", "startFrame": 1, "endFrame": 1339, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person15.txt", "object_class": "person"},
|
||||
{"name": "uav_person16", "path": "data_seq/UAV123/person16", "startFrame": 1, "endFrame": 1147, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person16.txt", "object_class": "person"},
|
||||
{"name": "uav_person17_1", "path": "data_seq/UAV123/person17", "startFrame": 1, "endFrame": 1501, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person17_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person17_2", "path": "data_seq/UAV123/person17", "startFrame": 1501, "endFrame": 2347,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person17_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person18", "path": "data_seq/UAV123/person18", "startFrame": 1, "endFrame": 1393, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person18.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_1", "path": "data_seq/UAV123/person19", "startFrame": 1, "endFrame": 1243, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person19_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_2", "path": "data_seq/UAV123/person19", "startFrame": 1243, "endFrame": 2791,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person19_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person19_3", "path": "data_seq/UAV123/person19", "startFrame": 2791, "endFrame": 4357,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/person19_3.txt", "object_class": "person"},
|
||||
{"name": "uav_person1_s", "path": "data_seq/UAV123/person1_s", "startFrame": 1, "endFrame": 1600, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person1_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_1", "path": "data_seq/UAV123/person2", "startFrame": 1, "endFrame": 1189, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_2", "path": "data_seq/UAV123/person2", "startFrame": 1189, "endFrame": 2623, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person20", "path": "data_seq/UAV123/person20", "startFrame": 1, "endFrame": 1783, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person20.txt", "object_class": "person"},
|
||||
{"name": "uav_person21", "path": "data_seq/UAV123/person21", "startFrame": 1, "endFrame": 487, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person21.txt", "object_class": "person"},
|
||||
{"name": "uav_person22", "path": "data_seq/UAV123/person22", "startFrame": 1, "endFrame": 199, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person22.txt", "object_class": "person"},
|
||||
{"name": "uav_person23", "path": "data_seq/UAV123/person23", "startFrame": 1, "endFrame": 397, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person23.txt", "object_class": "person"},
|
||||
{"name": "uav_person2_s", "path": "data_seq/UAV123/person2_s", "startFrame": 1, "endFrame": 250, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person2_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person3", "path": "data_seq/UAV123/person3", "startFrame": 1, "endFrame": 643, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person3.txt", "object_class": "person"},
|
||||
{"name": "uav_person3_s", "path": "data_seq/UAV123/person3_s", "startFrame": 1, "endFrame": 505, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person3_s.txt", "object_class": "person"},
|
||||
{"name": "uav_person4_1", "path": "data_seq/UAV123/person4", "startFrame": 1, "endFrame": 1501, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person4_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person4_2", "path": "data_seq/UAV123/person4", "startFrame": 1501, "endFrame": 2743, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person4_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person5_1", "path": "data_seq/UAV123/person5", "startFrame": 1, "endFrame": 877, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person5_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person5_2", "path": "data_seq/UAV123/person5", "startFrame": 877, "endFrame": 2101, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person5_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person6", "path": "data_seq/UAV123/person6", "startFrame": 1, "endFrame": 901, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person6.txt", "object_class": "person"},
|
||||
{"name": "uav_person7_1", "path": "data_seq/UAV123/person7", "startFrame": 1, "endFrame": 1249, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person7_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person7_2", "path": "data_seq/UAV123/person7", "startFrame": 1249, "endFrame": 2065, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person7_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person8_1", "path": "data_seq/UAV123/person8", "startFrame": 1, "endFrame": 1075, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person8_1.txt", "object_class": "person"},
|
||||
{"name": "uav_person8_2", "path": "data_seq/UAV123/person8", "startFrame": 1075, "endFrame": 1525, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person8_2.txt", "object_class": "person"},
|
||||
{"name": "uav_person9", "path": "data_seq/UAV123/person9", "startFrame": 1, "endFrame": 661, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/person9.txt", "object_class": "person"},
|
||||
{"name": "uav_truck1", "path": "data_seq/UAV123/truck1", "startFrame": 1, "endFrame": 463, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck1.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck2", "path": "data_seq/UAV123/truck2", "startFrame": 1, "endFrame": 385, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck2.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck3", "path": "data_seq/UAV123/truck3", "startFrame": 1, "endFrame": 535, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck3.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck4_1", "path": "data_seq/UAV123/truck4", "startFrame": 1, "endFrame": 577, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck4_1.txt", "object_class": "truck"},
|
||||
{"name": "uav_truck4_2", "path": "data_seq/UAV123/truck4", "startFrame": 577, "endFrame": 1261, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/truck4_2.txt", "object_class": "truck"},
|
||||
{"name": "uav_uav1_1", "path": "data_seq/UAV123/uav1", "startFrame": 1, "endFrame": 1555, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_1.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav1_2", "path": "data_seq/UAV123/uav1", "startFrame": 1555, "endFrame": 2377, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_2.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav1_3", "path": "data_seq/UAV123/uav1", "startFrame": 2473, "endFrame": 3469, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav1_3.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav2", "path": "data_seq/UAV123/uav2", "startFrame": 1, "endFrame": 133, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav2.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav3", "path": "data_seq/UAV123/uav3", "startFrame": 1, "endFrame": 265, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav3.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav4", "path": "data_seq/UAV123/uav4", "startFrame": 1, "endFrame": 157, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav4.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav5", "path": "data_seq/UAV123/uav5", "startFrame": 1, "endFrame": 139, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav5.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav6", "path": "data_seq/UAV123/uav6", "startFrame": 1, "endFrame": 109, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav6.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav7", "path": "data_seq/UAV123/uav7", "startFrame": 1, "endFrame": 373, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav7.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_uav8", "path": "data_seq/UAV123/uav8", "startFrame": 1, "endFrame": 301, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/uav8.txt", "object_class": "aircraft"},
|
||||
{"name": "uav_wakeboard1", "path": "data_seq/UAV123/wakeboard1", "startFrame": 1, "endFrame": 421, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard1.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard10", "path": "data_seq/UAV123/wakeboard10", "startFrame": 1, "endFrame": 469,
|
||||
"nz": 6, "ext": "jpg", "anno_path": "anno/UAV123/wakeboard10.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard2", "path": "data_seq/UAV123/wakeboard2", "startFrame": 1, "endFrame": 733, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard2.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard3", "path": "data_seq/UAV123/wakeboard3", "startFrame": 1, "endFrame": 823, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard3.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard4", "path": "data_seq/UAV123/wakeboard4", "startFrame": 1, "endFrame": 697, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard4.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard5", "path": "data_seq/UAV123/wakeboard5", "startFrame": 1, "endFrame": 1675, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard5.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard6", "path": "data_seq/UAV123/wakeboard6", "startFrame": 1, "endFrame": 1165, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard6.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard7", "path": "data_seq/UAV123/wakeboard7", "startFrame": 1, "endFrame": 199, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard7.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard8", "path": "data_seq/UAV123/wakeboard8", "startFrame": 1, "endFrame": 1543, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard8.txt", "object_class": "person"},
|
||||
{"name": "uav_wakeboard9", "path": "data_seq/UAV123/wakeboard9", "startFrame": 1, "endFrame": 355, "nz": 6,
|
||||
"ext": "jpg", "anno_path": "anno/UAV123/wakeboard9.txt", "object_class": "person"}
|
||||
]
|
||||
|
||||
return sequence_info_list
|
||||
349
lib/test/evaluation/votdataset.py
Normal file
349
lib/test/evaluation/votdataset.py
Normal file
@@ -0,0 +1,349 @@
|
||||
from typing import Union, TextIO
|
||||
|
||||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
from lib.test.evaluation.data import SequenceList, BaseDataset, Sequence
|
||||
|
||||
|
||||
class VOTDataset(BaseDataset):
|
||||
"""
|
||||
VOT2018 dataset
|
||||
|
||||
Publication:
|
||||
The sixth Visual Object Tracking VOT2018 challenge results.
|
||||
Matej Kristan, Ales Leonardis, Jiri Matas, Michael Felsberg, Roman Pfugfelder, Luka Cehovin Zajc, Tomas Vojir,
|
||||
Goutam Bhat, Alan Lukezic et al.
|
||||
ECCV, 2018
|
||||
https://prints.vicos.si/publications/365
|
||||
|
||||
Download the dataset from http://www.votchallenge.net/vot2018/dataset.html
|
||||
"""
|
||||
def __init__(self, year=18):
|
||||
super().__init__()
|
||||
self.year = year
|
||||
if year == 18:
|
||||
self.base_path = self.env_settings.vot18_path
|
||||
elif year == 20:
|
||||
self.base_path = self.env_settings.vot20_path
|
||||
elif year == 22:
|
||||
self.base_path = self.env_settings.vot22_path
|
||||
self.sequence_list = self._get_sequence_list(year)
|
||||
|
||||
def get_sequence_list(self):
|
||||
return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
|
||||
|
||||
def _construct_sequence(self, sequence_name):
|
||||
sequence_path = sequence_name
|
||||
nz = 8
|
||||
ext = 'jpg'
|
||||
start_frame = 1
|
||||
|
||||
anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
|
||||
|
||||
if self.year == 18 or self.year == 22:
|
||||
try:
|
||||
ground_truth_rect = np.loadtxt(str(anno_path), dtype=np.float64)
|
||||
except:
|
||||
ground_truth_rect = np.loadtxt(str(anno_path), delimiter=',', dtype=np.float64)
|
||||
|
||||
end_frame = ground_truth_rect.shape[0]
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/color/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path, frame=frame_num, nz=nz, ext=ext)
|
||||
for frame_num in range(start_frame, end_frame+1)]
|
||||
|
||||
# Convert gt
|
||||
if ground_truth_rect.shape[1] > 4:
|
||||
gt_x_all = ground_truth_rect[:, [0, 2, 4, 6]]
|
||||
gt_y_all = ground_truth_rect[:, [1, 3, 5, 7]]
|
||||
|
||||
x1 = np.amin(gt_x_all, 1).reshape(-1,1)
|
||||
y1 = np.amin(gt_y_all, 1).reshape(-1,1)
|
||||
x2 = np.amax(gt_x_all, 1).reshape(-1,1)
|
||||
y2 = np.amax(gt_y_all, 1).reshape(-1,1)
|
||||
|
||||
ground_truth_rect = np.concatenate((x1, y1, x2-x1, y2-y1), 1)
|
||||
|
||||
elif self.year == 20:
|
||||
ground_truth_rect = read_file(str(anno_path))
|
||||
ground_truth_rect = np.array(ground_truth_rect, dtype=np.float64)
|
||||
end_frame = ground_truth_rect.shape[0]
|
||||
|
||||
frames = ['{base_path}/{sequence_path}/color/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
|
||||
sequence_path=sequence_path,
|
||||
frame=frame_num, nz=nz, ext=ext)
|
||||
for frame_num in range(start_frame, end_frame + 1)]
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return Sequence(sequence_name, frames, 'vot', ground_truth_rect)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _get_sequence_list(self, year):
|
||||
if year == 18:
|
||||
sequence_list= ['ants1',
|
||||
'ants3',
|
||||
'bag',
|
||||
'ball1',
|
||||
'ball2',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'blanket',
|
||||
'bmx',
|
||||
'bolt1',
|
||||
'bolt2',
|
||||
'book',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'crossing',
|
||||
'dinosaur',
|
||||
'drone_across',
|
||||
'drone_flip',
|
||||
'drone1',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'fish3',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'glove',
|
||||
'godfather',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'leaves',
|
||||
'matrix',
|
||||
'motocross1',
|
||||
'motocross2',
|
||||
'nature',
|
||||
'pedestrian1',
|
||||
'rabbit',
|
||||
'racing',
|
||||
'road',
|
||||
'shaking',
|
||||
'sheep',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'tiger',
|
||||
'traffic',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
elif year == 20:
|
||||
|
||||
sequence_list= ['agility',
|
||||
'ants1',
|
||||
'ball2',
|
||||
'ball3',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'bolt1',
|
||||
'book',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'dinosaur',
|
||||
'dribble',
|
||||
'drone1',
|
||||
'drone_across',
|
||||
'drone_flip',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'glove',
|
||||
'godfather',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'hand02',
|
||||
'hand2',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'lamb',
|
||||
'leaves',
|
||||
'marathon',
|
||||
'matrix',
|
||||
'monkey',
|
||||
'motocross1',
|
||||
'nature',
|
||||
'polo',
|
||||
'rabbit',
|
||||
'rabbit2',
|
||||
'road',
|
||||
'rowing',
|
||||
'shaking',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'surfing',
|
||||
'tiger',
|
||||
'wheel',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
elif year == 22:
|
||||
sequence_list= ['agility',
|
||||
'animal',
|
||||
'ants1',
|
||||
'bag',
|
||||
'ball2',
|
||||
'ball3',
|
||||
'basketball',
|
||||
'birds1',
|
||||
'birds2',
|
||||
'bolt1',
|
||||
'book',
|
||||
'bubble',
|
||||
'butterfly',
|
||||
'car1',
|
||||
'conduction1',
|
||||
'crabs1',
|
||||
'dinosaur',
|
||||
'diver',
|
||||
'drone1',
|
||||
'drone_across',
|
||||
'fernando',
|
||||
'fish1',
|
||||
'fish2',
|
||||
'flamingo1',
|
||||
'frisbee',
|
||||
'girl',
|
||||
'graduate',
|
||||
'gymnastics1',
|
||||
'gymnastics2',
|
||||
'gymnastics3',
|
||||
'hand',
|
||||
'hand2',
|
||||
'handball1',
|
||||
'handball2',
|
||||
'helicopter',
|
||||
'iceskater1',
|
||||
'iceskater2',
|
||||
'kangaroo',
|
||||
'lamb',
|
||||
'leaves',
|
||||
'marathon',
|
||||
'matrix',
|
||||
'monkey',
|
||||
'motocross1',
|
||||
'nature',
|
||||
'polo',
|
||||
'rabbit',
|
||||
'rabbit2',
|
||||
'rowing',
|
||||
'shaking',
|
||||
'singer2',
|
||||
'singer3',
|
||||
'snake',
|
||||
'soccer1',
|
||||
'soccer2',
|
||||
'soldier',
|
||||
'surfing',
|
||||
'tennis',
|
||||
'tiger',
|
||||
'wheel',
|
||||
'wiper',
|
||||
'zebrafish1']
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
def parse(string):
|
||||
"""
|
||||
parse string to the appropriate region format and return region object
|
||||
"""
|
||||
from vot.region.shapes import Rectangle, Polygon, Mask
|
||||
|
||||
|
||||
if string[0] == 'm':
|
||||
# input is a mask - decode it
|
||||
m_, offset_, region = create_mask_from_string(string[1:].split(','))
|
||||
# return Mask(m_, offset=offset_)
|
||||
return region
|
||||
else:
|
||||
# input is not a mask - check if special, rectangle or polygon
|
||||
raise NotImplementedError
|
||||
print('Unknown region format.')
|
||||
return None
|
||||
|
||||
|
||||
def read_file(fp: Union[str, TextIO]):
|
||||
if isinstance(fp, str):
|
||||
with open(fp) as file:
|
||||
lines = file.readlines()
|
||||
else:
|
||||
lines = fp.readlines()
|
||||
|
||||
regions = []
|
||||
# iterate over all lines in the file
|
||||
for i, line in enumerate(lines):
|
||||
regions.append(parse(line.strip()))
|
||||
return regions
|
||||
|
||||
|
||||
def create_mask_from_string(mask_encoding):
|
||||
"""
|
||||
mask_encoding: a string in the following format: x0, y0, w, h, RLE
|
||||
output: mask, offset
|
||||
mask: 2-D binary mask, size defined in the mask encoding
|
||||
offset: (x, y) offset of the mask in the image coordinates
|
||||
"""
|
||||
elements = [int(el) for el in mask_encoding]
|
||||
tl_x, tl_y, region_w, region_h = elements[:4]
|
||||
rle = np.array([el for el in elements[4:]], dtype=np.int32)
|
||||
|
||||
# create mask from RLE within target region
|
||||
mask = rle_to_mask(rle, region_w, region_h)
|
||||
region = [tl_x, tl_y, region_w, region_h]
|
||||
|
||||
return mask, (tl_x, tl_y), region
|
||||
|
||||
@jit(nopython=True)
|
||||
def rle_to_mask(rle, width, height):
|
||||
"""
|
||||
rle: input rle mask encoding
|
||||
each evenly-indexed element represents number of consecutive 0s
|
||||
each oddly indexed element represents number of consecutive 1s
|
||||
width and height are dimensions of the mask
|
||||
output: 2-D binary mask
|
||||
"""
|
||||
# allocate list of zeros
|
||||
v = [0] * (width * height)
|
||||
|
||||
# set id of the last different element to the beginning of the vector
|
||||
idx_ = 0
|
||||
for i in range(len(rle)):
|
||||
if i % 2 != 0:
|
||||
# write as many 1s as RLE says (zeros are already in the vector)
|
||||
for j in range(rle[i]):
|
||||
v[idx_+j] = 1
|
||||
idx_ += rle[i]
|
||||
0
lib/test/parameter/__init__.py
Normal file
0
lib/test/parameter/__init__.py
Normal file
30
lib/test/parameter/artrack.py
Normal file
30
lib/test/parameter/artrack.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from lib.test.utils import TrackerParams
|
||||
import os
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.config.artrack.config import cfg, update_config_from_file
|
||||
|
||||
|
||||
def parameters(yaml_name: str):
|
||||
params = TrackerParams()
|
||||
prj_dir = env_settings().prj_dir
|
||||
save_dir = env_settings().save_dir
|
||||
# update default config from yaml file
|
||||
yaml_file = os.path.join(prj_dir, 'experiments/artrack/%s.yaml' % yaml_name)
|
||||
update_config_from_file(yaml_file)
|
||||
params.cfg = cfg
|
||||
print("test config: ", cfg)
|
||||
|
||||
# template and search region
|
||||
params.template_factor = cfg.TEST.TEMPLATE_FACTOR
|
||||
params.template_size = cfg.TEST.TEMPLATE_SIZE
|
||||
params.search_factor = cfg.TEST.SEARCH_FACTOR
|
||||
params.search_size = cfg.TEST.SEARCH_SIZE
|
||||
|
||||
# Network checkpoint path
|
||||
params.checkpoint = os.path.join(save_dir, "checkpoints/train/artrack/%s/ARTrack_ep%04d.pth.tar" %
|
||||
(yaml_name, cfg.TEST.EPOCH))
|
||||
|
||||
# whether to save boxes from all queries
|
||||
params.save_all_boxes = False
|
||||
|
||||
return params
|
||||
30
lib/test/parameter/artrack_seq.py
Normal file
30
lib/test/parameter/artrack_seq.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from lib.test.utils import TrackerParams
|
||||
import os
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
from lib.config.artrack_seq.config import cfg, update_config_from_file
|
||||
|
||||
|
||||
def parameters(yaml_name: str):
|
||||
params = TrackerParams()
|
||||
prj_dir = env_settings().prj_dir
|
||||
save_dir = env_settings().save_dir
|
||||
# update default config from yaml file
|
||||
yaml_file = os.path.join(prj_dir, 'experiments/artrack_seq/%s.yaml' % yaml_name)
|
||||
update_config_from_file(yaml_file)
|
||||
params.cfg = cfg
|
||||
print("test config: ", cfg)
|
||||
|
||||
# template and search region
|
||||
params.template_factor = cfg.TEST.TEMPLATE_FACTOR
|
||||
params.template_size = cfg.TEST.TEMPLATE_SIZE
|
||||
params.search_factor = cfg.TEST.SEARCH_FACTOR
|
||||
params.search_size = cfg.TEST.SEARCH_SIZE
|
||||
|
||||
# Network checkpoint path
|
||||
params.checkpoint = os.path.join(save_dir, "checkpoints/train/artrack_seq/%s/ARTrackSeq_ep%04d.pth.tar" %
|
||||
(yaml_name, cfg.TEST.EPOCH))
|
||||
|
||||
# whether to save boxes from all queries
|
||||
params.save_all_boxes = False
|
||||
|
||||
return params
|
||||
0
lib/test/tracker/__init__.py
Normal file
0
lib/test/tracker/__init__.py
Normal file
225
lib/test/tracker/artrack.py
Normal file
225
lib/test/tracker/artrack.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import math
|
||||
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.test.tracker.basetracker import BaseTracker
|
||||
import torch
|
||||
|
||||
from lib.test.tracker.vis_utils import gen_visualization
|
||||
from lib.test.utils.hann import hann2d
|
||||
from lib.train.data.processing_utils import sample_target
|
||||
# for debug
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from lib.test.tracker.data_utils import Preprocessor
|
||||
from lib.utils.box_ops import clip_box
|
||||
from lib.utils.ce_utils import generate_mask_cond
|
||||
import random
|
||||
|
||||
class RandomErasing(object):
|
||||
def __init__(self, EPSILON=0.5, sl=0.02, sh=0.33, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
|
||||
self.EPSILON = EPSILON
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
if random.uniform(0, 1) > self.EPSILON:
|
||||
return img
|
||||
|
||||
for attempt in range(100):
|
||||
print(img.size())
|
||||
area = img.size()[1] * img.size()[2]
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.size()[2] and h < img.size()[1]:
|
||||
x1 = random.randint(0, img.size()[1] - h)
|
||||
y1 = random.randint(0, img.size()[2] - w)
|
||||
if img.size()[0] == 3:
|
||||
# img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
# img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
# img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
||||
# img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w))
|
||||
else:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
# img[0, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, h, w))
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ARTrack(BaseTracker):
|
||||
def __init__(self, params, dataset_name):
|
||||
super(ARTrack, self).__init__(params)
|
||||
network = build_artrack(params.cfg, training=False)
|
||||
print(self.params.checkpoint)
|
||||
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
|
||||
self.cfg = params.cfg
|
||||
self.bins = self.cfg.MODEL.BINS
|
||||
self.network = network.cuda()
|
||||
self.network.eval()
|
||||
self.preprocessor = Preprocessor()
|
||||
self.state = None
|
||||
self.range = self.cfg.MODEL.RANGE
|
||||
|
||||
self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
|
||||
# motion constrain
|
||||
self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()
|
||||
|
||||
# for debug
|
||||
self.debug = params.debug
|
||||
self.use_visdom = params.debug
|
||||
self.frame_id = 0
|
||||
self.erase = RandomErasing()
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
self.save_dir = "debug"
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
else:
|
||||
# self.add_hook()
|
||||
self._init_visdom(None, 1)
|
||||
# for save boxes from all queries
|
||||
self.save_all_boxes = params.save_all_boxes
|
||||
self.z_dict1 = {}
|
||||
|
||||
def initialize(self, image, info: dict):
|
||||
# forward the template once
|
||||
|
||||
z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
|
||||
output_sz=self.params.template_size)#output_sz=self.params.template_size
|
||||
self.z_patch_arr = z_patch_arr
|
||||
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
|
||||
with torch.no_grad():
|
||||
self.z_dict1 = template
|
||||
|
||||
self.box_mask_z = None
|
||||
#if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
# template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor,
|
||||
# template.tensors.device).squeeze(1)
|
||||
# self.box_mask_z = generate_mask_cond(self.cfg, 1, template.tensors.device, template_bbox)
|
||||
|
||||
# save states
|
||||
self.state = info['init_bbox']
|
||||
self.frame_id = 0
|
||||
if self.save_all_boxes:
|
||||
'''save all predicted boxes'''
|
||||
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
|
||||
return {"all_boxes": all_boxes_save}
|
||||
|
||||
def track(self, image, info: dict = None):
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
H, W, _ = image.shape
|
||||
self.frame_id += 1
|
||||
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
|
||||
output_sz=self.params.search_size) # (x1, y1, w, h)
|
||||
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
x_dict = search
|
||||
# merge the template and the search
|
||||
# run the transformer
|
||||
out_dict = self.network.forward(
|
||||
template=self.z_dict1.tensors, search=x_dict.tensors)
|
||||
|
||||
# add hann windows
|
||||
# pred_score_map = out_dict['score_map']
|
||||
# response = self.output_window * pred_score_map
|
||||
# pred_boxes = self.network.box_head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map'])
|
||||
# pred_boxes = pred_boxes.view(-1, 4)
|
||||
|
||||
pred_boxes = out_dict['seqs'][:, 0:4] / (self.bins - 1) - magic_num
|
||||
pred_boxes = pred_boxes.view(-1, 4).mean(dim=0)
|
||||
pred_new = pred_boxes
|
||||
pred_new[2] = pred_boxes[2] - pred_boxes[0]
|
||||
pred_new[3] = pred_boxes[3] - pred_boxes[1]
|
||||
pred_new[0] = pred_boxes[0] + pred_boxes[2]/2
|
||||
pred_new[1] = pred_boxes[1] + pred_boxes[3]/2
|
||||
|
||||
pred_boxes = (pred_new * self.params.search_size / resize_factor).tolist()
|
||||
|
||||
# Baseline: Take the mean of all pred boxes as the final result
|
||||
#pred_box = (pred_boxes.mean(
|
||||
# dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
||||
# get the final box result
|
||||
self.state = clip_box(self.map_box_back(pred_boxes, resize_factor), H, W, margin=10)
|
||||
|
||||
# for debug
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
x1, y1, w, h = self.state
|
||||
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(image_BGR, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
|
||||
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
|
||||
cv2.imwrite(save_path, image_BGR)
|
||||
else:
|
||||
self.visdom.register((image, info['gt_bbox'].tolist(), self.state), 'Tracking', 1, 'Tracking')
|
||||
|
||||
self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region')
|
||||
self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template')
|
||||
self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map')
|
||||
self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map_hann')
|
||||
|
||||
if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']:
|
||||
removed_indexes_s = out_dict['removed_indexes_s']
|
||||
removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s]
|
||||
masked_search = gen_visualization(x_patch_arr, removed_indexes_s)
|
||||
self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search')
|
||||
|
||||
while self.pause_mode:
|
||||
if self.step:
|
||||
self.step = False
|
||||
break
|
||||
|
||||
if self.save_all_boxes:
|
||||
'''save all predictions'''
|
||||
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
|
||||
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
|
||||
return {"target_bbox": self.state,
|
||||
"all_boxes": all_boxes_save}
|
||||
else:
|
||||
return {"target_bbox": self.state}
|
||||
|
||||
def map_box_back(self, pred_box: list, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
#cx_real = cx + cx_prev
|
||||
#cy_real = cy + cy_prev
|
||||
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
||||
|
||||
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
||||
|
||||
def add_hook(self):
|
||||
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
|
||||
|
||||
for i in range(12):
|
||||
self.network.backbone.blocks[i].attn.register_forward_hook(
|
||||
# lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
)
|
||||
|
||||
self.enc_attn_weights = enc_attn_weights
|
||||
|
||||
|
||||
def get_tracker_class():
|
||||
return ARTrack
|
||||
209
lib/test/tracker/artrack_seq.py
Normal file
209
lib/test/tracker/artrack_seq.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import math
|
||||
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
from lib.test.tracker.basetracker import BaseTracker
|
||||
import torch
|
||||
|
||||
from lib.test.tracker.vis_utils import gen_visualization
|
||||
from lib.test.utils.hann import hann2d
|
||||
from lib.train.data.processing_utils import sample_target, transform_image_to_crop
|
||||
# for debug
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from lib.test.tracker.data_utils import Preprocessor
|
||||
from lib.utils.box_ops import clip_box
|
||||
from lib.utils.ce_utils import generate_mask_cond
|
||||
|
||||
|
||||
class ARTrackSeq(BaseTracker):
|
||||
def __init__(self, params, dataset_name):
|
||||
super(ARTrackSeq, self).__init__(params)
|
||||
network = build_artrack_seq(params.cfg, training=False)
|
||||
print(self.params.checkpoint)
|
||||
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
|
||||
self.cfg = params.cfg
|
||||
self.bins = self.cfg.MODEL.BINS
|
||||
self.network = network.cuda()
|
||||
self.network.eval()
|
||||
self.preprocessor = Preprocessor()
|
||||
self.state = None
|
||||
|
||||
self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
|
||||
# motion constrain
|
||||
self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()
|
||||
|
||||
# for debug
|
||||
self.debug = params.debug
|
||||
self.use_visdom = params.debug
|
||||
self.frame_id = 0
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
self.save_dir = "debug"
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
else:
|
||||
# self.add_hook()
|
||||
self._init_visdom(None, 1)
|
||||
# for save boxes from all queries
|
||||
self.save_all_boxes = params.save_all_boxes
|
||||
self.z_dict1 = {}
|
||||
self.store_result = None
|
||||
self.save_all = 7
|
||||
self.x_feat = None
|
||||
self.update = None
|
||||
self.update_threshold = 5.0
|
||||
self.update_intervals = 1
|
||||
|
||||
def initialize(self, image, info: dict):
|
||||
# forward the template once
|
||||
self.x_feat = None
|
||||
|
||||
z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
|
||||
output_sz=self.params.template_size) # output_sz=self.params.template_size
|
||||
self.z_patch_arr = z_patch_arr
|
||||
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
|
||||
with torch.no_grad():
|
||||
self.z_dict1 = template
|
||||
|
||||
self.box_mask_z = None
|
||||
# if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
# template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor,
|
||||
# template.tensors.device).squeeze(1)
|
||||
# self.box_mask_z = generate_mask_cond(self.cfg, 1, template.tensors.device, template_bbox)
|
||||
|
||||
# save states
|
||||
self.state = info['init_bbox']
|
||||
self.store_result = [info['init_bbox'].copy()]
|
||||
for i in range(self.save_all - 1):
|
||||
self.store_result.append(info['init_bbox'].copy())
|
||||
self.frame_id = 0
|
||||
self.update = None
|
||||
if self.save_all_boxes:
|
||||
'''save all predicted boxes'''
|
||||
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
|
||||
return {"all_boxes": all_boxes_save}
|
||||
|
||||
def track(self, image, info: dict = None):
|
||||
H, W, _ = image.shape
|
||||
self.frame_id += 1
|
||||
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
|
||||
output_sz=self.params.search_size) # (x1, y1, w, h)
|
||||
for i in range(len(self.store_result)):
|
||||
box_temp = self.store_result[i].copy()
|
||||
box_out_i = transform_image_to_crop(torch.Tensor(self.store_result[i]), torch.Tensor(self.state),
|
||||
resize_factor,
|
||||
torch.Tensor([self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]),
|
||||
normalize=True)
|
||||
box_out_i[2] = box_out_i[2] + box_out_i[0]
|
||||
box_out_i[3] = box_out_i[3] + box_out_i[1]
|
||||
box_out_i = box_out_i.clamp(min=-0.5, max=1.5)
|
||||
box_out_i = (box_out_i + 0.5) * (self.bins - 1)
|
||||
if i == 0:
|
||||
seqs_out = box_out_i
|
||||
else:
|
||||
seqs_out = torch.cat((seqs_out, box_out_i), dim=-1)
|
||||
seqs_out = seqs_out.unsqueeze(0)
|
||||
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
|
||||
with torch.no_grad():
|
||||
x_dict = search
|
||||
# merge the template and the search
|
||||
# run the transformer
|
||||
out_dict = self.network.forward(
|
||||
template=self.z_dict1.tensors, search=x_dict.tensors,
|
||||
seq_input=seqs_out, stage="sequence", search_feature=self.x_feat, update=None)
|
||||
|
||||
self.x_feat = out_dict['x_feat']
|
||||
|
||||
pred_boxes = out_dict['seqs'][:, 0:4] / (self.bins - 1) - 0.5
|
||||
pred_boxes = pred_boxes.view(-1, 4).mean(dim=0)
|
||||
pred_new = pred_boxes
|
||||
pred_new[2] = pred_boxes[2] - pred_boxes[0]
|
||||
pred_new[3] = pred_boxes[3] - pred_boxes[1]
|
||||
pred_new[0] = pred_boxes[0] + pred_new[2] / 2
|
||||
pred_new[1] = pred_boxes[1] + pred_new[3] / 2
|
||||
pred_boxes = (pred_new * self.params.search_size / resize_factor).tolist()
|
||||
|
||||
# Baseline: Take the mean of all pred boxes as the final result
|
||||
# pred_box = (pred_boxes.mean(
|
||||
# dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
|
||||
# get the final box result
|
||||
self.state = clip_box(self.map_box_back(pred_boxes, resize_factor), H, W, margin=10)
|
||||
if len(self.store_result) < self.save_all:
|
||||
self.store_result.append(self.state.copy())
|
||||
else:
|
||||
for i in range(self.save_all):
|
||||
if i != self.save_all - 1:
|
||||
self.store_result[i] = self.store_result[i + 1]
|
||||
else:
|
||||
self.store_result[i] = self.state.copy()
|
||||
|
||||
# for debug
|
||||
if self.debug:
|
||||
if not self.use_visdom:
|
||||
x1, y1, w, h = self.state
|
||||
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2)
|
||||
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
|
||||
cv2.imwrite(save_path, image_BGR)
|
||||
else:
|
||||
self.visdom.register((image, info['gt_bbox'].tolist(), self.state), 'Tracking', 1, 'Tracking')
|
||||
|
||||
self.visdom.register(torch.from_numpy(x_patch_arr).permute(2, 0, 1), 'image', 1, 'search_region')
|
||||
self.visdom.register(torch.from_numpy(self.z_patch_arr).permute(2, 0, 1), 'image', 1, 'template')
|
||||
self.visdom.register(pred_score_map.view(self.feat_sz, self.feat_sz), 'heatmap', 1, 'score_map')
|
||||
self.visdom.register((pred_score_map * self.output_window).view(self.feat_sz, self.feat_sz), 'heatmap',
|
||||
1, 'score_map_hann')
|
||||
|
||||
if 'removed_indexes_s' in out_dict and out_dict['removed_indexes_s']:
|
||||
removed_indexes_s = out_dict['removed_indexes_s']
|
||||
removed_indexes_s = [removed_indexes_s_i.cpu().numpy() for removed_indexes_s_i in removed_indexes_s]
|
||||
masked_search = gen_visualization(x_patch_arr, removed_indexes_s)
|
||||
self.visdom.register(torch.from_numpy(masked_search).permute(2, 0, 1), 'image', 1, 'masked_search')
|
||||
|
||||
while self.pause_mode:
|
||||
if self.step:
|
||||
self.step = False
|
||||
break
|
||||
|
||||
if self.save_all_boxes:
|
||||
'''save all predictions'''
|
||||
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
|
||||
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
|
||||
return {"target_bbox": self.state,
|
||||
"all_boxes": all_boxes_save}
|
||||
else:
|
||||
return {"target_bbox": self.state}
|
||||
|
||||
def map_box_back(self, pred_box: list, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
# cx_real = cx + cx_prev
|
||||
# cy_real = cy + cy_prev
|
||||
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
|
||||
|
||||
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
|
||||
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
|
||||
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
|
||||
half_side = 0.5 * self.params.search_size / resize_factor
|
||||
cx_real = cx + (cx_prev - half_side)
|
||||
cy_real = cy + (cy_prev - half_side)
|
||||
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
|
||||
|
||||
def add_hook(self):
|
||||
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
|
||||
|
||||
for i in range(12):
|
||||
self.network.backbone.blocks[i].attn.register_forward_hook(
|
||||
# lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
lambda self, input, output: enc_attn_weights.append(output[1])
|
||||
)
|
||||
|
||||
self.enc_attn_weights = enc_attn_weights
|
||||
|
||||
|
||||
def get_tracker_class():
|
||||
return ARTrackSeq
|
||||
89
lib/test/tracker/basetracker.py
Normal file
89
lib/test/tracker/basetracker.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from _collections import OrderedDict
|
||||
|
||||
from lib.train.data.processing_utils import transform_image_to_crop
|
||||
from lib.vis.visdom_cus import Visdom
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
"""Base class for all trackers."""
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.visdom = None
|
||||
|
||||
def predicts_segmentation_mask(self):
|
||||
return False
|
||||
|
||||
def initialize(self, image, info: dict) -> dict:
|
||||
"""Overload this function in your tracker. This should initialize the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def track(self, image, info: dict = None) -> dict:
|
||||
"""Overload this function in your tracker. This should track in the frame and update the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def visdom_draw_tracking(self, image, box, segmentation=None):
|
||||
if isinstance(box, OrderedDict):
|
||||
box = [v for k, v in box.items()]
|
||||
else:
|
||||
box = (box,)
|
||||
if segmentation is None:
|
||||
self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
|
||||
else:
|
||||
self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')
|
||||
|
||||
def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'):
|
||||
# box_in: list [x1, y1, w, h], not normalized
|
||||
# box_extract: same as box_in
|
||||
# out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized
|
||||
if crop_type == 'template':
|
||||
crop_sz = torch.Tensor([self.params.template_size, self.params.template_size])
|
||||
elif crop_type == 'search':
|
||||
crop_sz = torch.Tensor([self.params.search_size, self.params.search_size])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
box_in = torch.tensor(box_in)
|
||||
if box_extract is None:
|
||||
box_extract = box_in
|
||||
else:
|
||||
box_extract = torch.tensor(box_extract)
|
||||
template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True)
|
||||
template_bbox = template_bbox.view(1, 1, 4).to(device)
|
||||
|
||||
return template_bbox
|
||||
|
||||
def _init_visdom(self, visdom_info, debug):
|
||||
visdom_info = {} if visdom_info is None else visdom_info
|
||||
self.pause_mode = False
|
||||
self.step = False
|
||||
self.next_seq = False
|
||||
if debug > 0 and visdom_info.get('use_visdom', True):
|
||||
try:
|
||||
self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
|
||||
visdom_info=visdom_info)
|
||||
|
||||
# # Show help
|
||||
# help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
|
||||
# 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
|
||||
# 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
|
||||
# 'block list.'
|
||||
# self.visdom.register(help_text, 'text', 1, 'Help')
|
||||
except:
|
||||
time.sleep(0.5)
|
||||
print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
|
||||
'!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')
|
||||
|
||||
def _visdom_ui_handler(self, data):
|
||||
if data['event_type'] == 'KeyPress':
|
||||
if data['key'] == ' ':
|
||||
self.pause_mode = not self.pause_mode
|
||||
|
||||
elif data['key'] == 'ArrowRight' and self.pause_mode:
|
||||
self.step = True
|
||||
|
||||
elif data['key'] == 'n':
|
||||
self.next_seq = True
|
||||
46
lib/test/tracker/data_utils.py
Normal file
46
lib/test/tracker/data_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from lib.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class Preprocessor(object):
|
||||
def __init__(self):
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
# Deal with the image patch
|
||||
img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
|
||||
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
||||
# Deal with the attention mask
|
||||
amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
|
||||
return NestedTensor(img_tensor_norm, amask_tensor)
|
||||
|
||||
|
||||
class PreprocessorX(object):
|
||||
def __init__(self):
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
# Deal with the image patch
|
||||
img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
|
||||
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
|
||||
# Deal with the attention mask
|
||||
amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
|
||||
return img_tensor_norm, amask_tensor
|
||||
|
||||
|
||||
class PreprocessorX_onnx(object):
|
||||
def __init__(self):
|
||||
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))
|
||||
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
|
||||
|
||||
def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
|
||||
"""img_arr: (H,W,3), amask_arr: (H,W)"""
|
||||
# Deal with the image patch
|
||||
img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2)
|
||||
img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W)
|
||||
# Deal with the attention mask
|
||||
amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W)
|
||||
return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool)
|
||||
59
lib/test/tracker/vis_utils.py
Normal file
59
lib/test/tracker/vis_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
############## used for visulize eliminated tokens #################
|
||||
def get_keep_indices(decisions):
|
||||
keep_indices = []
|
||||
for i in range(3):
|
||||
if i == 0:
|
||||
keep_indices.append(decisions[i])
|
||||
else:
|
||||
keep_indices.append(keep_indices[-1][decisions[i]])
|
||||
return keep_indices
|
||||
|
||||
|
||||
def gen_masked_tokens(tokens, indices, alpha=0.2):
|
||||
# indices = [i for i in range(196) if i not in indices]
|
||||
indices = indices[0].astype(int)
|
||||
tokens = tokens.copy()
|
||||
tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
|
||||
return tokens
|
||||
|
||||
|
||||
def recover_image(tokens, H, W, Hp, Wp, patch_size):
|
||||
# image: (C, 196, 16, 16)
|
||||
image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
|
||||
return image
|
||||
|
||||
|
||||
def pad_img(img):
|
||||
height, width, channels = img.shape
|
||||
im_bg = np.ones((height, width + 8, channels)) * 255
|
||||
im_bg[0:height, 0:width, :] = img
|
||||
return im_bg
|
||||
|
||||
|
||||
def gen_visualization(image, mask_indices, patch_size=16):
|
||||
# image [224, 224, 3]
|
||||
# mask_indices, list of masked token indices
|
||||
|
||||
# mask mask_indices need to cat
|
||||
# mask_indices = mask_indices[::-1]
|
||||
num_stages = len(mask_indices)
|
||||
for i in range(1, num_stages):
|
||||
mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)
|
||||
|
||||
# keep_indices = get_keep_indices(decisions)
|
||||
image = np.asarray(image)
|
||||
H, W, C = image.shape
|
||||
Hp, Wp = H // patch_size, W // patch_size
|
||||
image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)
|
||||
|
||||
stages = [
|
||||
recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
|
||||
for i in range(num_stages)
|
||||
]
|
||||
imgs = [image] + stages
|
||||
imgs = [pad_img(img) for img in imgs]
|
||||
viz = np.concatenate(imgs, axis=1)
|
||||
return viz
|
||||
1
lib/test/utils/__init__.py
Normal file
1
lib/test/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .params import TrackerParams, FeatureParams, Choice
|
||||
17
lib/test/utils/_init_paths.py
Normal file
17
lib/test/utils/_init_paths.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
prj_path = osp.join(this_dir, '..', '..', '..')
|
||||
add_path(prj_path)
|
||||
93
lib/test/utils/hann.py
Normal file
93
lib/test/utils/hann.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def hann1d(sz: int, centered = True) -> torch.Tensor:
|
||||
"""1D cosine window."""
|
||||
if centered:
|
||||
return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float()))
|
||||
w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float()))
|
||||
return torch.cat([w, w[1:sz-sz//2].flip((0,))])
|
||||
|
||||
|
||||
def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""2D cosine window."""
|
||||
return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""2D cosine window."""
|
||||
distance = torch.stack([ctr_point, sz-ctr_point], dim=0)
|
||||
max_distance, _ = distance.max(dim=0)
|
||||
|
||||
hann1d_x = hann1d(max_distance[0].item() * 2, centered)
|
||||
hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]]
|
||||
hann1d_y = hann1d(max_distance[1].item() * 2, centered)
|
||||
hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]]
|
||||
|
||||
return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
|
||||
def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor:
|
||||
"""1D clipped cosine window."""
|
||||
|
||||
# Ensure that the difference is even
|
||||
effective_sz += (effective_sz - sz) % 2
|
||||
effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1)
|
||||
|
||||
pad = (sz - effective_sz) // 2
|
||||
|
||||
window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate')
|
||||
|
||||
if centered:
|
||||
return window
|
||||
else:
|
||||
mid = (sz / 2).int()
|
||||
window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3)
|
||||
return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2)
|
||||
|
||||
|
||||
def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor:
|
||||
if half:
|
||||
k = torch.arange(0, int(sz/2+1))
|
||||
else:
|
||||
k = torch.arange(-int((sz-1)/2), int(sz/2+1))
|
||||
return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2)
|
||||
|
||||
|
||||
def gauss_spatial(sz, sigma, center=0, end_pad=0):
|
||||
k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad)
|
||||
return torch.exp(-1.0/(2*sigma**2) * (k - center)**2)
|
||||
|
||||
|
||||
def label_function(sz: torch.Tensor, sigma: torch.Tensor):
|
||||
return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1)
|
||||
|
||||
def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)):
|
||||
"""The origin is in the middle of the image."""
|
||||
return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \
|
||||
gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1)
|
||||
|
||||
|
||||
def cubic_spline_fourier(f, a):
|
||||
"""The continuous Fourier transform of a cubic spline kernel."""
|
||||
|
||||
bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f))
|
||||
- (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \
|
||||
/ (4 * math.pi**4 * f**4)
|
||||
|
||||
bf[f == 0] = 1
|
||||
|
||||
return bf
|
||||
|
||||
def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
||||
"""Computes maximum and argmax in the last two dimensions."""
|
||||
|
||||
max_val_row, argmax_row = torch.max(a, dim=-2)
|
||||
max_val, argmax_col = torch.max(max_val_row, dim=-1)
|
||||
argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)]
|
||||
argmax_row = argmax_row.reshape(argmax_col.shape)
|
||||
argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1)
|
||||
return max_val, argmax
|
||||
47
lib/test/utils/load_text.py
Normal file
47
lib/test/utils/load_text.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def load_text_numpy(path, delimiter, dtype):
|
||||
if isinstance(delimiter, (tuple, list)):
|
||||
for d in delimiter:
|
||||
try:
|
||||
ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype)
|
||||
return ground_truth_rect
|
||||
except:
|
||||
pass
|
||||
|
||||
raise Exception('Could not read file {}'.format(path))
|
||||
else:
|
||||
ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype)
|
||||
return ground_truth_rect
|
||||
|
||||
|
||||
def load_text_pandas(path, delimiter, dtype):
|
||||
if isinstance(delimiter, (tuple, list)):
|
||||
for d in delimiter:
|
||||
try:
|
||||
ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False,
|
||||
low_memory=False).values
|
||||
return ground_truth_rect
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
raise Exception('Could not read file {}'.format(path))
|
||||
else:
|
||||
ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False,
|
||||
low_memory=False).values
|
||||
return ground_truth_rect
|
||||
|
||||
|
||||
def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'):
|
||||
if backend == 'numpy':
|
||||
return load_text_numpy(path, delimiter, dtype)
|
||||
elif backend == 'pandas':
|
||||
return load_text_pandas(path, delimiter, dtype)
|
||||
|
||||
|
||||
def load_str(path):
|
||||
with open(path, "r") as f:
|
||||
text_str = f.readline().strip().lower()
|
||||
return text_str
|
||||
43
lib/test/utils/params.py
Normal file
43
lib/test/utils/params.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from lib.utils import TensorList
|
||||
import random
|
||||
|
||||
|
||||
class TrackerParams:
|
||||
"""Class for tracker parameters."""
|
||||
def set_default_values(self, default_vals: dict):
|
||||
for name, val in default_vals.items():
|
||||
if not hasattr(self, name):
|
||||
setattr(self, name, val)
|
||||
|
||||
def get(self, name: str, *default):
|
||||
"""Get a parameter value with the given name. If it does not exists, it return the default value given as a
|
||||
second argument or returns an error if no default value is given."""
|
||||
if len(default) > 1:
|
||||
raise ValueError('Can only give one default value.')
|
||||
|
||||
if not default:
|
||||
return getattr(self, name)
|
||||
|
||||
return getattr(self, name, default[0])
|
||||
|
||||
def has(self, name: str):
|
||||
"""Check if there exist a parameter with the given name."""
|
||||
return hasattr(self, name)
|
||||
|
||||
|
||||
class FeatureParams:
|
||||
"""Class for feature specific parameters"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
raise ValueError
|
||||
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, list):
|
||||
setattr(self, name, TensorList(val))
|
||||
else:
|
||||
setattr(self, name, val)
|
||||
|
||||
|
||||
def Choice(*args):
|
||||
"""Can be used to sample random parameter values."""
|
||||
return random.choice(args)
|
||||
52
lib/test/utils/transform_got10k.py
Normal file
52
lib/test/utils/transform_got10k.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import _init_paths
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def transform_got10k(tracker_name, cfg_name):
|
||||
env = env_settings()
|
||||
result_dir = env.results_path
|
||||
src_dir = os.path.join(result_dir, "%s/%s/got10k/" % (tracker_name, cfg_name))
|
||||
dest_dir = os.path.join(result_dir, "%s/%s/got10k_submit/" % (tracker_name, cfg_name))
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
items = os.listdir(src_dir)
|
||||
for item in items:
|
||||
if "all" in item:
|
||||
continue
|
||||
src_path = os.path.join(src_dir, item)
|
||||
if "time" not in item:
|
||||
seq_name = item.replace(".txt", '')
|
||||
seq_dir = os.path.join(dest_dir, seq_name)
|
||||
if not os.path.exists(seq_dir):
|
||||
os.makedirs(seq_dir)
|
||||
new_item = item.replace(".txt", '_001.txt')
|
||||
dest_path = os.path.join(seq_dir, new_item)
|
||||
bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
|
||||
np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
|
||||
else:
|
||||
seq_name = item.replace("_time.txt", '')
|
||||
seq_dir = os.path.join(dest_dir, seq_name)
|
||||
if not os.path.exists(seq_dir):
|
||||
os.makedirs(seq_dir)
|
||||
dest_path = os.path.join(seq_dir, item)
|
||||
os.system("cp %s %s" % (src_path, dest_path))
|
||||
# make zip archive
|
||||
shutil.make_archive(src_dir, "zip", src_dir)
|
||||
shutil.make_archive(dest_dir, "zip", dest_dir)
|
||||
# Remove the original files
|
||||
shutil.rmtree(src_dir)
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='transform got10k results.')
|
||||
parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
|
||||
parser.add_argument('--cfg_name', type=str, help='Name of config file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
transform_got10k(args.tracker_name, args.cfg_name)
|
||||
|
||||
39
lib/test/utils/transform_trackingnet.py
Normal file
39
lib/test/utils/transform_trackingnet.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import _init_paths
|
||||
from lib.test.evaluation.environment import env_settings
|
||||
|
||||
|
||||
def transform_trackingnet(tracker_name, cfg_name):
|
||||
env = env_settings()
|
||||
result_dir = env.results_path
|
||||
src_dir = os.path.join(result_dir, "%s/%s/trackingnet/" % (tracker_name, cfg_name))
|
||||
dest_dir = os.path.join(result_dir, "%s/%s/trackingnet_submit/" % (tracker_name, cfg_name))
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
items = os.listdir(src_dir)
|
||||
for item in items:
|
||||
if "all" in item:
|
||||
continue
|
||||
if "time" not in item:
|
||||
src_path = os.path.join(src_dir, item)
|
||||
dest_path = os.path.join(dest_dir, item)
|
||||
bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
|
||||
np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
|
||||
# make zip archive
|
||||
shutil.make_archive(src_dir, "zip", src_dir)
|
||||
shutil.make_archive(dest_dir, "zip", dest_dir)
|
||||
# Remove the original files
|
||||
shutil.rmtree(src_dir)
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='transform trackingnet results.')
|
||||
parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
|
||||
parser.add_argument('--cfg_name', type=str, help='Name of config file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
transform_trackingnet(args.tracker_name, args.cfg_name)
|
||||
BIN
lib/train/.DS_Store
vendored
Normal file
BIN
lib/train/.DS_Store
vendored
Normal file
Binary file not shown.
1
lib/train/__init__.py
Normal file
1
lib/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .admin.multigpu import MultiGPU
|
||||
17
lib/train/_init_paths.py
Normal file
17
lib/train/_init_paths.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
prj_path = osp.join(this_dir, '../..')
|
||||
add_path(prj_path)
|
||||
3
lib/train/actors/__init__.py
Normal file
3
lib/train/actors/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_actor import BaseActor
|
||||
from .artrack import ARTrackActor
|
||||
from .artrack_seq import ARTrackSeqActor
|
||||
281
lib/train/actors/artrack.py
Normal file
281
lib/train/actors/artrack.py
Normal file
@@ -0,0 +1,281 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from lib.utils.merge import merge_template_search
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt)**2 + (cy_pred - cy_gt)**2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2*torch.sin(torch.arcsin(x)-torch.pi/4)**2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw+eps))**2
|
||||
py = ((cy_gt - cy_pred) / (ch+eps))**2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
#shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
#IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw**2 + ch**2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi**2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v**2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
class ARTrackActor(BaseActor):
|
||||
""" Actor for training ARTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.range = self.cfg.MODEL.RANGE
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.loss_weight['KL'] = 100
|
||||
self.loss_weight['focal'] = 2
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins * self.range
|
||||
end = self.bins * self.range + 1
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=(-1*magic_num), max=(1+magic_num))
|
||||
data['real_bbox'] = gt_bbox
|
||||
|
||||
seq_ori = (gt_bbox + magic_num) * (self.bins - 1)
|
||||
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_losses(self, pred_dict, gt_dict, return_status=True):
|
||||
bins = self.bins
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
seq_output = gt_dict['seq_output']
|
||||
pred_feat = pred_dict["feat"]
|
||||
if self.focal == None:
|
||||
weight = torch.ones(bins*self.range+2) * 1
|
||||
weight[bins*self.range+1] = 0.1
|
||||
weight[bins*self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.klloss = torch.nn.KLDivLoss(reduction='none').to(pred_feat)
|
||||
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
# compute varfifocal loss
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, bins*2+2)
|
||||
target = seq_output.reshape(-1).to(torch.int64)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
# compute giou and L1 loss
|
||||
beta = 1
|
||||
pred = pred_feat[0:4, :, 0:bins*self.range] * beta
|
||||
target = seq_output[:, 0:4].to(pred_feat)
|
||||
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range((-1*magic_num+1/(self.bins*self.range)), (1+magic_num-1/(self.bins*self.range)), 2/(self.bins*self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
target = target / (bins - 1) - magic_num
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
sious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
sious = sious.mean()
|
||||
siou_loss = sious
|
||||
l1_loss = self.objective['l1'](extra_seq, target)
|
||||
|
||||
loss = self.loss_weight['giou'] * siou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * varifocal_loss
|
||||
|
||||
if return_status:
|
||||
# status for log
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": loss.item(),
|
||||
"Loss/giou": siou_loss.item(),
|
||||
"Loss/l1": l1_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
return loss, status
|
||||
else:
|
||||
return loss
|
||||
629
lib/train/actors/artrack_seq.py
Normal file
629
lib/train/actors/artrack_seq.py
Normal file
@@ -0,0 +1,629 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
import lib.train.data.bounding_box_utils as bbutils
|
||||
from lib.utils.merge import merge_template_search
|
||||
from torch.distributions.categorical import Categorical
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
|
||||
def IoU(rect1, rect2):
|
||||
""" caculate interection over union
|
||||
Args:
|
||||
rect1: (x1, y1, x2, y2)
|
||||
rect2: (x1, y1, x2, y2)
|
||||
Returns:
|
||||
iou
|
||||
"""
|
||||
# overlap
|
||||
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
|
||||
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
|
||||
|
||||
xx1 = np.maximum(tx1, x1)
|
||||
yy1 = np.maximum(ty1, y1)
|
||||
xx2 = np.minimum(tx2, x2)
|
||||
yy2 = np.minimum(ty2, y2)
|
||||
|
||||
ww = np.maximum(0, xx2 - xx1)
|
||||
hh = np.maximum(0, yy2 - yy1)
|
||||
|
||||
area = (x2 - x1) * (y2 - y1)
|
||||
target_a = (tx2 - tx1) * (ty2 - ty1)
|
||||
inter = ww * hh
|
||||
iou = inter / (area + target_a - inter)
|
||||
return iou
|
||||
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt) ** 2 + (cy_pred - cy_gt) ** 2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2 * torch.sin(torch.arcsin(x) - torch.pi / 4) ** 2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw + eps)) ** 2
|
||||
py = ((cy_gt - cy_pred) / (ch + eps)) ** 2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
# shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
# IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw ** 2 + ch ** 2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi ** 2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v ** 2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
|
||||
class ARTrackSeqActor(BaseActor):
|
||||
""" Actor for training OSTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.range = cfg.MODEL.RANGE
|
||||
self.pre_num = cfg.MODEL.PRENUM
|
||||
self.loss_weight['KL'] = 0
|
||||
self.loss_weight['focal'] = 0
|
||||
self.pre_bbox = None
|
||||
self.x_feat_rem = None
|
||||
self.update_rem = None
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def _bbox_clip(self, cx, cy, width, height, boundary):
|
||||
cx = max(0, min(cx, boundary[1]))
|
||||
cy = max(0, min(cy, boundary[0]))
|
||||
width = max(10, min(width, boundary[1]))
|
||||
height = max(10, min(height, boundary[0]))
|
||||
return cx, cy, width, height
|
||||
|
||||
def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
|
||||
"""
|
||||
args:
|
||||
im: bgr based image
|
||||
pos: center position
|
||||
model_sz: exemplar size
|
||||
s_z: original size
|
||||
avg_chans: channel average
|
||||
"""
|
||||
if isinstance(pos, float):
|
||||
pos = [pos, pos]
|
||||
sz = original_sz
|
||||
im_sz = im.shape
|
||||
c = (original_sz + 1) / 2
|
||||
# context_xmin = round(pos[0] - c) # py2 and py3 round
|
||||
context_xmin = np.floor(pos[0] - c + 0.5)
|
||||
context_xmax = context_xmin + sz - 1
|
||||
# context_ymin = round(pos[1] - c)
|
||||
context_ymin = np.floor(pos[1] - c + 0.5)
|
||||
context_ymax = context_ymin + sz - 1
|
||||
left_pad = int(max(0., -context_xmin))
|
||||
top_pad = int(max(0., -context_ymin))
|
||||
right_pad = int(max(0., context_xmax - im_sz[1] + 1))
|
||||
bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
|
||||
|
||||
context_xmin = context_xmin + left_pad
|
||||
context_xmax = context_xmax + left_pad
|
||||
context_ymin = context_ymin + top_pad
|
||||
context_ymax = context_ymax + top_pad
|
||||
|
||||
r, c, k = im.shape
|
||||
if any([top_pad, bottom_pad, left_pad, right_pad]):
|
||||
size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k)
|
||||
te_im = np.zeros(size, np.uint8)
|
||||
te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
|
||||
if top_pad:
|
||||
te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
|
||||
if bottom_pad:
|
||||
te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
|
||||
if left_pad:
|
||||
te_im[:, 0:left_pad, :] = avg_chans
|
||||
if right_pad:
|
||||
te_im[:, c + left_pad:, :] = avg_chans
|
||||
im_patch = te_im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
else:
|
||||
im_patch = im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
|
||||
if not np.array_equal(model_sz, original_sz):
|
||||
try:
|
||||
im_patch = cv2.resize(im_patch, (model_sz, model_sz))
|
||||
except:
|
||||
return None
|
||||
im_patch = im_patch.transpose(2, 0, 1)
|
||||
im_patch = im_patch[np.newaxis, :, :, :]
|
||||
im_patch = im_patch.astype(np.float32)
|
||||
im_patch = torch.from_numpy(im_patch)
|
||||
im_patch = im_patch.cuda()
|
||||
return im_patch
|
||||
|
||||
def batch_init(self, images, template_bbox, initial_bbox) -> dict:
|
||||
self.frame_num = 1
|
||||
self.device = 'cuda'
|
||||
# Convert bbox (x1, y1, w, h) -> (cx, cy, w, h)
|
||||
|
||||
template_bbox = bbutils.batch_xywh2center2(template_bbox) # ndarray:(2*num_seq,4)
|
||||
initial_bbox = bbutils.batch_xywh2center2(initial_bbox) # ndarray:(2*num_seq,4)
|
||||
self.center_pos = initial_bbox[:, :2] # ndarray:(2*num_seq,2)
|
||||
self.size = initial_bbox[:, 2:] # ndarray:(2*num_seq,2)
|
||||
self.pre_bbox = initial_bbox
|
||||
for i in range(self.pre_num - 1):
|
||||
self.pre_bbox = numpy.concatenate((self.pre_bbox, initial_bbox), axis=1)
|
||||
# print(self.pre_bbox.shape)
|
||||
|
||||
template_factor = self.cfg.DATA.TEMPLATE.FACTOR
|
||||
w_z = template_bbox[:, 2] * template_factor # ndarray:(2*num_seq)
|
||||
h_z = template_bbox[:, 3] * template_factor # ndarray:(2*num_seq)
|
||||
s_z = np.ceil(np.sqrt(w_z * h_z)) # ndarray:(2*num_seq)
|
||||
|
||||
self.channel_average = []
|
||||
for img in images:
|
||||
self.channel_average.append(np.mean(img, axis=(0, 1)))
|
||||
self.channel_average = np.array(self.channel_average) # ndarray:(2*num_seq,3)
|
||||
|
||||
# get crop
|
||||
z_crop_list = []
|
||||
for i in range(len(images)):
|
||||
here_crop = self.get_subwindow(images[i], template_bbox[i, :2],
|
||||
self.cfg.DATA.TEMPLATE.SIZE, s_z[i], self.channel_average[i])
|
||||
z_crop = here_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
self.mean = [0.485, 0.456, 0.406]
|
||||
self.std = [0.229, 0.224, 0.225]
|
||||
self.inplace = False
|
||||
z_crop[0] = tvisf.normalize(z_crop[0], self.mean, self.std, self.inplace)
|
||||
z_crop_list.append(z_crop.clone())
|
||||
z_crop = torch.cat(z_crop_list, dim=0) # Tensor(2*num_seq,3,128,128)
|
||||
|
||||
self.update_rem = None
|
||||
|
||||
out = {'template_images': z_crop}
|
||||
return out
|
||||
|
||||
def batch_track(self, img, gt_boxes, template, action_mode='max') -> dict:
|
||||
search_factor = self.cfg.DATA.SEARCH.FACTOR
|
||||
w_x = self.size[:, 0] * search_factor
|
||||
h_x = self.size[:, 1] * search_factor
|
||||
s_x = np.ceil(np.sqrt(w_x * h_x))
|
||||
|
||||
gt_boxes_corner = bbutils.batch_xywh2corner(gt_boxes) # ndarray:(2*num_seq,4)
|
||||
|
||||
x_crop_list = []
|
||||
gt_in_crop_list = []
|
||||
pre_seq_list = []
|
||||
pre_seq_in_list = []
|
||||
x_feat_list = []
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
for i in range(len(img)):
|
||||
channel_avg = np.mean(img[i], axis=(0, 1))
|
||||
x_crop = self.get_subwindow(img[i], self.center_pos[i], self.cfg.DATA.SEARCH.SIZE,
|
||||
round(s_x[i]), channel_avg)
|
||||
if x_crop == None:
|
||||
return None
|
||||
for q in range(self.pre_num):
|
||||
pre_seq_temp = bbutils.batch_center2corner(self.pre_bbox[:, 0 + 4 * q:4 + 4 * q])
|
||||
if q == 0:
|
||||
pre_seq = pre_seq_temp
|
||||
else:
|
||||
pre_seq = numpy.concatenate((pre_seq, pre_seq_temp), axis=1)
|
||||
|
||||
if gt_boxes_corner is not None and np.sum(np.abs(gt_boxes_corner[i] - np.zeros(4))) > 10:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
for w in range(self.pre_num):
|
||||
|
||||
pre_in[0 + w * 4:2 + w * 4] = pre_seq[i, 0 + w * 4:2 + w * 4] - self.center_pos[i]
|
||||
pre_in[2 + w * 4:4 + w * 4] = pre_seq[i, 2 + w * 4:4 + w * 4] - self.center_pos[i]
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] * (
|
||||
self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] / self.cfg.DATA.SEARCH.SIZE
|
||||
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop = np.zeros(4)
|
||||
gt_in_crop[:2] = gt_boxes_corner[i, :2] - self.center_pos[i]
|
||||
gt_in_crop[2:] = gt_boxes_corner[i, 2:] - self.center_pos[i]
|
||||
gt_in_crop = gt_in_crop * (self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
gt_in_crop[2:] = gt_in_crop[2:] - gt_in_crop[:2] # (x1,y1,x2,y2) to (x1,y1,w,h)
|
||||
gt_in_crop_list.append(gt_in_crop)
|
||||
else:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop_list.append(np.zeros(4))
|
||||
pre_seq_input = torch.from_numpy(pre_in).clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq_input = (pre_seq_input + 0.5) * (self.bins - 1)
|
||||
pre_seq_in_list.append(pre_seq_input.clone())
|
||||
x_crop = x_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
x_crop[0] = tvisf.normalize(x_crop[0], self.mean, self.std, self.inplace)
|
||||
x_crop_list.append(x_crop.clone())
|
||||
|
||||
x_crop = torch.cat(x_crop_list, dim=0)
|
||||
pre_seq_output = torch.cat(pre_seq_in_list, dim=0).reshape(-1, 4 * self.pre_num)
|
||||
|
||||
outputs = self.net(template, x_crop, seq_input=pre_seq_output, head_type=None, stage="batch_track",
|
||||
search_feature=self.x_feat_rem, update=None)
|
||||
selected_indices = outputs['seqs'].detach()
|
||||
x_feat = outputs['x_feat'].detach().cpu()
|
||||
self.x_feat_rem = x_feat.clone()
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
pred_bbox = selected_indices[:, 0:4].data.cpu().numpy()
|
||||
bbox = (pred_bbox / (self.bins - 1) - magic_num) * s_x.reshape(-1, 1)
|
||||
cx = bbox[:, 0] + self.center_pos[:, 0] - s_x / 2
|
||||
cy = bbox[:, 1] + self.center_pos[:, 1] - s_x / 2
|
||||
width = bbox[:, 2] - bbox[:, 0]
|
||||
height = bbox[:, 3] - bbox[:, 1]
|
||||
cx = cx + width / 2
|
||||
cy = cy + height / 2
|
||||
|
||||
for i in range(len(img)):
|
||||
cx[i], cy[i], width[i], height[i] = self._bbox_clip(cx[i], cy[i], width[i],
|
||||
height[i], img[i].shape[:2])
|
||||
self.center_pos = np.stack([cx, cy], 1)
|
||||
self.size = np.stack([width, height], 1)
|
||||
for e in range(self.pre_num):
|
||||
if e != self.pre_num - 1:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = self.pre_bbox[:, 4 + e * 4:8 + e * 4]
|
||||
else:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = numpy.stack([cx, cy, width, height], 1)
|
||||
|
||||
bbox = np.stack([cx - width / 2, cy - height / 2, width, height], 1)
|
||||
|
||||
out = {
|
||||
'search_images': x_crop,
|
||||
'pred_bboxes': bbox,
|
||||
'selected_indices': selected_indices.cpu(),
|
||||
'gt_in_crop': torch.tensor(np.stack(gt_in_crop_list, axis=0), dtype=torch.float),
|
||||
'pre_seq': torch.tensor(np.stack(pre_seq_list, axis=0), dtype=torch.float),
|
||||
'x_feat': torch.tensor([item.cpu().detach().numpy() for item in x_feat_list], dtype=torch.float),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
def explore(self, data):
|
||||
results = {}
|
||||
search_images_list = []
|
||||
search_anno_list = []
|
||||
iou_list = []
|
||||
pre_seq_list = []
|
||||
x_feat_list = []
|
||||
|
||||
num_frames = data['num_frames']
|
||||
images = data['search_images']
|
||||
gt_bbox = data['search_annos']
|
||||
template = data['template_images']
|
||||
template_bbox = data['template_annos']
|
||||
|
||||
template = template
|
||||
template_bbox = template_bbox
|
||||
template_bbox = np.array(template_bbox)
|
||||
num_seq = len(num_frames)
|
||||
|
||||
for idx in range(np.max(num_frames)):
|
||||
here_images = [img[idx] for img in images] # S, N
|
||||
here_gt_bbox = np.array([gt[idx] for gt in gt_bbox])
|
||||
|
||||
here_images = here_images
|
||||
here_gt_bbox = np.concatenate([here_gt_bbox], 0)
|
||||
|
||||
if idx == 0:
|
||||
outputs_template = self.batch_init(template, template_bbox, here_gt_bbox)
|
||||
results['template_images'] = outputs_template['template_images']
|
||||
|
||||
else:
|
||||
outputs = self.batch_track(here_images, here_gt_bbox, outputs_template['template_images'],
|
||||
action_mode='half')
|
||||
if outputs == None:
|
||||
return None
|
||||
|
||||
x_feat = outputs['x_feat']
|
||||
pred_bbox = outputs['pred_bboxes']
|
||||
search_images_list.append(outputs['search_images'])
|
||||
search_anno_list.append(outputs['gt_in_crop'])
|
||||
if len(outputs['pre_seq']) != 8:
|
||||
print(outputs['pre_seq'])
|
||||
print(len(outputs['pre_seq']))
|
||||
print(idx)
|
||||
print(data['num_frames'])
|
||||
print(data['search_annos'])
|
||||
return None
|
||||
pre_seq_list.append(outputs['pre_seq'])
|
||||
pred_bbox_corner = bbutils.batch_xywh2corner(pred_bbox)
|
||||
gt_bbox_corner = bbutils.batch_xywh2corner(here_gt_bbox)
|
||||
here_iou = []
|
||||
for i in range(num_seq):
|
||||
bbox_iou = IoU(pred_bbox_corner[i], gt_bbox_corner[i])
|
||||
here_iou.append(bbox_iou)
|
||||
iou_list.append(here_iou)
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
results['x_feat'] = torch.cat([torch.stack(x_feat_list)], dim=2)
|
||||
|
||||
results['search_images'] = torch.cat([torch.stack(search_images_list)],
|
||||
dim=1)
|
||||
results['search_anno'] = torch.cat([torch.stack(search_anno_list)],
|
||||
dim=1)
|
||||
results['pre_seq'] = torch.cat([torch.stack(pre_seq_list)], dim=1)
|
||||
|
||||
iou_tensor = torch.tensor(iou_list, dtype=torch.float)
|
||||
results['baseline_iou'] = torch.cat([iou_tensor[:, :num_seq]], dim=1)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
box_mask_z = None
|
||||
ce_keep_rate = None
|
||||
if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
box_mask_z = generate_mask_cond(self.cfg, template_list[0].shape[0], template_list[0].device,
|
||||
data['template_anno'][0])
|
||||
|
||||
ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH
|
||||
ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH
|
||||
ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch,
|
||||
total_epochs=ce_start_epoch + ce_warm_epoch,
|
||||
ITERS_PER_EPOCH=1,
|
||||
base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0])
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins
|
||||
end = self.bins + 1
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=0.5, max=1.5)
|
||||
data['real_bbox'] = gt_bbox
|
||||
seq_ori = gt_bbox * (self.bins - 1)
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
ce_template_mask=box_mask_z,
|
||||
ce_keep_rate=ce_keep_rate,
|
||||
return_last_attn=False,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_sequence_losses(self, data):
|
||||
num_frames = data['search_images'].shape[0]
|
||||
template_images = data['template_images'].repeat(num_frames, 1, 1, 1, 1)
|
||||
template_images = template_images.view(-1, *template_images.size()[2:])
|
||||
search_images = data['search_images'].reshape(-1, *data['search_images'].size()[2:])
|
||||
search_anno = data['search_anno'].reshape(-1, *data['search_anno'].size()[2:])
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
self.loss_weight['focal'] = 0
|
||||
pre_seq = data['pre_seq'].reshape(-1, 4 * self.pre_num)
|
||||
x_feat = data['x_feat'].reshape(-1, *data['x_feat'].size()[2:])
|
||||
pre_seq = pre_seq.clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq = (pre_seq + magic_num) * (self.bins - 1)
|
||||
|
||||
outputs = self.net(template_images, search_images, seq_input=pre_seq, stage="forward_pass",
|
||||
search_feature=x_feat, update=None)
|
||||
|
||||
pred_feat = outputs["feat"]
|
||||
# generate labels
|
||||
if self.focal == None:
|
||||
weight = torch.ones(self.bins * self.range + 2) * 1
|
||||
weight[self.bins * self.range + 1] = 0.1
|
||||
weight[self.bins * self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
|
||||
search_anno[:, 2] = search_anno[:, 2] + search_anno[:, 0]
|
||||
search_anno[:, 3] = search_anno[:, 3] + search_anno[:, 1]
|
||||
target = (search_anno / self.cfg.DATA.SEARCH.SIZE + 0.5) * (self.bins - 1)
|
||||
|
||||
target = target.clamp(min=0.0, max=(self.bins * self.range - 0.0001))
|
||||
target_iou = target
|
||||
target = torch.cat([target], dim=1)
|
||||
target = target.reshape(-1).to(torch.int64)
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, self.bins * self.range + 2)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
pred = pred_feat[0:4, :, 0:self.bins * self.range]
|
||||
target = target_iou[:, 0:4].to(pred_feat) / (self.bins - 1) - magic_num
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range(-1 * magic_num + 1 / (self.bins * self.range), 1 + magic_num - 1 / (self.bins * self.range), 2 / (self.bins * self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
|
||||
cious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
cious = cious.mean()
|
||||
|
||||
giou_loss = cious
|
||||
loss_bb = self.loss_weight['giou'] * giou_loss + self.loss_weight[
|
||||
'focal'] * varifocal_loss
|
||||
|
||||
total_losses = loss_bb
|
||||
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": total_losses.item(),
|
||||
"Loss/giou": giou_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
|
||||
return total_losses, status
|
||||
|
||||
44
lib/train/actors/base_actor.py
Normal file
44
lib/train/actors/base_actor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class BaseActor:
|
||||
""" Base class for actor. The actor class handles the passing of the data through the network
|
||||
and calculation the loss"""
|
||||
def __init__(self, net, objective):
|
||||
"""
|
||||
args:
|
||||
net - The network to train
|
||||
objective - The loss function
|
||||
"""
|
||||
self.net = net
|
||||
self.objective = objective
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
""" Called in each training iteration. Should pass in input data through the network, calculate the loss, and
|
||||
return the training stats for the input data
|
||||
args:
|
||||
data - A TensorDict containing all the necessary data blocks.
|
||||
|
||||
returns:
|
||||
loss - loss for the input data
|
||||
stats - a dict containing detailed losses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device):
|
||||
""" Move the network to device
|
||||
args:
|
||||
device - device to use. 'cpu' or 'cuda'
|
||||
"""
|
||||
self.net.to(device)
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set whether the network is in train mode.
|
||||
args:
|
||||
mode (True) - Bool specifying whether in training mode.
|
||||
"""
|
||||
self.net.train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set network to eval mode"""
|
||||
self.train(False)
|
||||
3
lib/train/admin/__init__.py
Normal file
3
lib/train/admin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .environment import env_settings, create_default_local_file_ITP_train
|
||||
from .stats import AverageMeter, StatValue
|
||||
#from .tensorboard import TensorboardWriter
|
||||
102
lib/train/admin/environment.py
Normal file
102
lib/train/admin/environment.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': empty_str,
|
||||
'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
|
||||
'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
|
||||
'lasot_dir': empty_str,
|
||||
'got10k_dir': empty_str,
|
||||
'trackingnet_dir': empty_str,
|
||||
'coco_dir': empty_str,
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': empty_str,
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def create_default_local_file_ITP_train(workspace_dir, data_dir):
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': workspace_dir,
|
||||
'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files.
|
||||
'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'),
|
||||
'lasot_dir': os.path.join(data_dir, 'lasot'),
|
||||
'got10k_dir': os.path.join(data_dir, 'got10k/train'),
|
||||
'got10k_val_dir': os.path.join(data_dir, 'got10k/val'),
|
||||
'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'),
|
||||
'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'),
|
||||
'trackingnet_dir': os.path.join(data_dir, 'trackingnet'),
|
||||
'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'),
|
||||
'coco_dir': os.path.join(data_dir, 'coco'),
|
||||
'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'),
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': os.path.join(data_dir, 'vid'),
|
||||
'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'),
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
if attr_val == empty_str:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.train.admin.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.EnvironmentSettings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
|
||||
24
lib/train/admin/local.py
Normal file
24
lib/train/admin/local.py
Normal file
@@ -0,0 +1,24 @@
|
||||
class EnvironmentSettings:
|
||||
def __init__(self):
|
||||
self.workspace_dir = '/home/baiyifan/code/2stage_update_intrain' # Base directory for saving network checkpoints.
|
||||
self.tensorboard_dir = '/home/baiyifan/code/2stage/tensorboard' # Directory for tensorboard files.
|
||||
self.pretrained_networks = '/home/baiyifan/code/2stage/pretrained_networks'
|
||||
self.lasot_dir = '/home/baiyifan/LaSOT/LaSOTBenchmark'
|
||||
self.got10k_dir = '/home/baiyifan/GOT-10k/train'
|
||||
self.got10k_val_dir = '/home/baiyifan/GOT-10k/val'
|
||||
self.lasot_lmdb_dir = '/home/baiyifan/code/2stage/data/lasot_lmdb'
|
||||
self.got10k_lmdb_dir = '/home/baiyifan/code/2stage/data/got10k_lmdb'
|
||||
self.trackingnet_dir = '/ssddata/TrackingNet/all_zip'
|
||||
self.trackingnet_lmdb_dir = '/home/baiyifan/code/2stage/data/trackingnet_lmdb'
|
||||
self.coco_dir = '/home/baiyifan/coco'
|
||||
self.coco_lmdb_dir = '/home/baiyifan/code/2stage/data/coco_lmdb'
|
||||
self.lvis_dir = ''
|
||||
self.sbd_dir = ''
|
||||
self.imagenet_dir = '/home/baiyifan/code/2stage/data/vid'
|
||||
self.imagenet_lmdb_dir = '/home/baiyifan/code/2stage/data/vid_lmdb'
|
||||
self.imagenetdet_dir = ''
|
||||
self.ecssd_dir = ''
|
||||
self.hkuis_dir = ''
|
||||
self.msra10k_dir = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
||||
15
lib/train/admin/multigpu.py
Normal file
15
lib/train/admin/multigpu.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||||
|
||||
|
||||
def is_multi_gpu(net):
|
||||
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||||
|
||||
|
||||
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
pass
|
||||
return getattr(self.module, item)
|
||||
13
lib/train/admin/settings.py
Normal file
13
lib/train/admin/settings.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from lib.train.admin.environment import env_settings
|
||||
|
||||
|
||||
class Settings:
|
||||
""" Training settings, e.g. the paths to datasets and networks."""
|
||||
def __init__(self):
|
||||
self.set_default()
|
||||
|
||||
def set_default(self):
|
||||
self.env = env_settings()
|
||||
self.use_gpu = True
|
||||
|
||||
|
||||
71
lib/train/admin/stats.py
Normal file
71
lib/train/admin/stats.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
class StatValue:
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
self.history.append(self.val)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.has_new_data = False
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.val = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def new_epoch(self):
|
||||
if self.count > 0:
|
||||
self.history.append(self.avg)
|
||||
self.reset()
|
||||
self.has_new_data = True
|
||||
else:
|
||||
self.has_new_data = False
|
||||
|
||||
|
||||
def topk_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
single_input = not isinstance(topk, (tuple, list))
|
||||
if single_input:
|
||||
topk = (topk,)
|
||||
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
|
||||
res.append(correct_k * 100.0 / batch_size)
|
||||
|
||||
if single_input:
|
||||
return res[0]
|
||||
|
||||
return res
|
||||
27
lib/train/admin/tensorboard.py
Normal file
27
lib/train/admin/tensorboard.py
Normal file
@@ -0,0 +1,27 @@
|
||||
#import os
|
||||
#from collections import OrderedDict
|
||||
#try:
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
#except:
|
||||
# print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
#class TensorboardWriter:
|
||||
# def __init__(self, directory, loader_names):
|
||||
# self.directory = directory
|
||||
# self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
|
||||
|
||||
# def write_info(self, script_name, description):
|
||||
# tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
|
||||
# tb_info_writer.add_text('Script_name', script_name)
|
||||
# tb_info_writer.add_text('Description', description)
|
||||
# tb_info_writer.close()
|
||||
|
||||
# def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
|
||||
# for loader_name, loader_stats in stats.items():
|
||||
# if loader_stats is None:
|
||||
# continue
|
||||
# for var_name, val in loader_stats.items():
|
||||
# if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
|
||||
# self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
|
||||
193
lib/train/base_functions.py
Normal file
193
lib/train/base_functions.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
# datasets related
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader
|
||||
import lib.train.data.transforms as tfm
|
||||
from lib.utils.misc import is_main_process
|
||||
|
||||
|
||||
def update_settings(settings, cfg):
|
||||
settings.print_interval = cfg.TRAIN.PRINT_INTERVAL
|
||||
settings.search_area_factor = {'template': cfg.DATA.TEMPLATE.FACTOR,
|
||||
'search': cfg.DATA.SEARCH.FACTOR}
|
||||
settings.output_sz = {'template': cfg.DATA.TEMPLATE.SIZE,
|
||||
'search': cfg.DATA.SEARCH.SIZE}
|
||||
settings.center_jitter_factor = {'template': cfg.DATA.TEMPLATE.CENTER_JITTER,
|
||||
'search': cfg.DATA.SEARCH.CENTER_JITTER}
|
||||
settings.scale_jitter_factor = {'template': cfg.DATA.TEMPLATE.SCALE_JITTER,
|
||||
'search': cfg.DATA.SEARCH.SCALE_JITTER}
|
||||
settings.grad_clip_norm = cfg.TRAIN.GRAD_CLIP_NORM
|
||||
settings.print_stats = None
|
||||
settings.batchsize = cfg.TRAIN.BATCH_SIZE
|
||||
settings.scheduler_type = cfg.TRAIN.SCHEDULER.TYPE
|
||||
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
|
||||
def build_dataloaders(cfg, settings):
|
||||
# Data transform
|
||||
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05),
|
||||
tfm.RandomHorizontalFlip(probability=0.5))
|
||||
|
||||
transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
|
||||
tfm.RandomHorizontalFlip_Norm(probability=0.5),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
transform_val = tfm.Transform(tfm.ToTensor(),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
# The tracking pairs processing module
|
||||
output_sz = settings.output_sz
|
||||
search_area_factor = settings.search_area_factor
|
||||
|
||||
data_processing_train = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_train,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
data_processing_val = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_val,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
# Train sampler and loader
|
||||
settings.num_template = getattr(cfg.DATA.TEMPLATE, "NUMBER", 1)
|
||||
settings.num_search = getattr(cfg.DATA.SEARCH, "NUMBER", 1)
|
||||
sampler_mode = getattr(cfg.DATA, "SAMPLER_MODE", "causal")
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
print("sampler_mode", sampler_mode)
|
||||
dataset_train = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_train,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
|
||||
train_sampler = DistributedSampler(dataset_train) if settings.local_rank != -1 else None
|
||||
shuffle = False if settings.local_rank != -1 else True
|
||||
|
||||
loader_train = LTRLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=shuffle,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=train_sampler)
|
||||
|
||||
# Validation samplers and loaders
|
||||
dataset_val = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.VAL.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.VAL.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.VAL.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_val,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
val_sampler = DistributedSampler(dataset_val) if settings.local_rank != -1 else None
|
||||
loader_val = LTRLoader('val', dataset_val, training=False, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=val_sampler,
|
||||
epoch_interval=cfg.TRAIN.VAL_EPOCH_INTERVAL)
|
||||
|
||||
return loader_train, loader_val
|
||||
|
||||
|
||||
def get_optimizer_scheduler(net, cfg):
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
if train_cls:
|
||||
print("Only training classification head. Learnable parameters are shown below.")
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "cls" in n and p.requires_grad]}
|
||||
]
|
||||
|
||||
for n, p in net.named_parameters():
|
||||
if "cls" not in n:
|
||||
p.requires_grad = False
|
||||
else:
|
||||
print(n)
|
||||
else:
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in net.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.TRAIN.LR * cfg.TRAIN.BACKBONE_MULTIPLIER,
|
||||
},
|
||||
]
|
||||
if is_main_process():
|
||||
print("Learnable parameters are shown below.")
|
||||
for n, p in net.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
if cfg.TRAIN.OPTIMIZER == "ADAMW":
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR,
|
||||
weight_decay=cfg.TRAIN.WEIGHT_DECAY)
|
||||
else:
|
||||
raise ValueError("Unsupported Optimizer")
|
||||
if cfg.TRAIN.SCHEDULER.TYPE == 'step':
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP_EPOCH)
|
||||
elif cfg.TRAIN.SCHEDULER.TYPE == "Mstep":
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=cfg.TRAIN.SCHEDULER.MILESTONES,
|
||||
gamma=cfg.TRAIN.SCHEDULER.GAMMA)
|
||||
else:
|
||||
raise ValueError("Unsupported scheduler")
|
||||
return optimizer, lr_scheduler
|
||||
2
lib/train/data/__init__.py
Normal file
2
lib/train/data/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .loader import LTRLoader
|
||||
from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader
|
||||
150
lib/train/data/bounding_box_utils.py
Normal file
150
lib/train/data/bounding_box_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def batch_center2corner(boxes):
|
||||
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
|
||||
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
|
||||
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
|
||||
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def batch_corner2center(boxes):
|
||||
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
|
||||
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
|
||||
w = (boxes[:, 2] - boxes[:, 0])
|
||||
h = (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center(boxes):
|
||||
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
|
||||
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center2(boxes):
|
||||
cx = boxes[:, 0] + boxes[:, 2] / 2
|
||||
cy = boxes[:, 1] + boxes[:, 3] / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
|
||||
def batch_xywh2corner(boxes):
|
||||
xmin = boxes[:, 0]
|
||||
ymin = boxes[:, 1]
|
||||
xmax = boxes[:, 0] + boxes[:, 2]
|
||||
ymax = boxes[:, 1] + boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def rect_to_rel(bb, sz_norm=None):
|
||||
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
|
||||
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
|
||||
args:
|
||||
bb - N x 4 tensor of boxes.
|
||||
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
|
||||
"""
|
||||
|
||||
c = bb[...,:2] + 0.5 * bb[...,2:]
|
||||
if sz_norm is None:
|
||||
c_rel = c / bb[...,2:]
|
||||
else:
|
||||
c_rel = c / sz_norm
|
||||
sz_rel = torch.log(bb[...,2:])
|
||||
return torch.cat((c_rel, sz_rel), dim=-1)
|
||||
|
||||
|
||||
def rel_to_rect(bb, sz_norm=None):
|
||||
"""Inverts the effect of rect_to_rel. See above."""
|
||||
|
||||
sz = torch.exp(bb[...,2:])
|
||||
if sz_norm is None:
|
||||
c = bb[...,:2] * sz
|
||||
else:
|
||||
c = bb[...,:2] * sz_norm
|
||||
tl = c - 0.5 * sz
|
||||
return torch.cat((tl, sz), dim=-1)
|
||||
|
||||
|
||||
def masks_to_bboxes(mask, fmt='c'):
|
||||
|
||||
""" Convert a mask tensor to one or more bounding boxes.
|
||||
Note: This function is a bit new, make sure it does what it says. /Andreas
|
||||
:param mask: Tensor of masks, shape = (..., H, W)
|
||||
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
|
||||
't' => "top left + size" or (x_left, y_top, width, height)
|
||||
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
|
||||
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
|
||||
"""
|
||||
batch_shape = mask.shape[:-2]
|
||||
mask = mask.reshape((-1, *mask.shape[-2:]))
|
||||
bboxes = []
|
||||
|
||||
for m in mask:
|
||||
mx = m.sum(dim=-2).nonzero()
|
||||
my = m.sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
bboxes.append(bb)
|
||||
|
||||
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
|
||||
bboxes = bboxes.reshape(batch_shape + (4,))
|
||||
|
||||
if fmt == 'v':
|
||||
return bboxes
|
||||
|
||||
x1 = bboxes[..., :2]
|
||||
s = bboxes[..., 2:] - x1 + 1
|
||||
|
||||
if fmt == 'c':
|
||||
return torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
return torch.cat((x1, s), dim=-1)
|
||||
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
|
||||
|
||||
def masks_to_bboxes_multi(mask, ids, fmt='c'):
|
||||
assert mask.dim() == 2
|
||||
bboxes = []
|
||||
|
||||
for id in ids:
|
||||
mx = (mask == id).sum(dim=-2).nonzero()
|
||||
my = (mask == id).float().sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
|
||||
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
|
||||
|
||||
x1 = bb[:2]
|
||||
s = bb[2:] - x1 + 1
|
||||
|
||||
if fmt == 'v':
|
||||
pass
|
||||
elif fmt == 'c':
|
||||
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
bb = torch.cat((x1, s), dim=-1)
|
||||
else:
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
bboxes.append(bb)
|
||||
|
||||
return bboxes
|
||||
103
lib/train/data/image_loader.py
Normal file
103
lib/train/data/image_loader.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import jpeg4py
|
||||
import cv2 as cv
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
|
||||
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
|
||||
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
|
||||
[0, 64, 128], [128, 64, 128]]
|
||||
|
||||
|
||||
def default_image_loader(path):
|
||||
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
|
||||
but reverts to the opencv_loader if the former is not available."""
|
||||
if default_image_loader.use_jpeg4py is None:
|
||||
# Try using jpeg4py
|
||||
im = jpeg4py_loader(path)
|
||||
if im is None:
|
||||
default_image_loader.use_jpeg4py = False
|
||||
print('Using opencv_loader instead.')
|
||||
else:
|
||||
default_image_loader.use_jpeg4py = True
|
||||
return im
|
||||
if default_image_loader.use_jpeg4py:
|
||||
return jpeg4py_loader(path)
|
||||
return opencv_loader(path)
|
||||
|
||||
default_image_loader.use_jpeg4py = None
|
||||
|
||||
|
||||
def jpeg4py_loader(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_loader(path):
|
||||
""" Read image using opencv's imread function and returns it in rgb format"""
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def jpeg4py_loader_w_failsafe(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except:
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_seg_loader(path):
|
||||
""" Read segmentation annotation using opencv's imread function"""
|
||||
try:
|
||||
return cv.imread(path)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def imread_indexed(filename):
|
||||
""" Load indexed image with given filename. Used to read segmentation annotations."""
|
||||
|
||||
im = Image.open(filename)
|
||||
|
||||
annotation = np.atleast_3d(im)[...,0]
|
||||
return annotation
|
||||
|
||||
|
||||
def imwrite_indexed(filename, array, color_palette=None):
|
||||
""" Save indexed image as png. Used to save segmentation annotation."""
|
||||
|
||||
if color_palette is None:
|
||||
color_palette = davis_palette
|
||||
|
||||
if np.atleast_3d(array).shape[2] != 1:
|
||||
raise Exception("Saving indexed PNGs requires 2D array.")
|
||||
|
||||
im = Image.fromarray(array)
|
||||
im.putpalette(color_palette.ravel())
|
||||
im.save(filename, format='PNG')
|
||||
199
lib/train/data/loader.py
Normal file
199
lib/train/data/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.utils.data.dataloader
|
||||
import importlib
|
||||
import collections
|
||||
# from torch._six import string_classes
|
||||
from lib.utils import TensorDict, TensorList
|
||||
|
||||
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
|
||||
int_classes = int
|
||||
else:
|
||||
from torch._six import int_classes
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
string_classes = str
|
||||
|
||||
def _check_use_shared_memory():
|
||||
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
|
||||
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
|
||||
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
|
||||
if hasattr(collate_lib, '_use_shared_memory'):
|
||||
return getattr(collate_lib, '_use_shared_memory')
|
||||
return torch.utils.data.get_worker_info() is not None
|
||||
|
||||
|
||||
def ltr_collate(batch):
|
||||
"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 0, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
def ltr_collate_stack1(batch):
|
||||
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 1, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 1)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate_stack1(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
class LTRLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
|
||||
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
|
||||
select along which dimension the data should be stacked to form a batch.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: 1).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: False).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, ``shuffle`` must be False.
|
||||
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
||||
indices at a time. Mutually exclusive with batch_size, shuffle,
|
||||
sampler, and drop_last.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means that the data will be loaded in the main process.
|
||||
(default: 0)
|
||||
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
||||
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
||||
into CUDA pinned memory before returning them.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: False)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: 0)
|
||||
worker_init_fn (callable, optional): If not None, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: None)
|
||||
|
||||
.. note:: By default, each worker will have its PyTorch seed set to
|
||||
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
||||
by main process using its RNG. However, seeds for other libraries
|
||||
may be duplicated upon initializing workers (w.g., NumPy), causing
|
||||
each worker to return identical random numbers. (See
|
||||
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
||||
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
||||
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
||||
before data loading.
|
||||
|
||||
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
||||
unpicklable object, e.g., a lambda function.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
print("pin_memory is", pin_memory)
|
||||
if collate_fn is None:
|
||||
if stack_dim == 0:
|
||||
collate_fn = ltr_collate
|
||||
elif stack_dim == 1:
|
||||
collate_fn = ltr_collate_stack1
|
||||
else:
|
||||
raise ValueError('Stack dim no supported. Must be 0 or 1.')
|
||||
|
||||
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
155
lib/train/data/processing.py
Normal file
155
lib/train/data/processing.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from lib.utils import TensorDict
|
||||
import lib.train.data.processing_utils as prutils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stack_tensors(x):
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
|
||||
return torch.stack(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseProcessing:
|
||||
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
|
||||
through the network. For example, it can be used to crop a search region around the object, apply various data
|
||||
augmentations, etc."""
|
||||
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
|
||||
"""
|
||||
args:
|
||||
transform - The set of transformations to be applied on the images. Used only if template_transform or
|
||||
search_transform is None.
|
||||
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
|
||||
example, it can be used to convert both template and search images to grayscale.
|
||||
"""
|
||||
self.transform = {'template': transform if template_transform is None else template_transform,
|
||||
'search': transform if search_transform is None else search_transform,
|
||||
'joint': joint_transform}
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class STARKProcessing(BaseProcessing):
|
||||
""" The processing class used for training LittleBoy. The images are processed in the following way.
|
||||
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
|
||||
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
|
||||
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
|
||||
always at the center of the search region. The search region is then resized to a fixed size given by the
|
||||
argument output_sz.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
|
||||
mode='pair', settings=None, *args, **kwargs):
|
||||
"""
|
||||
args:
|
||||
search_area_factor - The size of the search region relative to the target size.
|
||||
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
|
||||
square.
|
||||
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.search_area_factor = search_area_factor
|
||||
self.output_sz = output_sz
|
||||
self.center_jitter_factor = center_jitter_factor
|
||||
self.scale_jitter_factor = scale_jitter_factor
|
||||
self.mode = mode
|
||||
self.settings = settings
|
||||
|
||||
def _get_jittered_box(self, box, mode):
|
||||
""" Jitter the input box
|
||||
args:
|
||||
box - input bounding box
|
||||
mode - string 'template' or 'search' indicating template or search data
|
||||
|
||||
returns:
|
||||
torch.Tensor - jittered box
|
||||
"""
|
||||
|
||||
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
|
||||
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
|
||||
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
|
||||
|
||||
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the following fields:
|
||||
'template_images', search_images', 'template_anno', 'search_anno'
|
||||
returns:
|
||||
TensorDict - output data block with following fields:
|
||||
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
|
||||
"""
|
||||
# Apply joint transforms
|
||||
if self.transform['joint'] is not None:
|
||||
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
|
||||
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
|
||||
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
|
||||
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
|
||||
|
||||
for s in ['template', 'search']:
|
||||
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
|
||||
"In pair mode, num train/test frames must be 1"
|
||||
|
||||
# Add a uniform noise to the center pos
|
||||
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
|
||||
|
||||
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
|
||||
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
|
||||
|
||||
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
|
||||
if (crop_sz < 1).any():
|
||||
data['valid'] = False
|
||||
# print("Too small box is found. Replace it with new data.")
|
||||
return data
|
||||
|
||||
# Crop image region centered at jittered_anno box and get the attention mask
|
||||
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
|
||||
data[s + '_anno'], self.search_area_factor[s],
|
||||
self.output_sz[s], masks=data[s + '_masks'])
|
||||
# Apply transforms
|
||||
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
|
||||
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
|
||||
|
||||
|
||||
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
|
||||
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
|
||||
for ele in data[s + '_att']:
|
||||
if (ele == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of original attention mask are all one. Replace it with new data.")
|
||||
return data
|
||||
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
|
||||
for ele in data[s + '_att']:
|
||||
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
|
||||
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
|
||||
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
|
||||
if (mask_down == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of down-sampled attention mask are all one. "
|
||||
# "Replace it with new data.")
|
||||
return data
|
||||
|
||||
data['valid'] = True
|
||||
# if we use copy-and-paste augmentation
|
||||
if data["template_masks"] is None or data["search_masks"] is None:
|
||||
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
|
||||
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
|
||||
# Prepare output
|
||||
if self.mode == 'sequence':
|
||||
data = data.apply(stack_tensors)
|
||||
else:
|
||||
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
|
||||
|
||||
return data
|
||||
168
lib/train/data/processing_utils.py
Normal file
168
lib/train/data/processing_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
'''modified from the original test implementation
|
||||
Replace cv.BORDER_REPLICATE with cv.BORDER_CONSTANT
|
||||
Add a variable called att_mask for computing attention and positional encoding later'''
|
||||
|
||||
|
||||
def sample_target(im, target_bb, search_area_factor, output_sz=None, mask=None):
|
||||
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
||||
|
||||
args:
|
||||
im - cv image
|
||||
target_bb - target box [x, y, w, h]
|
||||
search_area_factor - Ratio of crop size to target size
|
||||
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
||||
|
||||
returns:
|
||||
cv image - extracted crop
|
||||
float - the factor by which the crop has been resized to make the crop size equal output_size
|
||||
"""
|
||||
if not isinstance(target_bb, list):
|
||||
x, y, w, h = target_bb.tolist()
|
||||
else:
|
||||
x, y, w, h = target_bb
|
||||
# Crop image
|
||||
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
||||
|
||||
if crop_sz < 1:
|
||||
raise Exception('Too small bounding box.')
|
||||
|
||||
x1 = round(x + 0.5 * w - crop_sz * 0.5)
|
||||
x2 = x1 + crop_sz
|
||||
|
||||
y1 = round(y + 0.5 * h - crop_sz * 0.5)
|
||||
y2 = y1 + crop_sz
|
||||
|
||||
x1_pad = max(0, -x1)
|
||||
x2_pad = max(x2 - im.shape[1] + 1, 0)
|
||||
|
||||
y1_pad = max(0, -y1)
|
||||
y2_pad = max(y2 - im.shape[0] + 1, 0)
|
||||
|
||||
# Crop target
|
||||
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
||||
if mask is not None:
|
||||
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
||||
|
||||
# Pad
|
||||
im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_CONSTANT)
|
||||
# deal with attention mask
|
||||
H, W, _ = im_crop_padded.shape
|
||||
att_mask = np.ones((H,W))
|
||||
end_x, end_y = -x2_pad, -y2_pad
|
||||
if y2_pad == 0:
|
||||
end_y = None
|
||||
if x2_pad == 0:
|
||||
end_x = None
|
||||
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
||||
if mask is not None:
|
||||
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
||||
|
||||
if output_sz is not None:
|
||||
resize_factor = output_sz / crop_sz
|
||||
im_crop_padded = cv.resize(im_crop_padded, (output_sz, output_sz))
|
||||
att_mask = cv.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
||||
if mask is None:
|
||||
return im_crop_padded, resize_factor, att_mask
|
||||
mask_crop_padded = \
|
||||
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
||||
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
||||
|
||||
else:
|
||||
if mask is None:
|
||||
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
||||
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
||||
|
||||
|
||||
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
|
||||
crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box_in - the box for which the co-ordinates are to be transformed
|
||||
box_extract - the box about which the image crop has been extracted.
|
||||
resize_factor - the ratio between the original image scale and the scale of the image crop
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]
|
||||
|
||||
box_in_center = box_in[0:2] + 0.5 * box_in[2:4]
|
||||
|
||||
box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
|
||||
box_out_wh = box_in[2:4] * resize_factor
|
||||
|
||||
box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
|
||||
def jittered_center_crop(frames, box_extract, box_gt, search_area_factor, output_sz, masks=None):
|
||||
""" For each frame in frames, extracts a square crop centered at box_extract, of area search_area_factor^2
|
||||
times box_extract area. The extracted crops are then resized to output_sz. Further, the co-ordinates of the box
|
||||
box_gt are transformed to the image crop co-ordinates
|
||||
|
||||
args:
|
||||
frames - list of frames
|
||||
box_extract - list of boxes of same length as frames. The crops are extracted using anno_extract
|
||||
box_gt - list of boxes of same length as frames. The co-ordinates of these boxes are transformed from
|
||||
image co-ordinates to the crop co-ordinates
|
||||
search_area_factor - The area of the extracted crop is search_area_factor^2 times box_extract area
|
||||
output_sz - The size to which the extracted crops are resized
|
||||
|
||||
returns:
|
||||
list - list of image crops
|
||||
list - box_gt location in the crop co-ordinates
|
||||
"""
|
||||
|
||||
if masks is None:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz)
|
||||
for f, a in zip(frames, box_extract)]
|
||||
frames_crop, resize_factors, att_mask = zip(*crops_resize_factors)
|
||||
masks_crop = None
|
||||
else:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz, m)
|
||||
for f, a, m in zip(frames, box_extract, masks)]
|
||||
frames_crop, resize_factors, att_mask, masks_crop = zip(*crops_resize_factors)
|
||||
# frames_crop: tuple of ndarray (128,128,3), att_mask: tuple of ndarray (128,128)
|
||||
crop_sz = torch.Tensor([output_sz, output_sz])
|
||||
|
||||
# find the bb location in the crop
|
||||
'''Note that here we use normalized coord'''
|
||||
box_crop = [transform_image_to_crop(a_gt, a_ex, rf, crop_sz, normalize=True)
|
||||
for a_gt, a_ex, rf in zip(box_gt, box_extract, resize_factors)] # (x1,y1,w,h) list of tensors
|
||||
|
||||
return frames_crop, box_crop, att_mask, masks_crop
|
||||
|
||||
|
||||
def transform_box_to_crop(box: torch.Tensor, crop_box: torch.Tensor, crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box - the box for which the co-ordinates are to be transformed
|
||||
crop_box - bounding box defining the crop in the original image
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
|
||||
box_out = box.clone()
|
||||
box_out[:2] -= crop_box[:2]
|
||||
|
||||
scale_factor = crop_sz / crop_box[2:]
|
||||
|
||||
box_out[:2] *= scale_factor
|
||||
box_out[2:] *= scale_factor
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
349
lib/train/data/sampler.py
Normal file
349
lib/train/data/sampler.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
from lib.utils import TensorDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def no_processing(data):
|
||||
return data
|
||||
|
||||
|
||||
class TrackingSampler(torch.utils.data.Dataset):
|
||||
""" Class responsible for sampling frames from training sequences to form batches.
|
||||
|
||||
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
|
||||
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
|
||||
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
|
||||
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
|
||||
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
|
||||
|
||||
The sampled frames are then passed through the input 'processing' function for the necessary processing-
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
|
||||
train_cls=False, pos_prob=0.5):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the test frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.train_cls = train_cls # whether we are training classification
|
||||
self.pos_prob = pos_prob # probability of sampling positive class when making classification
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.processing = processing
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
|
||||
allow_invisible=False, force_invisible=False):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
# get valid ids
|
||||
if force_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
|
||||
else:
|
||||
if allow_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id)]
|
||||
else:
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.train_cls:
|
||||
return self.getitem_cls()
|
||||
else:
|
||||
return self.getitem()
|
||||
|
||||
def getitem(self):
|
||||
"""
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
|
||||
if is_video_dataset:
|
||||
template_frame_ids = None
|
||||
search_frame_ids = None
|
||||
gap_increase = 0
|
||||
|
||||
if self.frame_sample_mode == 'causal':
|
||||
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
|
||||
while search_frame_ids is None:
|
||||
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
|
||||
min_id=base_frame_id[0] - self.max_gap - gap_increase,
|
||||
max_id=base_frame_id[0])
|
||||
if prev_frame_ids is None:
|
||||
gap_increase += 5
|
||||
continue
|
||||
template_frame_ids = base_frame_id + prev_frame_ids
|
||||
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
|
||||
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
|
||||
num_ids=self.num_search_frames)
|
||||
# Increase gap until a frame is found
|
||||
gap_increase += 5
|
||||
|
||||
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("Illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def getitem_cls(self):
|
||||
# get data for classification
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
label = None
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample template and search frame ids
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode in ["trident", "trident_pro"]:
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
# "try" is used to handle trackingnet data failure
|
||||
# get images and bounding boxes (for templates)
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
|
||||
seq_info_dict)
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
|
||||
(H, W))] * self.num_template_frames
|
||||
# get images and bounding boxes (for searches)
|
||||
# positive samples
|
||||
if random.random() < self.pos_prob:
|
||||
label = torch.ones(1,)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
# negative samples
|
||||
else:
|
||||
label = torch.zeros(1,)
|
||||
if is_video_dataset:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
|
||||
if search_frame_ids is None:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
|
||||
seq_info_dict)
|
||||
search_anno["bbox"] = [self.get_center_box(H, W)]
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
H, W, _ = search_frames[0].shape
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
# add classification label
|
||||
data["label"] = label
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def get_center_box(self, H, W, ratio=1/8):
|
||||
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
|
||||
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
|
||||
|
||||
def sample_seq_from_dataset(self, dataset, is_video_dataset):
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
return seq_id, visible, seq_info_dict
|
||||
|
||||
def get_one_search(self):
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
# sample a sequence
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample a frame
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == "stark":
|
||||
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
|
||||
else:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
|
||||
else:
|
||||
search_frame_ids = [1]
|
||||
# get the image, bounding box and other info
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
return search_frames, search_anno, meta_obj_test
|
||||
|
||||
def get_frame_ids_trident(self, visible):
|
||||
# get template and search ids in a 'trident' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
if self.frame_sample_mode == "trident_pro":
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
|
||||
allow_invisible=True)
|
||||
else:
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
def get_frame_ids_stark(self, visible, valid):
|
||||
# get template and search ids in a 'stark' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
"""we require the frame to be valid but not necessary visible"""
|
||||
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
265
lib/train/data/sequence_sampler.py
Normal file
265
lib/train/data/sequence_sampler.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class SequenceSampler(torch.utils.data.Dataset):
|
||||
"""
|
||||
Sample sequence for sequence-level training
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
|
||||
max_interval - Maximum interval between sampled frames
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
prob - sequential sampling by prob / interval sampling by 1-prob
|
||||
"""
|
||||
self.datasets = datasets
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
self.prob=prob
|
||||
self.extra=1
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
|
||||
def _sequential_sample(self, visible):
|
||||
# Sample frames in sequential manner
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
if self.max_gap == -1:
|
||||
left = template_frame_ids[0]
|
||||
else:
|
||||
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
|
||||
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
|
||||
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)[0]
|
||||
|
||||
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
|
||||
search_frame_ids = valid_ids[:self.num_search_frames]
|
||||
|
||||
# if length is not enough
|
||||
last = search_frame_ids[-1]
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
if last >= len(visible) - 1:
|
||||
search_frame_ids.append(last)
|
||||
else:
|
||||
last += 1
|
||||
if visible[last]:
|
||||
search_frame_ids.append(last)
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def _random_interval_sample(self, visible):
|
||||
# Get valid ids
|
||||
valid_ids = [i for i in range(len(visible)) if visible[i]]
|
||||
|
||||
# Sample template frame
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
|
||||
if template_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == 0:
|
||||
template_frame_ids = [valid_ids[0]]
|
||||
break
|
||||
|
||||
# Sample first search frame
|
||||
if self.max_gap == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
else:
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
|
||||
template_frame_ids[0] + self.max_gap)
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)
|
||||
|
||||
if search_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
break
|
||||
|
||||
# Sample rest of the search frames with random interval
|
||||
last = search_frame_ids[0]
|
||||
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
|
||||
# sample id with interval
|
||||
max_id = min(last + self.max_interval + 1, len(visible))
|
||||
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
|
||||
max_id=max_id)
|
||||
|
||||
if id is None:
|
||||
# If not found in current range, find from previous range
|
||||
last = last + self.max_interval
|
||||
else:
|
||||
search_frame_ids.append(id[0])
|
||||
last = search_frame_ids[-1]
|
||||
|
||||
# if length is not enough, randomly sample new ids
|
||||
if len(search_frame_ids) < self.num_search_frames:
|
||||
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
|
||||
|
||||
if len(valid_ids) > 0:
|
||||
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
|
||||
self.num_search_frames - len(search_frame_ids)))
|
||||
search_frame_ids = search_frame_ids + new_ids
|
||||
search_frame_ids = sorted(search_frame_ids, key=int)
|
||||
|
||||
# if length is still not enough, duplicate last frame
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
search_frame_ids.append(search_frame_ids[-1])
|
||||
|
||||
for i in range(1, self.num_search_frames):
|
||||
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
|
||||
print(search_frame_ids[i] - search_frame_ids[i - 1])
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
if dataset.get_name() == 'got10k' :
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
else:
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
self.max_gap = max_gap * self.extra
|
||||
self.max_interval = max_interval * self.extra
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
while True:
|
||||
try:
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == 'sequential':
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
|
||||
elif self.frame_sample_mode == 'random_interval':
|
||||
if random.random() < self.prob:
|
||||
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
|
||||
else:
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
else:
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
#print(self.max_gap, self.max_interval)
|
||||
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
|
||||
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
|
||||
# print("====================================================================================")
|
||||
# print("dataset index: {}".format(index))
|
||||
# print("seq_id: {}".format(seq_id))
|
||||
# print('template_frame_ids: {}'.format(template_frame_ids))
|
||||
# print('search_frame_ids: {}'.format(search_frame_ids))
|
||||
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
|
||||
'template_annos': np.array(template_bbox).squeeze(),
|
||||
'search_images': np.array(search_frames), # (num_frames) search images
|
||||
'search_annos': np.array(search_bbox),
|
||||
'seq_id': seq_id,
|
||||
'dataset': dataset.get_name(),
|
||||
'search_class': meta_obj_search.get('object_class_name'),
|
||||
'num_frames': len(search_frames)
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
335
lib/train/data/transforms.py
Normal file
335
lib/train/data/transforms.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
|
||||
|
||||
class Transform:
|
||||
"""A set of transformations, used for e.g. data augmentation.
|
||||
Args of constructor:
|
||||
transforms: An arbitrary number of transformations, derived from the TransformBase class.
|
||||
They are applied in the order they are given.
|
||||
|
||||
The Transform object can jointly transform images, bounding boxes and segmentation masks.
|
||||
This is done by calling the object with the following key-word arguments (all are optional).
|
||||
|
||||
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
|
||||
image - Image
|
||||
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
|
||||
bbox - Bounding box on the form [x, y, w, h]
|
||||
mask - Segmentation mask with discrete classes
|
||||
|
||||
The following parameters can be supplied with calling the transform object:
|
||||
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
|
||||
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
|
||||
different random rolls. Default: True.
|
||||
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
|
||||
is used instead. Default: True.
|
||||
|
||||
Check the DiMPProcessing class for examples.
|
||||
"""
|
||||
|
||||
def __init__(self, *transforms):
|
||||
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
|
||||
transforms = transforms[0]
|
||||
self.transforms = transforms
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['joint', 'new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
|
||||
def __call__(self, **inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
for v in inputs.keys():
|
||||
if v not in self._valid_all:
|
||||
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
|
||||
|
||||
joint_mode = inputs.get('joint', True)
|
||||
new_roll = inputs.get('new_roll', True)
|
||||
|
||||
if not joint_mode:
|
||||
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
|
||||
return tuple(list(o) for o in out)
|
||||
|
||||
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
|
||||
for t in self.transforms:
|
||||
out = t(**out, joint=joint_mode, new_roll=new_roll)
|
||||
if len(var_names) == 1:
|
||||
return out[var_names[0]]
|
||||
# Make sure order is correct
|
||||
return tuple(out[v] for v in var_names)
|
||||
|
||||
def _split_inputs(self, inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
|
||||
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
|
||||
if isinstance(arg_val, list):
|
||||
for inp, av in zip(split_inputs, arg_val):
|
||||
inp[arg_name] = av
|
||||
else:
|
||||
for inp in split_inputs:
|
||||
inp[arg_name] = arg_val
|
||||
return split_inputs
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class TransformBase:
|
||||
"""Base class for transformation objects. See the Transform class for details."""
|
||||
def __init__(self):
|
||||
"""2020.12.24 Add 'att' to valid inputs"""
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
self._rand_params = None
|
||||
|
||||
def __call__(self, **inputs):
|
||||
# Split input
|
||||
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
|
||||
|
||||
# Roll random parameters for the transform
|
||||
if input_args.get('new_roll', True):
|
||||
rand_params = self.roll()
|
||||
if rand_params is None:
|
||||
rand_params = ()
|
||||
elif not isinstance(rand_params, tuple):
|
||||
rand_params = (rand_params,)
|
||||
self._rand_params = rand_params
|
||||
|
||||
outputs = dict()
|
||||
for var_name, var in input_vars.items():
|
||||
if var is not None:
|
||||
transform_func = getattr(self, 'transform_' + var_name)
|
||||
if var_name in ['coords', 'bbox']:
|
||||
params = (self._get_image_size(input_vars),) + self._rand_params
|
||||
else:
|
||||
params = self._rand_params
|
||||
if isinstance(var, (list, tuple)):
|
||||
outputs[var_name] = [transform_func(x, *params) for x in var]
|
||||
else:
|
||||
outputs[var_name] = transform_func(var, *params)
|
||||
return outputs
|
||||
|
||||
def _get_image_size(self, inputs):
|
||||
im = None
|
||||
for var_name in ['image', 'mask']:
|
||||
if inputs.get(var_name) is not None:
|
||||
im = inputs[var_name]
|
||||
break
|
||||
if im is None:
|
||||
return None
|
||||
if isinstance(im, (list, tuple)):
|
||||
im = im[0]
|
||||
if isinstance(im, np.ndarray):
|
||||
return im.shape[:2]
|
||||
if torch.is_tensor(im):
|
||||
return (im.shape[-2], im.shape[-1])
|
||||
raise Exception('Unknown image type')
|
||||
|
||||
def roll(self):
|
||||
return None
|
||||
|
||||
def transform_image(self, image, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return coords
|
||||
|
||||
def transform_bbox(self, bbox, image_shape, *rand_params):
|
||||
"""Assumes [x, y, w, h]"""
|
||||
# Check if not overloaded
|
||||
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
|
||||
return bbox
|
||||
|
||||
coord = bbox.clone().view(-1,2).t().flip(0)
|
||||
|
||||
x1 = coord[1, 0]
|
||||
x2 = coord[1, 0] + coord[1, 1]
|
||||
|
||||
y1 = coord[0, 0]
|
||||
y2 = coord[0, 0] + coord[0, 1]
|
||||
|
||||
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
|
||||
|
||||
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
|
||||
tl = torch.min(coord_transf, dim=1)[0]
|
||||
sz = torch.max(coord_transf, dim=1)[0] - tl
|
||||
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
|
||||
return bbox_out
|
||||
|
||||
def transform_mask(self, mask, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, *rand_params):
|
||||
"""2020.12.24 Added to deal with attention masks"""
|
||||
return att
|
||||
|
||||
|
||||
class ToTensor(TransformBase):
|
||||
"""Convert to a Tensor"""
|
||||
|
||||
def transform_image(self, image):
|
||||
# handle numpy array
|
||||
if image.ndim == 2:
|
||||
image = image[:, :, None]
|
||||
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(image, torch.ByteTensor):
|
||||
return image.float().div(255)
|
||||
else:
|
||||
return image
|
||||
|
||||
def transfrom_mask(self, mask):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
|
||||
def transform_att(self, att):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class ToTensorAndJitter(TransformBase):
|
||||
"""Convert to a Tensor and jitter brightness"""
|
||||
def __init__(self, brightness_jitter=0.0, normalize=True):
|
||||
super().__init__()
|
||||
self.brightness_jitter = brightness_jitter
|
||||
self.normalize = normalize
|
||||
|
||||
def roll(self):
|
||||
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
|
||||
|
||||
def transform_image(self, image, brightness_factor):
|
||||
# handle numpy array
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
|
||||
# backward compatibility
|
||||
if self.normalize:
|
||||
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
|
||||
else:
|
||||
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
|
||||
|
||||
def transform_mask(self, mask, brightness_factor):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
else:
|
||||
return mask
|
||||
def transform_att(self, att, brightness_factor):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class Normalize(TransformBase):
|
||||
"""Normalize image"""
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
super().__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def transform_image(self, image):
|
||||
return tvisf.normalize(image, self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToGrayscale(TransformBase):
|
||||
"""Converts image to grayscale with probability"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_grayscale):
|
||||
if do_grayscale:
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
|
||||
return np.stack([img_gray, img_gray, img_gray], axis=2)
|
||||
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
|
||||
return image
|
||||
|
||||
|
||||
class ToBGR(TransformBase):
|
||||
"""Converts image to BGR"""
|
||||
def transform_image(self, image):
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
||||
return img_bgr
|
||||
|
||||
|
||||
class RandomHorizontalFlip(TransformBase):
|
||||
"""Horizontally flip image randomly with a probability p."""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(image):
|
||||
return image.flip((2,))
|
||||
return np.fliplr(image).copy()
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
|
||||
def transform_mask(self, mask, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(mask):
|
||||
return mask.flip((-1,))
|
||||
return np.fliplr(mask).copy()
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(att):
|
||||
return att.flip((-1,))
|
||||
return np.fliplr(att).copy()
|
||||
return att
|
||||
|
||||
|
||||
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
|
||||
"""Horizontally flip image randomly with a probability p.
|
||||
The difference is that the coord is normalized to [0,1]"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
"""we should use 1 rather than image_shape"""
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = 1 - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
33
lib/train/data/wandb_logger.py
Normal file
33
lib/train/data/wandb_logger.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install wandb" to install wandb')
|
||||
|
||||
|
||||
class WandbWriter:
|
||||
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
|
||||
self.wandb = wandb
|
||||
self.step = cur_step
|
||||
self.interval = step_interval
|
||||
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
|
||||
|
||||
def write_log(self, stats: OrderedDict, epoch=-1):
|
||||
self.step += 1
|
||||
for loader_name, loader_stats in stats.items():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
|
||||
log_dict = {}
|
||||
for var_name, val in loader_stats.items():
|
||||
if hasattr(val, 'avg'):
|
||||
log_dict.update({loader_name + '/' + var_name: val.avg})
|
||||
else:
|
||||
log_dict.update({loader_name + '/' + var_name: val.val})
|
||||
|
||||
if epoch >= 0:
|
||||
log_dict.update({loader_name + '/epoch': epoch})
|
||||
|
||||
self.wandb.log(log_dict, step=self.step*self.interval)
|
||||
16
lib/train/data_specs/README.md
Normal file
16
lib/train/data_specs/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# README
|
||||
|
||||
## Description for different text files
|
||||
GOT10K
|
||||
- got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos)
|
||||
- got10k_train_split.txt: part of videos from the GOT-10K training set
|
||||
- got10k_val_split.txt: another part of videos from the GOT-10K training set
|
||||
- got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html))
|
||||
- got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set
|
||||
- got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set
|
||||
|
||||
LaSOT
|
||||
- lasot_train_split.txt: the complete LaSOT training set
|
||||
|
||||
TrackingNnet
|
||||
- trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet
|
||||
9335
lib/train/data_specs/got10k_train_full_split.txt
Normal file
9335
lib/train/data_specs/got10k_train_full_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
7934
lib/train/data_specs/got10k_train_split.txt
Normal file
7934
lib/train/data_specs/got10k_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1401
lib/train/data_specs/got10k_val_split.txt
Normal file
1401
lib/train/data_specs/got10k_val_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1000
lib/train/data_specs/got10k_vot_exclude.txt
Normal file
1000
lib/train/data_specs/got10k_vot_exclude.txt
Normal file
File diff suppressed because it is too large
Load Diff
7086
lib/train/data_specs/got10k_vot_train_split.txt
Normal file
7086
lib/train/data_specs/got10k_vot_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1249
lib/train/data_specs/got10k_vot_val_split.txt
Normal file
1249
lib/train/data_specs/got10k_vot_val_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1120
lib/train/data_specs/lasot_train_split.txt
Normal file
1120
lib/train/data_specs/lasot_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
30134
lib/train/data_specs/trackingnet_classmap.txt
Normal file
30134
lib/train/data_specs/trackingnet_classmap.txt
Normal file
File diff suppressed because it is too large
Load Diff
437
lib/train/dataset/COCO_tool.py
Normal file
437
lib/train/dataset/COCO_tool.py
Normal file
@@ -0,0 +1,437 @@
|
||||
__author__ = 'tylin'
|
||||
__version__ = '2.0'
|
||||
# Interface for accessing the Microsoft COCO dataset.
|
||||
|
||||
# Microsoft COCO is a large image dataset designed for object detection,
|
||||
# segmentation, and caption generation. pycocotools is a Python API that
|
||||
# assists in loading, parsing and visualizing the annotations in COCO.
|
||||
# Please visit http://mscoco.org/ for more information on COCO, including
|
||||
# for the data, paper, and tutorials. The exact format of the annotations
|
||||
# is also described on the COCO website. For example usage of the pycocotools
|
||||
# please see pycocotools_demo.ipynb. In addition to this API, please download both
|
||||
# the COCO images and annotations in order to run the demo.
|
||||
|
||||
# An alternative to using the API is to load the annotations directly
|
||||
# into Python dictionary
|
||||
# Using the API provides additional utility functions. Note that this API
|
||||
# supports both *instance* and *caption* annotations. In the case of
|
||||
# captions not all functions are defined (e.g. categories are undefined).
|
||||
|
||||
# The following API functions are defined:
|
||||
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
|
||||
# decodeMask - Decode binary mask M encoded via run-length encoding.
|
||||
# encodeMask - Encode binary mask M using run-length encoding.
|
||||
# getAnnIds - Get ann ids that satisfy given filter conditions.
|
||||
# getCatIds - Get cat ids that satisfy given filter conditions.
|
||||
# getImgIds - Get img ids that satisfy given filter conditions.
|
||||
# loadAnns - Load anns with the specified ids.
|
||||
# loadCats - Load cats with the specified ids.
|
||||
# loadImgs - Load imgs with the specified ids.
|
||||
# annToMask - Convert segmentation in an annotation to binary mask.
|
||||
# showAnns - Display the specified annotations.
|
||||
# loadRes - Load algorithm results and create API for accessing them.
|
||||
# download - Download COCO images from mscoco.org server.
|
||||
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
|
||||
# Help on each functions can be accessed by: "help COCO>function".
|
||||
|
||||
# See also COCO>decodeMask,
|
||||
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
|
||||
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
|
||||
# COCO>loadImgs, COCO>annToMask, COCO>showAnns
|
||||
|
||||
# Microsoft COCO Toolbox. version 2.0
|
||||
# Data, paper, and tutorials available at: http://mscoco.org/
|
||||
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
|
||||
# Licensed under the Simplified BSD License [see bsd.txt]
|
||||
|
||||
import json
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
import numpy as np
|
||||
import copy
|
||||
import itertools
|
||||
from pycocotools import mask as maskUtils
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
PYTHON_VERSION = sys.version_info[0]
|
||||
if PYTHON_VERSION == 2:
|
||||
from urllib import urlretrieve
|
||||
elif PYTHON_VERSION == 3:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _isArrayLike(obj):
|
||||
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
||||
|
||||
|
||||
class COCO:
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
||||
:param annotation_file (str): location of annotation file
|
||||
:param image_folder (str): location to the folder that hosts images.
|
||||
:return:
|
||||
"""
|
||||
# load dataset
|
||||
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
||||
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
||||
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
||||
self.dataset = dataset
|
||||
self.createIndex()
|
||||
|
||||
def createIndex(self):
|
||||
# create index
|
||||
print('creating index...')
|
||||
anns, cats, imgs = {}, {}, {}
|
||||
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
|
||||
if 'annotations' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
anns[ann['id']] = ann
|
||||
|
||||
if 'images' in self.dataset:
|
||||
for img in self.dataset['images']:
|
||||
imgs[img['id']] = img
|
||||
|
||||
if 'categories' in self.dataset:
|
||||
for cat in self.dataset['categories']:
|
||||
cats[cat['id']] = cat
|
||||
|
||||
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
catToImgs[ann['category_id']].append(ann['image_id'])
|
||||
|
||||
print('index created!')
|
||||
|
||||
# create class members
|
||||
self.anns = anns
|
||||
self.imgToAnns = imgToAnns
|
||||
self.catToImgs = catToImgs
|
||||
self.imgs = imgs
|
||||
self.cats = cats
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Print information about the annotation file.
|
||||
:return:
|
||||
"""
|
||||
for key, value in self.dataset['info'].items():
|
||||
print('{}: {}'.format(key, value))
|
||||
|
||||
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
|
||||
"""
|
||||
Get ann ids that satisfy given filter conditions. default skips that filter
|
||||
:param imgIds (int array) : get anns for given imgs
|
||||
catIds (int array) : get anns for given cats
|
||||
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
||||
iscrowd (boolean) : get anns for given crowd label (False or True)
|
||||
:return: ids (int array) : integer array of ann ids
|
||||
"""
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == len(areaRng) == 0:
|
||||
anns = self.dataset['annotations']
|
||||
else:
|
||||
if not len(imgIds) == 0:
|
||||
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
|
||||
anns = list(itertools.chain.from_iterable(lists))
|
||||
else:
|
||||
anns = self.dataset['annotations']
|
||||
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
||||
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
|
||||
if not iscrowd == None:
|
||||
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
||||
else:
|
||||
ids = [ann['id'] for ann in anns]
|
||||
return ids
|
||||
|
||||
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
||||
"""
|
||||
filtering parameters. default skips that filter.
|
||||
:param catNms (str array) : get cats for given cat names
|
||||
:param supNms (str array) : get cats for given supercategory names
|
||||
:param catIds (int array) : get cats for given cat ids
|
||||
:return: ids (int array) : integer array of cat ids
|
||||
"""
|
||||
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
||||
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(catNms) == len(supNms) == len(catIds) == 0:
|
||||
cats = self.dataset['categories']
|
||||
else:
|
||||
cats = self.dataset['categories']
|
||||
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
||||
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
||||
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
||||
ids = [cat['id'] for cat in cats]
|
||||
return ids
|
||||
|
||||
def getImgIds(self, imgIds=[], catIds=[]):
|
||||
'''
|
||||
Get img ids that satisfy given filter conditions.
|
||||
:param imgIds (int array) : get imgs for given ids
|
||||
:param catIds (int array) : get imgs with all given cats
|
||||
:return: ids (int array) : integer array of img ids
|
||||
'''
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == 0:
|
||||
ids = self.imgs.keys()
|
||||
else:
|
||||
ids = set(imgIds)
|
||||
for i, catId in enumerate(catIds):
|
||||
if i == 0 and len(ids) == 0:
|
||||
ids = set(self.catToImgs[catId])
|
||||
else:
|
||||
ids &= set(self.catToImgs[catId])
|
||||
return list(ids)
|
||||
|
||||
def loadAnns(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying anns
|
||||
:return: anns (object array) : loaded ann objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.anns[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.anns[ids]]
|
||||
|
||||
def loadCats(self, ids=[]):
|
||||
"""
|
||||
Load cats with the specified ids.
|
||||
:param ids (int array) : integer ids specifying cats
|
||||
:return: cats (object array) : loaded cat objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.cats[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.cats[ids]]
|
||||
|
||||
def loadImgs(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying img
|
||||
:return: imgs (object array) : loaded img objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.imgs[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.imgs[ids]]
|
||||
|
||||
def showAnns(self, anns, draw_bbox=False):
|
||||
"""
|
||||
Display the specified annotations.
|
||||
:param anns (array of object): annotations to display
|
||||
:return: None
|
||||
"""
|
||||
if len(anns) == 0:
|
||||
return 0
|
||||
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
|
||||
datasetType = 'instances'
|
||||
elif 'caption' in anns[0]:
|
||||
datasetType = 'captions'
|
||||
else:
|
||||
raise Exception('datasetType not supported')
|
||||
if datasetType == 'instances':
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in anns:
|
||||
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
|
||||
if 'segmentation' in ann:
|
||||
if type(ann['segmentation']) == list:
|
||||
# polygon
|
||||
for seg in ann['segmentation']:
|
||||
poly = np.array(seg).reshape((int(len(seg)/2), 2))
|
||||
polygons.append(Polygon(poly))
|
||||
color.append(c)
|
||||
else:
|
||||
# mask
|
||||
t = self.imgs[ann['image_id']]
|
||||
if type(ann['segmentation']['counts']) == list:
|
||||
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
|
||||
else:
|
||||
rle = [ann['segmentation']]
|
||||
m = maskUtils.decode(rle)
|
||||
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
||||
if ann['iscrowd'] == 1:
|
||||
color_mask = np.array([2.0,166.0,101.0])/255
|
||||
if ann['iscrowd'] == 0:
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack( (img, m*0.5) ))
|
||||
if 'keypoints' in ann and type(ann['keypoints']) == list:
|
||||
# turn skeleton into zero-based index
|
||||
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
|
||||
kp = np.array(ann['keypoints'])
|
||||
x = kp[0::3]
|
||||
y = kp[1::3]
|
||||
v = kp[2::3]
|
||||
for sk in sks:
|
||||
if np.all(v[sk]>0):
|
||||
plt.plot(x[sk],y[sk], linewidth=3, color=c)
|
||||
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
|
||||
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
|
||||
|
||||
if draw_bbox:
|
||||
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
||||
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
||||
np_poly = np.array(poly).reshape((4,2))
|
||||
polygons.append(Polygon(np_poly))
|
||||
color.append(c)
|
||||
|
||||
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||||
ax.add_collection(p)
|
||||
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
|
||||
ax.add_collection(p)
|
||||
elif datasetType == 'captions':
|
||||
for ann in anns:
|
||||
print(ann['caption'])
|
||||
|
||||
def loadRes(self, resFile):
|
||||
"""
|
||||
Load result file and return a result api object.
|
||||
:param resFile (str) : file name of result file
|
||||
:return: res (obj) : result api object
|
||||
"""
|
||||
res = COCO()
|
||||
res.dataset['images'] = [img for img in self.dataset['images']]
|
||||
|
||||
print('Loading and preparing results...')
|
||||
tic = time.time()
|
||||
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
||||
with open(resFile) as f:
|
||||
anns = json.load(f)
|
||||
elif type(resFile) == np.ndarray:
|
||||
anns = self.loadNumpyAnnotations(resFile)
|
||||
else:
|
||||
anns = resFile
|
||||
assert type(anns) == list, 'results in not an array of objects'
|
||||
annsImgIds = [ann['image_id'] for ann in anns]
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
|
||||
'Results do not correspond to current coco set'
|
||||
if 'caption' in anns[0]:
|
||||
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
|
||||
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
|
||||
for id, ann in enumerate(anns):
|
||||
ann['id'] = id+1
|
||||
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
bb = ann['bbox']
|
||||
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
|
||||
if not 'segmentation' in ann:
|
||||
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
ann['area'] = bb[2]*bb[3]
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'segmentation' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
# now only support compressed RLE format as segmentation results
|
||||
ann['area'] = maskUtils.area(ann['segmentation'])
|
||||
if not 'bbox' in ann:
|
||||
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'keypoints' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
s = ann['keypoints']
|
||||
x = s[0::3]
|
||||
y = s[1::3]
|
||||
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
||||
ann['area'] = (x1-x0)*(y1-y0)
|
||||
ann['id'] = id + 1
|
||||
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
|
||||
print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
||||
|
||||
res.dataset['annotations'] = anns
|
||||
res.createIndex()
|
||||
return res
|
||||
|
||||
def download(self, tarDir = None, imgIds = [] ):
|
||||
'''
|
||||
Download COCO images from mscoco.org server.
|
||||
:param tarDir (str): COCO results directory name
|
||||
imgIds (list): images to be downloaded
|
||||
:return:
|
||||
'''
|
||||
if tarDir is None:
|
||||
print('Please specify target directory')
|
||||
return -1
|
||||
if len(imgIds) == 0:
|
||||
imgs = self.imgs.values()
|
||||
else:
|
||||
imgs = self.loadImgs(imgIds)
|
||||
N = len(imgs)
|
||||
if not os.path.exists(tarDir):
|
||||
os.makedirs(tarDir)
|
||||
for i, img in enumerate(imgs):
|
||||
tic = time.time()
|
||||
fname = os.path.join(tarDir, img['file_name'])
|
||||
if not os.path.exists(fname):
|
||||
urlretrieve(img['coco_url'], fname)
|
||||
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
|
||||
|
||||
def loadNumpyAnnotations(self, data):
|
||||
"""
|
||||
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
|
||||
:param data (numpy.ndarray)
|
||||
:return: annotations (python nested list)
|
||||
"""
|
||||
print('Converting ndarray to lists...')
|
||||
assert(type(data) == np.ndarray)
|
||||
print(data.shape)
|
||||
assert(data.shape[1] == 7)
|
||||
N = data.shape[0]
|
||||
ann = []
|
||||
for i in range(N):
|
||||
if i % 1000000 == 0:
|
||||
print('{}/{}'.format(i,N))
|
||||
ann += [{
|
||||
'image_id' : int(data[i, 0]),
|
||||
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
|
||||
'score' : data[i, 5],
|
||||
'category_id': int(data[i, 6]),
|
||||
}]
|
||||
return ann
|
||||
|
||||
def annToRLE(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
t = self.imgs[ann['image_id']]
|
||||
h, w = t['height'], t['width']
|
||||
segm = ann['segmentation']
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = ann['segmentation']
|
||||
return rle
|
||||
|
||||
def annToMask(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
rle = self.annToRLE(ann)
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
||||
11
lib/train/dataset/__init__.py
Normal file
11
lib/train/dataset/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .lasot import Lasot
|
||||
from .got10k import Got10k
|
||||
from .tracking_net import TrackingNet
|
||||
from .imagenetvid import ImagenetVID
|
||||
from .coco import MSCOCO
|
||||
from .coco_seq import MSCOCOSeq
|
||||
from .got10k_lmdb import Got10k_lmdb
|
||||
from .lasot_lmdb import Lasot_lmdb
|
||||
from .imagenetvid_lmdb import ImagenetVID_lmdb
|
||||
from .coco_seq_lmdb import MSCOCOSeq_lmdb
|
||||
from .tracking_net_lmdb import TrackingNet_lmdb
|
||||
92
lib/train/dataset/base_image_dataset.py
Normal file
92
lib/train/dataset/base_image_dataset.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch.utils.data
|
||||
from lib.train.data.image_loader import jpeg4py_loader
|
||||
|
||||
|
||||
class BaseImageDataset(torch.utils.data.Dataset):
|
||||
""" Base class for image datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.image_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_images()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_images(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.image_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def get_class_name(self, image_id):
|
||||
return None
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_image_info(self, seq_id):
|
||||
""" Returns information about a particular image,
|
||||
|
||||
args:
|
||||
seq_id - index of the image
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
""" Get a image
|
||||
|
||||
args:
|
||||
image_id - index of image
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
image -
|
||||
anno -
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
110
lib/train/dataset/base_video_dataset.py
Normal file
110
lib/train/dataset/base_video_dataset.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import torch.utils.data
|
||||
# 2021.1.5 use jpeg4py_loader_w_failsafe as default
|
||||
from lib.train.data.image_loader import jpeg4py_loader_w_failsafe
|
||||
|
||||
|
||||
class BaseVideoDataset(torch.utils.data.Dataset):
|
||||
""" Base class for video datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.sequence_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_sequences()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def is_video_sequence(self):
|
||||
""" Returns whether the dataset is a video dataset or an image dataset
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return True
|
||||
|
||||
def is_synthetic_video_dataset(self):
|
||||
""" Returns whether the dataset contains real videos or synthetic
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_sequences(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.sequence_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
""" Returns information about a particular sequences,
|
||||
|
||||
args:
|
||||
seq_id - index of the sequence
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
""" Get a set of frames from a particular sequence
|
||||
|
||||
args:
|
||||
seq_id - index of sequence
|
||||
frame_ids - a list of frame numbers
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
list - List of frames corresponding to frame_ids
|
||||
list - List of dicts for each frame
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
156
lib/train/dataset/coco.py
Normal file
156
lib/train/dataset/coco.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
from .base_image_dataset import BaseImageDataset
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
|
||||
class MSCOCO(BaseImageDataset):
|
||||
""" The COCO object detection dataset.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None,
|
||||
split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to coco root folder
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
min_area - Objects with area less than min_area are filtered out. Default is 0.0
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list() # the parent class thing would happen in the sampler
|
||||
|
||||
self.image_list = self._get_image_list(min_area=min_area)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction))
|
||||
self.im_per_class = self._build_im_per_class()
|
||||
|
||||
def _get_image_list(self, min_area=None):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
if min_area is not None:
|
||||
image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area]
|
||||
|
||||
return image_list
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def _build_im_per_class(self):
|
||||
im_per_class = {}
|
||||
for i, im in enumerate(self.image_list):
|
||||
class_name = self.cats[self.coco_set.anns[im]['category_id']]['name']
|
||||
if class_name not in im_per_class:
|
||||
im_per_class[class_name] = [i]
|
||||
else:
|
||||
im_per_class[class_name].append(i)
|
||||
|
||||
return im_per_class
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
return self.im_per_class[class_name]
|
||||
|
||||
def get_image_info(self, im_id):
|
||||
anno = self._get_anno(im_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(4,)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno))
|
||||
|
||||
valid = (bbox[2] > 0) & (bbox[3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, im_id):
|
||||
anno = self.coco_set.anns[self.image_list[im_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_image(self, im_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, im_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def get_class_name(self, im_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
frame = self._get_image(image_id)
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_image_info(image_id)
|
||||
|
||||
object_meta = self.get_meta_info(image_id)
|
||||
|
||||
return frame, anno, object_meta
|
||||
170
lib/train/dataset/coco_seq.py
Normal file
170
lib/train/dataset/coco_seq.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from pycocotools.coco import COCO
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class MSCOCOSeq(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
# Load the COCO set.
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
177
lib/train/dataset/coco_seq_lmdb.py
Normal file
177
lib/train/dataset/coco_seq_lmdb.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.train.dataset.COCO_tool import COCO
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
import time
|
||||
|
||||
class MSCOCOSeq_lmdb(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO_lmdb', root, image_loader)
|
||||
self.root = root
|
||||
self.img_pth = 'images/{}{}/'.format(split, version)
|
||||
self.anno_path = 'annotations/instances_{}{}.json'.format(split, version)
|
||||
|
||||
# Load the COCO set.
|
||||
print('loading annotations into memory...')
|
||||
tic = time.time()
|
||||
coco_json = decode_json(root, self.anno_path)
|
||||
print('Done (t={:0.2f}s)'.format(time.time() - tic))
|
||||
|
||||
self.coco_set = COCO(coco_json)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
# img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
img = decode_img(self.root, os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
186
lib/train/dataset/got10k.py
Normal file
186
lib/train/dataset/got10k.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Got10k(BaseVideoDataset):
|
||||
""" GOT-10k dataset.
|
||||
|
||||
Publication:
|
||||
GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
|
||||
Lianghua Huang, Xin Zhao, and Kaiqi Huang
|
||||
arXiv:1810.11981, 2018
|
||||
https://arxiv.org/pdf/1810.11981.pdf
|
||||
|
||||
Download dataset from http://got-10k.aitestunion.com/downloads
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().got10k_dir if root is None else root
|
||||
super().__init__('GOT10k', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
seq_ids = pandas.read_csv(file_path, header=None, dtype=np.int64).squeeze("columns").values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
sequence_meta_info = {s: self._read_meta(os.path.join(self.root, s)) for s in self.sequence_list}
|
||||
return sequence_meta_info
|
||||
|
||||
def _read_meta(self, seq_path):
|
||||
try:
|
||||
with open(os.path.join(seq_path, 'meta_info.ini')) as f:
|
||||
meta_info = f.readlines()
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1][:-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1][:-1],
|
||||
'major_class': meta_info[7].split(': ')[-1][:-1],
|
||||
'root_class': meta_info[8].split(': ')[-1][:-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1][:-1]})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
with open(os.path.join(self.root, 'list.txt')) as f:
|
||||
dir_list = list(csv.reader(f))
|
||||
dir_list = [dir_name[0] for dir_name in dir_list]
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
with open(cover_file, 'r', newline='') as f:
|
||||
cover = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join(self.root, self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
||||
183
lib/train/dataset/got10k_lmdb.py
Normal file
183
lib/train/dataset/got10k_lmdb.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
'''2021.1.16 Gok10k for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Got10k_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
use_lmdb - whether the dataset is stored in lmdb format
|
||||
"""
|
||||
root = env_settings().got10k_lmdb_dir if root is None else root
|
||||
super().__init__('GOT10k_lmdb', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
train_lib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
def _read_meta(meta_info):
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1],
|
||||
'major_class': meta_info[7].split(': ')[-1],
|
||||
'root_class': meta_info[8].split(': ')[-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1]})
|
||||
|
||||
return object_meta
|
||||
sequence_meta_info = {}
|
||||
for s in self.sequence_list:
|
||||
try:
|
||||
meta_str = decode_str(self.root, "train/%s/meta_info.ini" %s)
|
||||
sequence_meta_info[s] = _read_meta(meta_str.split('\n'))
|
||||
except:
|
||||
sequence_meta_info[s] = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return sequence_meta_info
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
dir_str = decode_str(self.root, 'train/list.txt')
|
||||
dir_list = dir_str.split('\n')
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line in got10k is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# full occlusion and out_of_view files
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
# Read these files
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
cover_list = list(map(int, decode_str(self.root, cover_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
cover = torch.ByteTensor(cover_list)
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join("train", self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
||||
159
lib/train/dataset/imagenetvid.py
Normal file
159
lib/train/dataset/imagenetvid.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import xml.etree.ElementTree as ET
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid", root, image_loader)
|
||||
|
||||
cache_file = os.path.join(root, 'cache.json')
|
||||
if os.path.isfile(cache_file):
|
||||
# If available, load the pre-processed cache file containing meta-info for each sequence
|
||||
with open(cache_file, 'r') as f:
|
||||
sequence_list_dict = json.load(f)
|
||||
|
||||
self.sequence_list = sequence_list_dict
|
||||
else:
|
||||
# Else process the imagenet annotations and generate the cache file
|
||||
self.sequence_list = self._process_anno(root)
|
||||
|
||||
with open(cache_file, 'w') as f:
|
||||
json.dump(self.sequence_list, f)
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join(self.root, 'Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
||||
def _process_anno(self, root):
|
||||
# Builds individual tracklets
|
||||
base_vid_anno_path = os.path.join(root, 'Annotations', 'VID', 'train')
|
||||
|
||||
all_sequences = []
|
||||
for set in sorted(os.listdir(base_vid_anno_path)):
|
||||
set_id = int(set.split('_')[-1])
|
||||
for vid in sorted(os.listdir(os.path.join(base_vid_anno_path, set))):
|
||||
|
||||
vid_id = int(vid.split('_')[-1])
|
||||
anno_files = sorted(os.listdir(os.path.join(base_vid_anno_path, set, vid)))
|
||||
|
||||
frame1_anno = ET.parse(os.path.join(base_vid_anno_path, set, vid, anno_files[0]))
|
||||
image_size = [int(frame1_anno.find('size/width').text), int(frame1_anno.find('size/height').text)]
|
||||
|
||||
objects = [ET.ElementTree(file=os.path.join(base_vid_anno_path, set, vid, f)).findall('object')
|
||||
for f in anno_files]
|
||||
|
||||
tracklets = {}
|
||||
|
||||
# Find all tracklets along with start frame
|
||||
for f_id, all_targets in enumerate(objects):
|
||||
for target in all_targets:
|
||||
tracklet_id = target.find('trackid').text
|
||||
if tracklet_id not in tracklets:
|
||||
tracklets[tracklet_id] = f_id
|
||||
|
||||
for tracklet_id, tracklet_start in tracklets.items():
|
||||
tracklet_anno = []
|
||||
target_visible = []
|
||||
class_name_id = None
|
||||
|
||||
for f_id in range(tracklet_start, len(objects)):
|
||||
found = False
|
||||
for target in objects[f_id]:
|
||||
if target.find('trackid').text == tracklet_id:
|
||||
if not class_name_id:
|
||||
class_name_id = target.find('name').text
|
||||
x1 = int(target.find('bndbox/xmin').text)
|
||||
y1 = int(target.find('bndbox/ymin').text)
|
||||
x2 = int(target.find('bndbox/xmax').text)
|
||||
y2 = int(target.find('bndbox/ymax').text)
|
||||
|
||||
tracklet_anno.append([x1, y1, x2 - x1, y2 - y1])
|
||||
target_visible.append(target.find('occluded').text == '0')
|
||||
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
break
|
||||
|
||||
new_sequence = {'set_id': set_id, 'vid_id': vid_id, 'class_name': class_name_id,
|
||||
'start_frame': tracklet_start, 'anno': tracklet_anno,
|
||||
'target_visible': target_visible, 'image_size': image_size}
|
||||
all_sequences.append(new_sequence)
|
||||
|
||||
return all_sequences
|
||||
90
lib/train/dataset/imagenetvid_lmdb.py
Normal file
90
lib/train/dataset/imagenetvid_lmdb.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID_lmdb(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid_lmdb", root, image_loader)
|
||||
|
||||
sequence_list_dict = decode_json(root, "cache.json")
|
||||
self.sequence_list = sequence_list_dict
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid_lmdb'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return decode_img(self.root, frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
||||
169
lib/train/dataset/lasot.py
Normal file
169
lib/train/dataset/lasot.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Lasot(BaseVideoDataset):
|
||||
""" LaSOT dataset.
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_dir if root is None else root
|
||||
super().__init__('LaSOT', root, image_loader)
|
||||
|
||||
# Keep a list of all classes
|
||||
self.class_list = [f for f in os.listdir(self.root)]
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
sequence_list = pandas.read_csv(file_path, header=None).squeeze("columns").values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
with open(out_of_view_file, 'r') as f:
|
||||
out_of_view = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(self.root, class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
165
lib/train/dataset/lasot_lmdb.py
Normal file
165
lib/train/dataset/lasot_lmdb.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
'''2021.1.16 Lasot for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Lasot_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_lmdb_dir if root is None else root
|
||||
super().__init__('LaSOT_lmdb', root, image_loader)
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
class_list = [seq_name.split('-')[0] for seq_name in self.sequence_list]
|
||||
self.class_list = []
|
||||
for ele in class_list:
|
||||
if ele not in self.class_list:
|
||||
self.class_list.append(ele)
|
||||
# Keep a list of all classes
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split(',')))
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
out_view_list = list(map(int, decode_str(self.root, out_of_view_file).split(',')))
|
||||
out_of_view = torch.ByteTensor(out_view_list)
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
151
lib/train/dataset/tracking_net.py
Normal file
151
lib/train/dataset/tracking_net.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def list_sequences(root, set_ids):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
set_ids: Sets (0-11) which are to be used
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
sequence_list = []
|
||||
|
||||
for s in set_ids:
|
||||
anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno")
|
||||
|
||||
sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
|
||||
sequence_list += sequences_cur_set
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_dir if root is None else root
|
||||
super().__init__('TrackingNet', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root, self.set_ids)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
bb_anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False,
|
||||
low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg")
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
147
lib/train/dataset/tracking_net_lmdb.py
Normal file
147
lib/train/dataset/tracking_net_lmdb.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
import json
|
||||
from lib.utils.lmdb_utils import decode_img, decode_str
|
||||
|
||||
|
||||
def list_sequences(root):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
fname = os.path.join(root, "seq_list.json")
|
||||
with open(fname, "r") as f:
|
||||
sequence_list = json.loads(f.read())
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet_lmdb(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_lmdb_dir if root is None else root
|
||||
super().__init__('TrackingNet_lmdb', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
gt_str_list = decode_str(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("anno", vid_name + ".txt")).split('\n')[:-1]
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
return decode_img(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("frames", vid_name, str(frame_id) + ".jpg"))
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
113
lib/train/run_training.py
Normal file
113
lib/train/run_training.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import importlib
|
||||
import cv2 as cv
|
||||
import torch.backends.cudnn
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
import _init_paths
|
||||
import lib.train.admin.settings as ws_settings
|
||||
|
||||
|
||||
def init_seeds(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(1)
|
||||
cv.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None,
|
||||
use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False,
|
||||
distill=None, script_teacher=None, config_teacher=None):
|
||||
"""Run the train script.
|
||||
args:
|
||||
script_name: Name of emperiment in the "experiments/" folder.
|
||||
config_name: Name of the yaml file in the "experiments/<script_name>".
|
||||
cudnn_benchmark: Use cudnn benchmark or not (default is True).
|
||||
"""
|
||||
if save_dir is None:
|
||||
print("save_dir dir is not given. Use the default dir instead.")
|
||||
# This is needed to avoid strange crashes related to opencv
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(4)
|
||||
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
|
||||
print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name))
|
||||
|
||||
'''2021.1.5 set seed for different process'''
|
||||
if base_seed is not None:
|
||||
if local_rank != -1:
|
||||
init_seeds(base_seed + local_rank)
|
||||
else:
|
||||
init_seeds(base_seed)
|
||||
|
||||
settings = ws_settings.Settings()
|
||||
settings.script_name = script_name
|
||||
settings.config_name = config_name
|
||||
settings.project_path = 'train/{}/{}'.format(script_name, config_name)
|
||||
if script_name_prv is not None and config_name_prv is not None:
|
||||
settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv)
|
||||
settings.local_rank = local_rank
|
||||
settings.save_dir = os.path.abspath(save_dir)
|
||||
settings.use_lmdb = use_lmdb
|
||||
prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name))
|
||||
settings.use_wandb = use_wandb
|
||||
if distill:
|
||||
settings.distill = distill
|
||||
settings.script_teacher = script_teacher
|
||||
settings.config_teacher = config_teacher
|
||||
if script_teacher is not None and config_teacher is not None:
|
||||
settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher)
|
||||
settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher))
|
||||
expr_module = importlib.import_module('lib.train.train_script_distill')
|
||||
else:
|
||||
expr_module = importlib.import_module('lib.train.train_script')
|
||||
expr_func = getattr(expr_module, 'run')
|
||||
|
||||
expr_func(settings)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
|
||||
parser.add_argument('--script', type=str, required=True, help='Name of the train script.')
|
||||
parser.add_argument('--config', type=str, required=True, help="Name of the config file.")
|
||||
parser.add_argument('--cudnn_benchmark', type=bool, default=False, help='Set cudnn benchmark on (1) or off (0) (default is on).')
|
||||
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
|
||||
parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs')
|
||||
parser.add_argument('--seed', type=int, default=42, help='seed for random numbers')
|
||||
parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format
|
||||
parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.')
|
||||
parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.")
|
||||
parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb
|
||||
# for knowledge distillation
|
||||
parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation
|
||||
parser.add_argument('--script_teacher', type=str, help='teacher script name')
|
||||
parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.local_rank != -1:
|
||||
dist.init_process_group(backend='nccl')
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
else:
|
||||
torch.cuda.set_device(0)
|
||||
run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark,
|
||||
local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed,
|
||||
use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv,
|
||||
use_wandb=args.use_wandb,
|
||||
distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
203
lib/train/train_script.py
Normal file
203
lib/train/train_script.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer, LTRSeqTrainer
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader, sequence_sampler
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
# forward propagation related
|
||||
from lib.train.actors import ARTrackActor, ARTrackSeqActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
from ..utils.focal_loss import FocalLoss
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
def slt_collate(batch):
|
||||
ret = {}
|
||||
for k in batch[0].keys():
|
||||
here_list = []
|
||||
for ex in batch:
|
||||
here_list.append(ex[k])
|
||||
ret[k] = here_list
|
||||
return ret
|
||||
|
||||
class SLTLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = slt_collate
|
||||
|
||||
super(SLTLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
bins = cfg.MODEL.BINS
|
||||
search_size = cfg.DATA.SEARCH.SIZE
|
||||
# Create network
|
||||
if settings.script_name == "artrack":
|
||||
net = build_artrack(cfg)
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
net = build_artrack_seq(cfg)
|
||||
dataset_train = sequence_sampler.SequenceSampler(
|
||||
datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_GAP, max_interval=cfg.DATA.MAX_INTERVAL,
|
||||
num_search_frames=cfg.DATA.SEARCH.NUMBER, num_template_frames=1,
|
||||
frame_sample_mode='random_interval',
|
||||
prob=cfg.DATA.INTERVAL_PROB)
|
||||
loader_train = SLTLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER,
|
||||
shuffle=False, drop_last=True)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
if settings.local_rank != -1:
|
||||
# net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "artrack":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackSeqActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# if cfg.TRAIN.DEEP_SUPERVISION:
|
||||
# raise ValueError("Deep supervision is not supported now.")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
if settings.script_name == "artrack":
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
trainer = LTRSeqTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|
||||
111
lib/train/train_script_distill.py
Normal file
111
lib/train/train_script_distill.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.stark import build_starks, build_starkst
|
||||
from lib.models.stark import build_stark_lightning_x_trt
|
||||
# forward propagation related
|
||||
from lib.train.actors import STARKLightningXtrtdistillActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
|
||||
def build_network(script_name, cfg):
|
||||
# Create network
|
||||
if script_name == "stark_s":
|
||||
net = build_starks(cfg)
|
||||
elif script_name == "stark_st1" or script_name == "stark_st2":
|
||||
net = build_starkst(cfg)
|
||||
elif script_name == "stark_lightning_X_trt":
|
||||
net = build_stark_lightning_x_trt(cfg, phase="train")
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
return net
|
||||
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update the default teacher configs with teacher config file
|
||||
if not os.path.exists(settings.cfg_file_teacher):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file_teacher)
|
||||
config_module_teacher = importlib.import_module("lib.config.%s.config" % settings.script_teacher)
|
||||
cfg_teacher = config_module_teacher.cfg
|
||||
config_module_teacher.update_config_from_file(settings.cfg_file_teacher)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New teacher configuration is shown below.")
|
||||
for key in cfg_teacher.keys():
|
||||
print("%s configuration:" % key, cfg_teacher[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
"""turn on the distillation mode"""
|
||||
cfg.TRAIN.DISTILL = True
|
||||
cfg_teacher.TRAIN.DISTILL = True
|
||||
net = build_network(settings.script_name, cfg)
|
||||
net_teacher = build_network(settings.script_teacher, cfg_teacher)
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
net_teacher.cuda()
|
||||
net_teacher.eval()
|
||||
|
||||
if settings.local_rank != -1:
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
net_teacher = DDP(net_teacher, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
# settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
# settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "L1")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "stark_lightning_X_trt":
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
|
||||
actor = STARKLightningXtrtdistillActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings,
|
||||
net_teacher=net_teacher)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True, distill=True)
|
||||
3
lib/train/trainers/__init__.py
Normal file
3
lib/train/trainers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_trainer import BaseTrainer
|
||||
from .ltr_trainer import LTRTrainer
|
||||
from .ltr_seq_trainer import LTRSeqTrainer
|
||||
275
lib/train/trainers/base_trainer.py
Normal file
275
lib/train/trainers/base_trainer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import traceback
|
||||
from lib.train.admin import multigpu
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""Base trainer class. Contains functions for training and saving/loading checkpoints.
|
||||
Trainer classes should inherit from this one and overload the train_epoch function."""
|
||||
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
self.actor = actor
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.loaders = loaders
|
||||
|
||||
self.update_settings(settings)
|
||||
|
||||
self.epoch = 0
|
||||
self.stats = {}
|
||||
|
||||
self.device = getattr(settings, 'device', None)
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() and settings.use_gpu else "cpu")
|
||||
|
||||
self.actor.to(self.device)
|
||||
self.settings = settings
|
||||
|
||||
def update_settings(self, settings=None):
|
||||
"""Updates the trainer settings. Must be called to update internal settings."""
|
||||
if settings is not None:
|
||||
self.settings = settings
|
||||
|
||||
if self.settings.env.workspace_dir is not None:
|
||||
self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
|
||||
'''2021.1.4 New function: specify checkpoint dir'''
|
||||
if self.settings.save_dir is None:
|
||||
self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
|
||||
else:
|
||||
self._checkpoint_dir = os.path.join(self.settings.save_dir, 'checkpoints')
|
||||
print("checkpoints will be saved to %s" % self._checkpoint_dir)
|
||||
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(self._checkpoint_dir):
|
||||
print("Training with multiple GPUs. checkpoints directory doesn't exist. "
|
||||
"Create checkpoints directory")
|
||||
os.makedirs(self._checkpoint_dir)
|
||||
else:
|
||||
self._checkpoint_dir = None
|
||||
|
||||
def train(self, max_epochs, load_latest=False, fail_safe=True, load_previous_ckpt=False, distill=False):
|
||||
"""Do training for the given number of epochs.
|
||||
args:
|
||||
max_epochs - Max number of training epochs,
|
||||
load_latest - Bool indicating whether to resume from latest epoch.
|
||||
fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
|
||||
"""
|
||||
|
||||
epoch = -1
|
||||
num_tries = 1
|
||||
for i in range(num_tries):
|
||||
try:
|
||||
if load_latest:
|
||||
self.load_checkpoint()
|
||||
if load_previous_ckpt:
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_prv)
|
||||
self.load_state_dict(directory)
|
||||
if distill:
|
||||
directory_teacher = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_teacher)
|
||||
self.load_state_dict(directory_teacher, distill=True)
|
||||
for epoch in range(self.epoch+1, max_epochs+1):
|
||||
self.epoch = epoch
|
||||
|
||||
self.train_epoch()
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
if self.settings.scheduler_type != 'cosine':
|
||||
self.lr_scheduler.step()
|
||||
else:
|
||||
self.lr_scheduler.step(epoch - 1)
|
||||
# only save the last 10 checkpoints
|
||||
save_every_epoch = getattr(self.settings, "save_every_epoch", False)
|
||||
save_epochs = []
|
||||
if epoch > (max_epochs - 1) or save_every_epoch or epoch % 5 == 0 or epoch in save_epochs or epoch > (max_epochs - 5):
|
||||
# if epoch > (max_epochs - 10) or save_every_epoch or epoch % 100 == 0:
|
||||
if self._checkpoint_dir:
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
self.save_checkpoint()
|
||||
except:
|
||||
print('Training crashed at epoch {}'.format(epoch))
|
||||
if fail_safe:
|
||||
self.epoch -= 1
|
||||
load_latest = True
|
||||
print('Traceback for the error!')
|
||||
print(traceback.format_exc())
|
||||
print('Restarting training from last epoch ...')
|
||||
else:
|
||||
raise
|
||||
|
||||
print('Finished training!')
|
||||
|
||||
def train_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self):
|
||||
"""Saves a checkpoint of the network and other variables."""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'actor_type': actor_type,
|
||||
'net_type': net_type,
|
||||
'net': net.state_dict(),
|
||||
'net_info': getattr(net, 'info', None),
|
||||
'constructor': getattr(net, 'constructor', None),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'stats': self.stats,
|
||||
'settings': self.settings
|
||||
}
|
||||
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
|
||||
print(directory)
|
||||
if not os.path.exists(directory):
|
||||
print("directory doesn't exist. creating...")
|
||||
os.makedirs(directory)
|
||||
|
||||
# First save as a tmp file
|
||||
tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
|
||||
torch.save(state, tmp_file_path)
|
||||
|
||||
file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
|
||||
|
||||
# Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
|
||||
os.rename(tmp_file_path, file_path)
|
||||
|
||||
def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
|
||||
if checkpoint is None:
|
||||
# Load most recent checkpoint
|
||||
checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
|
||||
self.settings.project_path, net_type)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
print('No matching checkpoint file found')
|
||||
return
|
||||
elif isinstance(checkpoint, int):
|
||||
# Checkpoint is the epoch number
|
||||
checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
|
||||
net_type, checkpoint)
|
||||
elif isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
if fields is None:
|
||||
fields = checkpoint_dict.keys()
|
||||
if ignore_fields is None:
|
||||
ignore_fields = ['settings']
|
||||
|
||||
# Never load the scheduler. It exists in older checkpoints.
|
||||
ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])
|
||||
|
||||
# Load all fields
|
||||
for key in fields:
|
||||
if key in ignore_fields:
|
||||
continue
|
||||
if key == 'net':
|
||||
net.load_state_dict(checkpoint_dict[key])
|
||||
elif key == 'optimizer':
|
||||
self.optimizer.load_state_dict(checkpoint_dict[key])
|
||||
else:
|
||||
setattr(self, key, checkpoint_dict[key])
|
||||
|
||||
# Set the net info
|
||||
if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
|
||||
net.constructor = checkpoint_dict['constructor']
|
||||
if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
|
||||
net.info = checkpoint_dict['net_info']
|
||||
|
||||
# Update the epoch in lr scheduler
|
||||
if 'epoch' in fields:
|
||||
self.lr_scheduler.last_epoch = self.epoch
|
||||
# 2021.1.10 Update the epoch in data_samplers
|
||||
for loader in self.loaders:
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
return True
|
||||
|
||||
def load_state_dict(self, checkpoint=None, distill=False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
if distill:
|
||||
net = self.actor.net_teacher.module if multigpu.is_multi_gpu(self.actor.net_teacher) \
|
||||
else self.actor.net_teacher
|
||||
else:
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
net_type = type(net).__name__
|
||||
|
||||
if isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
print("Loading pretrained model from ", checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
|
||||
print("previous checkpoint is loaded.")
|
||||
print("missing keys: ", missing_k)
|
||||
print("unexpected keys:", unexpected_k)
|
||||
|
||||
return True
|
||||
322
lib/train/trainers/ltr_seq_trainer.py
Normal file
322
lib/train/trainers/ltr_seq_trainer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
# from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
from memory_profiler import profile
|
||||
# from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRSeqTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
# self.wandb_writer = None
|
||||
# if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
self.miou_list = []
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.actor.eval()
|
||||
self.data_read_done_time = time.time()
|
||||
with torch.no_grad():
|
||||
explore_result = self.actor.explore(data)
|
||||
if explore_result == None:
|
||||
print("this time i skip")
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
continue
|
||||
# get inputs
|
||||
# print(data)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
|
||||
stats = {}
|
||||
reward_record = []
|
||||
miou_record = []
|
||||
e_miou_record = []
|
||||
num_seq = len(data['num_frames'])
|
||||
|
||||
# Calculate reward tensor
|
||||
# reward_tensor = torch.zeros(explore_result['baseline_iou'].size())
|
||||
baseline_iou = explore_result['baseline_iou']
|
||||
# explore_iou = explore_result['explore_iou']
|
||||
for seq_idx in range(num_seq):
|
||||
num_frames = data['num_frames'][seq_idx] - 1
|
||||
b_miou = torch.mean(baseline_iou[:num_frames, seq_idx])
|
||||
# e_miou = torch.mean(explore_iou[:num_frames, seq_idx])
|
||||
miou_record.append(b_miou.item())
|
||||
# e_miou_record.append(e_miou.item())
|
||||
|
||||
b_reward = b_miou.item()
|
||||
# e_reward = e_miou.item()
|
||||
# iou_gap = e_reward - b_reward
|
||||
# reward_record.append(iou_gap)
|
||||
# reward_tensor[:num_frames, seq_idx] = iou_gap
|
||||
|
||||
# Training mode
|
||||
cursor = 0
|
||||
bs_backward = 1
|
||||
|
||||
# print(self.actor.net.module.box_head.decoder.layers[2].mlpx.fc1.weight)
|
||||
self.optimizer.zero_grad()
|
||||
while cursor < num_seq:
|
||||
# print("now is ", cursor , "and all is ", num_seq)
|
||||
model_inputs = {}
|
||||
model_inputs['slt_loss_weight'] = 15
|
||||
if cursor < num_seq:
|
||||
model_inputs['template_images'] = explore_result['template_images'][
|
||||
cursor:cursor + bs_backward].cuda()
|
||||
else:
|
||||
model_inputs['template_images'] = explore_result['template_images_reverse'][
|
||||
cursor - num_seq:cursor - num_seq + bs_backward].cuda()
|
||||
model_inputs['search_images'] = explore_result['search_images'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['search_anno'] = explore_result['search_anno'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['pre_seq'] = explore_result['pre_seq'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['x_feat'] = explore_result['x_feat'].squeeze(1)[:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['epoch'] = data['epoch']
|
||||
# model_inputs['template_update'] = explore_result['template_update'].squeeze(1)[:,
|
||||
# cursor:cursor + bs_backward].cuda()
|
||||
# print("this is cursor")
|
||||
# print(explore_result['pre_seq'].shape)
|
||||
# print(explore_result['x_feat'].squeeze(1).shape)
|
||||
# model_inputs['action_tensor'] = explore_result['action_tensor'][:, cursor:cursor + bs_backward].cuda()
|
||||
# model_inputs['reward_tensor'] = reward_tensor[:, cursor:cursor + bs_backward].cuda()
|
||||
|
||||
loss, stats_cur = self.actor.compute_sequence_losses(model_inputs)
|
||||
# for name, param in self.actor.net.named_parameters():
|
||||
# shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
|
||||
# print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')
|
||||
# print("i make this!")
|
||||
loss.backward()
|
||||
# print("i made that?")
|
||||
|
||||
for key, val in stats_cur.items():
|
||||
if key in stats:
|
||||
stats[key] += val * (bs_backward / num_seq)
|
||||
else:
|
||||
stats[key] = val * (bs_backward / num_seq)
|
||||
cursor += bs_backward
|
||||
grad_norm = clip_grad_norm_(self.actor.net.parameters(), 100)
|
||||
stats['grad_norm'] = grad_norm
|
||||
# print(self.actor.net.module.backbone.blocks[8].mlp.fc1.weight)
|
||||
self.optimizer.step()
|
||||
# print(self.optimizer)
|
||||
|
||||
miou = np.mean(miou_record)
|
||||
self.miou_list.append(miou)
|
||||
# stats['reward'] = np.mean(reward_record)
|
||||
# stats['e_mIoU'] = np.mean(e_miou_record)
|
||||
stats['mIoU'] = miou
|
||||
stats['mIoU10'] = np.mean(self.miou_list[-10:])
|
||||
stats['mIoU100'] = np.mean(self.miou_list[-100:])
|
||||
|
||||
batch_size = num_seq * np.max(data['num_frames'])
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
self._print_stats(i, loader, batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# # forward pass
|
||||
# if not self.use_amp:
|
||||
# loss, stats = self.actor(data)
|
||||
# else:
|
||||
# with autocast():
|
||||
# loss, stats = self.actor(data)
|
||||
#
|
||||
# # backward pass and update weights
|
||||
# if loader.training:
|
||||
# self.optimizer.zero_grad()
|
||||
# if not self.use_amp:
|
||||
# loss.backward()
|
||||
# if self.settings.grad_clip_norm > 0:
|
||||
# torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
# self.optimizer.step()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
# self.scaler.step(self.optimizer)
|
||||
# self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
# batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
# self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
# if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
# epoch_time = self.prev_time - self.start_time
|
||||
# print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
# print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
# print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
# print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (
|
||||
self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
# def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
||||
225
lib/train/trainers/ltr_trainer.py
Normal file
225
lib/train/trainers/ltr_trainer.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
|
||||
#from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
#from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
#self.wandb_writer = None
|
||||
#if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.data_read_done_time = time.time()
|
||||
# get inputs
|
||||
if self.move_data_to_gpu:
|
||||
data = data.to(self.device)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
# forward pass
|
||||
if not self.use_amp:
|
||||
loss, stats = self.actor(data)
|
||||
else:
|
||||
with autocast():
|
||||
loss, stats = self.actor(data)
|
||||
|
||||
# backward pass and update weights
|
||||
if loader.training:
|
||||
self.optimizer.zero_grad()
|
||||
if not self.use_amp:
|
||||
loss.backward()
|
||||
if self.settings.grad_clip_norm > 0:
|
||||
torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
#if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
epoch_time = self.prev_time - self.start_time
|
||||
print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
#if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
#def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
||||
1
lib/utils/__init__.py
Normal file
1
lib/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .tensor import TensorDict, TensorList
|
||||
106
lib/utils/box_ops.py
Normal file
106
lib/utils/box_ops.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
import numpy as np
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xywh_to_xyxy(x):
|
||||
x1, y1, w, h = x.unbind(-1)
|
||||
b = [x1, y1, x1 + w, y1 + h]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(x):
|
||||
x1, y1, x2, y2 = x.unbind(-1)
|
||||
b = [x1, y1, x2 - x1, y2 - y1]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
'''Note that this function only supports shape (N,4)'''
|
||||
|
||||
|
||||
def box_iou(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
area1 = box_area(boxes1) # (N,)
|
||||
area2 = box_area(boxes2) # (N,)
|
||||
|
||||
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2)
|
||||
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2)
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
inter = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
union = area1 + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
'''Note that this implementation is different from DETR's'''
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
boxes1: (N, 4)
|
||||
boxes2: (N, 4)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
# try:
|
||||
#assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
# assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2) # (N,)
|
||||
|
||||
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
area = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
return iou - (area - union) / area, iou
|
||||
|
||||
|
||||
def giou_loss(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
giou, iou = generalized_box_iou(boxes1, boxes2)
|
||||
return (1 - giou).mean(), iou
|
||||
|
||||
|
||||
def clip_box(box: list, H, W, margin=0):
|
||||
x1, y1, w, h = box
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
x1 = min(max(0, x1), W-margin)
|
||||
x2 = min(max(margin, x2), W)
|
||||
y1 = min(max(0, y1), H-margin)
|
||||
y2 = min(max(margin, y2), H)
|
||||
w = max(margin, x2-x1)
|
||||
h = max(margin, y2-y1)
|
||||
return [x1, y1, w, h]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user