init commit of samurai
This commit is contained in:
BIN
lib/train/.DS_Store
vendored
Normal file
BIN
lib/train/.DS_Store
vendored
Normal file
Binary file not shown.
1
lib/train/__init__.py
Normal file
1
lib/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .admin.multigpu import MultiGPU
|
17
lib/train/_init_paths.py
Normal file
17
lib/train/_init_paths.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = osp.dirname(__file__)
|
||||
|
||||
prj_path = osp.join(this_dir, '../..')
|
||||
add_path(prj_path)
|
3
lib/train/actors/__init__.py
Normal file
3
lib/train/actors/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_actor import BaseActor
|
||||
from .artrack import ARTrackActor
|
||||
from .artrack_seq import ARTrackSeqActor
|
281
lib/train/actors/artrack.py
Normal file
281
lib/train/actors/artrack.py
Normal file
@@ -0,0 +1,281 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
from lib.utils.merge import merge_template_search
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt)**2 + (cy_pred - cy_gt)**2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2*torch.sin(torch.arcsin(x)-torch.pi/4)**2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw+eps))**2
|
||||
py = ((cy_gt - cy_pred) / (ch+eps))**2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
#shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
#IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw**2 + ch**2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi**2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v**2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
class ARTrackActor(BaseActor):
|
||||
""" Actor for training ARTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.range = self.cfg.MODEL.RANGE
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.loss_weight['KL'] = 100
|
||||
self.loss_weight['focal'] = 2
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins * self.range
|
||||
end = self.bins * self.range + 1
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=(-1*magic_num), max=(1+magic_num))
|
||||
data['real_bbox'] = gt_bbox
|
||||
|
||||
seq_ori = (gt_bbox + magic_num) * (self.bins - 1)
|
||||
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_losses(self, pred_dict, gt_dict, return_status=True):
|
||||
bins = self.bins
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
seq_output = gt_dict['seq_output']
|
||||
pred_feat = pred_dict["feat"]
|
||||
if self.focal == None:
|
||||
weight = torch.ones(bins*self.range+2) * 1
|
||||
weight[bins*self.range+1] = 0.1
|
||||
weight[bins*self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.klloss = torch.nn.KLDivLoss(reduction='none').to(pred_feat)
|
||||
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
# compute varfifocal loss
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, bins*2+2)
|
||||
target = seq_output.reshape(-1).to(torch.int64)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
# compute giou and L1 loss
|
||||
beta = 1
|
||||
pred = pred_feat[0:4, :, 0:bins*self.range] * beta
|
||||
target = seq_output[:, 0:4].to(pred_feat)
|
||||
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range((-1*magic_num+1/(self.bins*self.range)), (1+magic_num-1/(self.bins*self.range)), 2/(self.bins*self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
target = target / (bins - 1) - magic_num
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
sious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
sious = sious.mean()
|
||||
siou_loss = sious
|
||||
l1_loss = self.objective['l1'](extra_seq, target)
|
||||
|
||||
loss = self.loss_weight['giou'] * siou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * varifocal_loss
|
||||
|
||||
if return_status:
|
||||
# status for log
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": loss.item(),
|
||||
"Loss/giou": siou_loss.item(),
|
||||
"Loss/l1": l1_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
return loss, status
|
||||
else:
|
||||
return loss
|
629
lib/train/actors/artrack_seq.py
Normal file
629
lib/train/actors/artrack_seq.py
Normal file
@@ -0,0 +1,629 @@
|
||||
from . import BaseActor
|
||||
from lib.utils.misc import NestedTensor
|
||||
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
import lib.train.data.bounding_box_utils as bbutils
|
||||
from lib.utils.merge import merge_template_search
|
||||
from torch.distributions.categorical import Categorical
|
||||
from ...utils.heapmap_utils import generate_heatmap
|
||||
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
|
||||
|
||||
|
||||
def IoU(rect1, rect2):
|
||||
""" caculate interection over union
|
||||
Args:
|
||||
rect1: (x1, y1, x2, y2)
|
||||
rect2: (x1, y1, x2, y2)
|
||||
Returns:
|
||||
iou
|
||||
"""
|
||||
# overlap
|
||||
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
|
||||
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
|
||||
|
||||
xx1 = np.maximum(tx1, x1)
|
||||
yy1 = np.maximum(ty1, y1)
|
||||
xx2 = np.minimum(tx2, x2)
|
||||
yy2 = np.minimum(ty2, y2)
|
||||
|
||||
ww = np.maximum(0, xx2 - xx1)
|
||||
hh = np.maximum(0, yy2 - yy1)
|
||||
|
||||
area = (x2 - x1) * (y2 - y1)
|
||||
target_a = (tx2 - tx1) * (ty2 - ty1)
|
||||
inter = ww * hh
|
||||
iou = inter / (area + target_a - inter)
|
||||
return iou
|
||||
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
|
||||
def generate_sa_simdr(joints):
|
||||
'''
|
||||
:param joints: [num_joints, 3]
|
||||
:param joints_vis: [num_joints, 3]
|
||||
:return: target, target_weight(1: visible, 0: invisible)
|
||||
'''
|
||||
num_joints = 48
|
||||
image_size = [256, 256]
|
||||
simdr_split_ratio = 1.5625
|
||||
sigma = 6
|
||||
|
||||
target_x1 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y1 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_x2 = np.zeros((num_joints,
|
||||
int(image_size[0] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
target_y2 = np.zeros((num_joints,
|
||||
int(image_size[1] * simdr_split_ratio)),
|
||||
dtype=np.float32)
|
||||
zero_4_begin = np.zeros((num_joints, 1), dtype=np.float32)
|
||||
|
||||
tmp_size = sigma * 3
|
||||
|
||||
for joint_id in range(num_joints):
|
||||
mu_x1 = joints[joint_id][0]
|
||||
mu_y1 = joints[joint_id][1]
|
||||
mu_x2 = joints[joint_id][2]
|
||||
mu_y2 = joints[joint_id][3]
|
||||
|
||||
x1 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y1 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
x2 = np.arange(0, int(image_size[0] * simdr_split_ratio), 1, np.float32)
|
||||
y2 = np.arange(0, int(image_size[1] * simdr_split_ratio), 1, np.float32)
|
||||
|
||||
target_x1[joint_id] = (np.exp(- ((x1 - mu_x1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y1[joint_id] = (np.exp(- ((y1 - mu_y1) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_x2[joint_id] = (np.exp(- ((x2 - mu_x2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
target_y2[joint_id] = (np.exp(- ((y2 - mu_y2) ** 2) / (2 * sigma ** 2))) / (
|
||||
sigma * np.sqrt(np.pi * 2))
|
||||
return target_x1, target_y1, target_x2, target_y2
|
||||
|
||||
|
||||
# angle cost
|
||||
def SIoU_loss(test1, test2, theta=4):
|
||||
eps = 1e-7
|
||||
cx_pred = (test1[:, 0] + test1[:, 2]) / 2
|
||||
cy_pred = (test1[:, 1] + test1[:, 3]) / 2
|
||||
cx_gt = (test2[:, 0] + test2[:, 2]) / 2
|
||||
cy_gt = (test2[:, 1] + test2[:, 3]) / 2
|
||||
|
||||
dist = ((cx_pred - cx_gt) ** 2 + (cy_pred - cy_gt) ** 2) ** 0.5
|
||||
ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred)
|
||||
x = ch / (dist + eps)
|
||||
|
||||
angle = 1 - 2 * torch.sin(torch.arcsin(x) - torch.pi / 4) ** 2
|
||||
# distance cost
|
||||
xmin = torch.min(test1[:, 0], test2[:, 0])
|
||||
xmax = torch.max(test1[:, 2], test2[:, 2])
|
||||
ymin = torch.min(test1[:, 1], test2[:, 1])
|
||||
ymax = torch.max(test1[:, 3], test2[:, 3])
|
||||
cw = xmax - xmin
|
||||
ch = ymax - ymin
|
||||
px = ((cx_gt - cx_pred) / (cw + eps)) ** 2
|
||||
py = ((cy_gt - cy_pred) / (ch + eps)) ** 2
|
||||
gama = 2 - angle
|
||||
dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py))
|
||||
|
||||
# shape cost
|
||||
w_pred = test1[:, 2] - test1[:, 0]
|
||||
h_pred = test1[:, 3] - test1[:, 1]
|
||||
w_gt = test2[:, 2] - test2[:, 0]
|
||||
h_gt = test2[:, 3] - test2[:, 1]
|
||||
ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps)
|
||||
wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps)
|
||||
omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta
|
||||
|
||||
# IoU loss
|
||||
lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2]
|
||||
rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2]
|
||||
|
||||
wh = fp16_clamp(rb - lt, min=0)
|
||||
overlap = wh[..., 0] * wh[..., 1]
|
||||
area1 = (test1[..., 2] - test1[..., 0]) * (
|
||||
test1[..., 3] - test1[..., 1])
|
||||
area2 = (test2[..., 2] - test2[..., 0]) * (
|
||||
test2[..., 3] - test2[..., 1])
|
||||
iou = overlap / (area1 + area2 - overlap)
|
||||
|
||||
SIoU = 1 - iou + (omega + dis) / 2
|
||||
return SIoU, iou
|
||||
|
||||
|
||||
def ciou(pred, target, eps=1e-7):
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw ** 2 + ch ** 2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi ** 2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + v ** 2 / (1 - ious + v))
|
||||
return cious, ious
|
||||
|
||||
|
||||
class ARTrackSeqActor(BaseActor):
|
||||
""" Actor for training OSTrack models """
|
||||
|
||||
def __init__(self, net, objective, loss_weight, settings, bins, search_size, cfg=None):
|
||||
super().__init__(net, objective)
|
||||
self.loss_weight = loss_weight
|
||||
self.settings = settings
|
||||
self.bs = self.settings.batchsize # batch size
|
||||
self.cfg = cfg
|
||||
self.bins = bins
|
||||
self.search_size = search_size
|
||||
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
self.focal = None
|
||||
self.range = cfg.MODEL.RANGE
|
||||
self.pre_num = cfg.MODEL.PRENUM
|
||||
self.loss_weight['KL'] = 0
|
||||
self.loss_weight['focal'] = 0
|
||||
self.pre_bbox = None
|
||||
self.x_feat_rem = None
|
||||
self.update_rem = None
|
||||
|
||||
def __call__(self, data):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
|
||||
template_images: (N_t, batch, 3, H, W)
|
||||
search_images: (N_s, batch, 3, H, W)
|
||||
returns:
|
||||
loss - the training loss
|
||||
status - dict containing detailed losses
|
||||
"""
|
||||
# forward pass
|
||||
out_dict = self.forward_pass(data)
|
||||
|
||||
# compute losses
|
||||
loss, status = self.compute_losses(out_dict, data)
|
||||
|
||||
return loss, status
|
||||
|
||||
def _bbox_clip(self, cx, cy, width, height, boundary):
|
||||
cx = max(0, min(cx, boundary[1]))
|
||||
cy = max(0, min(cy, boundary[0]))
|
||||
width = max(10, min(width, boundary[1]))
|
||||
height = max(10, min(height, boundary[0]))
|
||||
return cx, cy, width, height
|
||||
|
||||
def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
|
||||
"""
|
||||
args:
|
||||
im: bgr based image
|
||||
pos: center position
|
||||
model_sz: exemplar size
|
||||
s_z: original size
|
||||
avg_chans: channel average
|
||||
"""
|
||||
if isinstance(pos, float):
|
||||
pos = [pos, pos]
|
||||
sz = original_sz
|
||||
im_sz = im.shape
|
||||
c = (original_sz + 1) / 2
|
||||
# context_xmin = round(pos[0] - c) # py2 and py3 round
|
||||
context_xmin = np.floor(pos[0] - c + 0.5)
|
||||
context_xmax = context_xmin + sz - 1
|
||||
# context_ymin = round(pos[1] - c)
|
||||
context_ymin = np.floor(pos[1] - c + 0.5)
|
||||
context_ymax = context_ymin + sz - 1
|
||||
left_pad = int(max(0., -context_xmin))
|
||||
top_pad = int(max(0., -context_ymin))
|
||||
right_pad = int(max(0., context_xmax - im_sz[1] + 1))
|
||||
bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
|
||||
|
||||
context_xmin = context_xmin + left_pad
|
||||
context_xmax = context_xmax + left_pad
|
||||
context_ymin = context_ymin + top_pad
|
||||
context_ymax = context_ymax + top_pad
|
||||
|
||||
r, c, k = im.shape
|
||||
if any([top_pad, bottom_pad, left_pad, right_pad]):
|
||||
size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k)
|
||||
te_im = np.zeros(size, np.uint8)
|
||||
te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
|
||||
if top_pad:
|
||||
te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
|
||||
if bottom_pad:
|
||||
te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
|
||||
if left_pad:
|
||||
te_im[:, 0:left_pad, :] = avg_chans
|
||||
if right_pad:
|
||||
te_im[:, c + left_pad:, :] = avg_chans
|
||||
im_patch = te_im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
else:
|
||||
im_patch = im[int(context_ymin):int(context_ymax + 1),
|
||||
int(context_xmin):int(context_xmax + 1), :]
|
||||
|
||||
if not np.array_equal(model_sz, original_sz):
|
||||
try:
|
||||
im_patch = cv2.resize(im_patch, (model_sz, model_sz))
|
||||
except:
|
||||
return None
|
||||
im_patch = im_patch.transpose(2, 0, 1)
|
||||
im_patch = im_patch[np.newaxis, :, :, :]
|
||||
im_patch = im_patch.astype(np.float32)
|
||||
im_patch = torch.from_numpy(im_patch)
|
||||
im_patch = im_patch.cuda()
|
||||
return im_patch
|
||||
|
||||
def batch_init(self, images, template_bbox, initial_bbox) -> dict:
|
||||
self.frame_num = 1
|
||||
self.device = 'cuda'
|
||||
# Convert bbox (x1, y1, w, h) -> (cx, cy, w, h)
|
||||
|
||||
template_bbox = bbutils.batch_xywh2center2(template_bbox) # ndarray:(2*num_seq,4)
|
||||
initial_bbox = bbutils.batch_xywh2center2(initial_bbox) # ndarray:(2*num_seq,4)
|
||||
self.center_pos = initial_bbox[:, :2] # ndarray:(2*num_seq,2)
|
||||
self.size = initial_bbox[:, 2:] # ndarray:(2*num_seq,2)
|
||||
self.pre_bbox = initial_bbox
|
||||
for i in range(self.pre_num - 1):
|
||||
self.pre_bbox = numpy.concatenate((self.pre_bbox, initial_bbox), axis=1)
|
||||
# print(self.pre_bbox.shape)
|
||||
|
||||
template_factor = self.cfg.DATA.TEMPLATE.FACTOR
|
||||
w_z = template_bbox[:, 2] * template_factor # ndarray:(2*num_seq)
|
||||
h_z = template_bbox[:, 3] * template_factor # ndarray:(2*num_seq)
|
||||
s_z = np.ceil(np.sqrt(w_z * h_z)) # ndarray:(2*num_seq)
|
||||
|
||||
self.channel_average = []
|
||||
for img in images:
|
||||
self.channel_average.append(np.mean(img, axis=(0, 1)))
|
||||
self.channel_average = np.array(self.channel_average) # ndarray:(2*num_seq,3)
|
||||
|
||||
# get crop
|
||||
z_crop_list = []
|
||||
for i in range(len(images)):
|
||||
here_crop = self.get_subwindow(images[i], template_bbox[i, :2],
|
||||
self.cfg.DATA.TEMPLATE.SIZE, s_z[i], self.channel_average[i])
|
||||
z_crop = here_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
self.mean = [0.485, 0.456, 0.406]
|
||||
self.std = [0.229, 0.224, 0.225]
|
||||
self.inplace = False
|
||||
z_crop[0] = tvisf.normalize(z_crop[0], self.mean, self.std, self.inplace)
|
||||
z_crop_list.append(z_crop.clone())
|
||||
z_crop = torch.cat(z_crop_list, dim=0) # Tensor(2*num_seq,3,128,128)
|
||||
|
||||
self.update_rem = None
|
||||
|
||||
out = {'template_images': z_crop}
|
||||
return out
|
||||
|
||||
def batch_track(self, img, gt_boxes, template, action_mode='max') -> dict:
|
||||
search_factor = self.cfg.DATA.SEARCH.FACTOR
|
||||
w_x = self.size[:, 0] * search_factor
|
||||
h_x = self.size[:, 1] * search_factor
|
||||
s_x = np.ceil(np.sqrt(w_x * h_x))
|
||||
|
||||
gt_boxes_corner = bbutils.batch_xywh2corner(gt_boxes) # ndarray:(2*num_seq,4)
|
||||
|
||||
x_crop_list = []
|
||||
gt_in_crop_list = []
|
||||
pre_seq_list = []
|
||||
pre_seq_in_list = []
|
||||
x_feat_list = []
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
for i in range(len(img)):
|
||||
channel_avg = np.mean(img[i], axis=(0, 1))
|
||||
x_crop = self.get_subwindow(img[i], self.center_pos[i], self.cfg.DATA.SEARCH.SIZE,
|
||||
round(s_x[i]), channel_avg)
|
||||
if x_crop == None:
|
||||
return None
|
||||
for q in range(self.pre_num):
|
||||
pre_seq_temp = bbutils.batch_center2corner(self.pre_bbox[:, 0 + 4 * q:4 + 4 * q])
|
||||
if q == 0:
|
||||
pre_seq = pre_seq_temp
|
||||
else:
|
||||
pre_seq = numpy.concatenate((pre_seq, pre_seq_temp), axis=1)
|
||||
|
||||
if gt_boxes_corner is not None and np.sum(np.abs(gt_boxes_corner[i] - np.zeros(4))) > 10:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
for w in range(self.pre_num):
|
||||
|
||||
pre_in[0 + w * 4:2 + w * 4] = pre_seq[i, 0 + w * 4:2 + w * 4] - self.center_pos[i]
|
||||
pre_in[2 + w * 4:4 + w * 4] = pre_seq[i, 2 + w * 4:4 + w * 4] - self.center_pos[i]
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] * (
|
||||
self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
pre_in[0 + w * 4:4 + w * 4] = pre_in[0 + w * 4:4 + w * 4] / self.cfg.DATA.SEARCH.SIZE
|
||||
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop = np.zeros(4)
|
||||
gt_in_crop[:2] = gt_boxes_corner[i, :2] - self.center_pos[i]
|
||||
gt_in_crop[2:] = gt_boxes_corner[i, 2:] - self.center_pos[i]
|
||||
gt_in_crop = gt_in_crop * (self.cfg.DATA.SEARCH.SIZE / s_x[i]) + self.cfg.DATA.SEARCH.SIZE / 2
|
||||
gt_in_crop[2:] = gt_in_crop[2:] - gt_in_crop[:2] # (x1,y1,x2,y2) to (x1,y1,w,h)
|
||||
gt_in_crop_list.append(gt_in_crop)
|
||||
else:
|
||||
pre_in = np.zeros(4 * self.pre_num)
|
||||
pre_seq_list.append(pre_in)
|
||||
gt_in_crop_list.append(np.zeros(4))
|
||||
pre_seq_input = torch.from_numpy(pre_in).clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq_input = (pre_seq_input + 0.5) * (self.bins - 1)
|
||||
pre_seq_in_list.append(pre_seq_input.clone())
|
||||
x_crop = x_crop.float().mul(1.0 / 255.0).clamp(0.0, 1.0)
|
||||
x_crop[0] = tvisf.normalize(x_crop[0], self.mean, self.std, self.inplace)
|
||||
x_crop_list.append(x_crop.clone())
|
||||
|
||||
x_crop = torch.cat(x_crop_list, dim=0)
|
||||
pre_seq_output = torch.cat(pre_seq_in_list, dim=0).reshape(-1, 4 * self.pre_num)
|
||||
|
||||
outputs = self.net(template, x_crop, seq_input=pre_seq_output, head_type=None, stage="batch_track",
|
||||
search_feature=self.x_feat_rem, update=None)
|
||||
selected_indices = outputs['seqs'].detach()
|
||||
x_feat = outputs['x_feat'].detach().cpu()
|
||||
self.x_feat_rem = x_feat.clone()
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
pred_bbox = selected_indices[:, 0:4].data.cpu().numpy()
|
||||
bbox = (pred_bbox / (self.bins - 1) - magic_num) * s_x.reshape(-1, 1)
|
||||
cx = bbox[:, 0] + self.center_pos[:, 0] - s_x / 2
|
||||
cy = bbox[:, 1] + self.center_pos[:, 1] - s_x / 2
|
||||
width = bbox[:, 2] - bbox[:, 0]
|
||||
height = bbox[:, 3] - bbox[:, 1]
|
||||
cx = cx + width / 2
|
||||
cy = cy + height / 2
|
||||
|
||||
for i in range(len(img)):
|
||||
cx[i], cy[i], width[i], height[i] = self._bbox_clip(cx[i], cy[i], width[i],
|
||||
height[i], img[i].shape[:2])
|
||||
self.center_pos = np.stack([cx, cy], 1)
|
||||
self.size = np.stack([width, height], 1)
|
||||
for e in range(self.pre_num):
|
||||
if e != self.pre_num - 1:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = self.pre_bbox[:, 4 + e * 4:8 + e * 4]
|
||||
else:
|
||||
self.pre_bbox[:, 0 + e * 4:4 + e * 4] = numpy.stack([cx, cy, width, height], 1)
|
||||
|
||||
bbox = np.stack([cx - width / 2, cy - height / 2, width, height], 1)
|
||||
|
||||
out = {
|
||||
'search_images': x_crop,
|
||||
'pred_bboxes': bbox,
|
||||
'selected_indices': selected_indices.cpu(),
|
||||
'gt_in_crop': torch.tensor(np.stack(gt_in_crop_list, axis=0), dtype=torch.float),
|
||||
'pre_seq': torch.tensor(np.stack(pre_seq_list, axis=0), dtype=torch.float),
|
||||
'x_feat': torch.tensor([item.cpu().detach().numpy() for item in x_feat_list], dtype=torch.float),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
def explore(self, data):
|
||||
results = {}
|
||||
search_images_list = []
|
||||
search_anno_list = []
|
||||
iou_list = []
|
||||
pre_seq_list = []
|
||||
x_feat_list = []
|
||||
|
||||
num_frames = data['num_frames']
|
||||
images = data['search_images']
|
||||
gt_bbox = data['search_annos']
|
||||
template = data['template_images']
|
||||
template_bbox = data['template_annos']
|
||||
|
||||
template = template
|
||||
template_bbox = template_bbox
|
||||
template_bbox = np.array(template_bbox)
|
||||
num_seq = len(num_frames)
|
||||
|
||||
for idx in range(np.max(num_frames)):
|
||||
here_images = [img[idx] for img in images] # S, N
|
||||
here_gt_bbox = np.array([gt[idx] for gt in gt_bbox])
|
||||
|
||||
here_images = here_images
|
||||
here_gt_bbox = np.concatenate([here_gt_bbox], 0)
|
||||
|
||||
if idx == 0:
|
||||
outputs_template = self.batch_init(template, template_bbox, here_gt_bbox)
|
||||
results['template_images'] = outputs_template['template_images']
|
||||
|
||||
else:
|
||||
outputs = self.batch_track(here_images, here_gt_bbox, outputs_template['template_images'],
|
||||
action_mode='half')
|
||||
if outputs == None:
|
||||
return None
|
||||
|
||||
x_feat = outputs['x_feat']
|
||||
pred_bbox = outputs['pred_bboxes']
|
||||
search_images_list.append(outputs['search_images'])
|
||||
search_anno_list.append(outputs['gt_in_crop'])
|
||||
if len(outputs['pre_seq']) != 8:
|
||||
print(outputs['pre_seq'])
|
||||
print(len(outputs['pre_seq']))
|
||||
print(idx)
|
||||
print(data['num_frames'])
|
||||
print(data['search_annos'])
|
||||
return None
|
||||
pre_seq_list.append(outputs['pre_seq'])
|
||||
pred_bbox_corner = bbutils.batch_xywh2corner(pred_bbox)
|
||||
gt_bbox_corner = bbutils.batch_xywh2corner(here_gt_bbox)
|
||||
here_iou = []
|
||||
for i in range(num_seq):
|
||||
bbox_iou = IoU(pred_bbox_corner[i], gt_bbox_corner[i])
|
||||
here_iou.append(bbox_iou)
|
||||
iou_list.append(here_iou)
|
||||
x_feat_list.append(x_feat.clone())
|
||||
|
||||
results['x_feat'] = torch.cat([torch.stack(x_feat_list)], dim=2)
|
||||
|
||||
results['search_images'] = torch.cat([torch.stack(search_images_list)],
|
||||
dim=1)
|
||||
results['search_anno'] = torch.cat([torch.stack(search_anno_list)],
|
||||
dim=1)
|
||||
results['pre_seq'] = torch.cat([torch.stack(pre_seq_list)], dim=1)
|
||||
|
||||
iou_tensor = torch.tensor(iou_list, dtype=torch.float)
|
||||
results['baseline_iou'] = torch.cat([iou_tensor[:, :num_seq]], dim=1)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def forward_pass(self, data):
|
||||
# currently only support 1 template and 1 search region
|
||||
assert len(data['template_images']) == 1
|
||||
assert len(data['search_images']) == 1
|
||||
|
||||
template_list = []
|
||||
for i in range(self.settings.num_template):
|
||||
template_img_i = data['template_images'][i].view(-1,
|
||||
*data['template_images'].shape[2:]) # (batch, 3, 128, 128)
|
||||
template_list.append(template_img_i)
|
||||
|
||||
search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 3, 320, 320)
|
||||
|
||||
box_mask_z = None
|
||||
ce_keep_rate = None
|
||||
if self.cfg.MODEL.BACKBONE.CE_LOC:
|
||||
box_mask_z = generate_mask_cond(self.cfg, template_list[0].shape[0], template_list[0].device,
|
||||
data['template_anno'][0])
|
||||
|
||||
ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH
|
||||
ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH
|
||||
ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch,
|
||||
total_epochs=ce_start_epoch + ce_warm_epoch,
|
||||
ITERS_PER_EPOCH=1,
|
||||
base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0])
|
||||
|
||||
if len(template_list) == 1:
|
||||
template_list = template_list[0]
|
||||
gt_bbox = data['search_anno'][-1]
|
||||
begin = self.bins
|
||||
end = self.bins + 1
|
||||
gt_bbox[:, 2] = gt_bbox[:, 0] + gt_bbox[:, 2]
|
||||
gt_bbox[:, 3] = gt_bbox[:, 1] + gt_bbox[:, 3]
|
||||
gt_bbox = gt_bbox.clamp(min=0.5, max=1.5)
|
||||
data['real_bbox'] = gt_bbox
|
||||
seq_ori = gt_bbox * (self.bins - 1)
|
||||
seq_ori = seq_ori.int().to(search_img)
|
||||
B = seq_ori.shape[0]
|
||||
seq_input = torch.cat([torch.ones((B, 1)).to(search_img) * begin, seq_ori], dim=1)
|
||||
seq_output = torch.cat([seq_ori, torch.ones((B, 1)).to(search_img) * end], dim=1)
|
||||
data['seq_input'] = seq_input
|
||||
data['seq_output'] = seq_output
|
||||
out_dict = self.net(template=template_list,
|
||||
search=search_img,
|
||||
ce_template_mask=box_mask_z,
|
||||
ce_keep_rate=ce_keep_rate,
|
||||
return_last_attn=False,
|
||||
seq_input=seq_input)
|
||||
|
||||
return out_dict
|
||||
|
||||
def compute_sequence_losses(self, data):
|
||||
num_frames = data['search_images'].shape[0]
|
||||
template_images = data['template_images'].repeat(num_frames, 1, 1, 1, 1)
|
||||
template_images = template_images.view(-1, *template_images.size()[2:])
|
||||
search_images = data['search_images'].reshape(-1, *data['search_images'].size()[2:])
|
||||
search_anno = data['search_anno'].reshape(-1, *data['search_anno'].size()[2:])
|
||||
|
||||
magic_num = (self.range - 1) * 0.5
|
||||
self.loss_weight['focal'] = 0
|
||||
pre_seq = data['pre_seq'].reshape(-1, 4 * self.pre_num)
|
||||
x_feat = data['x_feat'].reshape(-1, *data['x_feat'].size()[2:])
|
||||
pre_seq = pre_seq.clamp(-1 * magic_num, 1 + magic_num)
|
||||
pre_seq = (pre_seq + magic_num) * (self.bins - 1)
|
||||
|
||||
outputs = self.net(template_images, search_images, seq_input=pre_seq, stage="forward_pass",
|
||||
search_feature=x_feat, update=None)
|
||||
|
||||
pred_feat = outputs["feat"]
|
||||
# generate labels
|
||||
if self.focal == None:
|
||||
weight = torch.ones(self.bins * self.range + 2) * 1
|
||||
weight[self.bins * self.range + 1] = 0.1
|
||||
weight[self.bins * self.range] = 0.1
|
||||
weight.to(pred_feat)
|
||||
self.focal = torch.nn.CrossEntropyLoss(weight=weight, size_average=True).to(pred_feat)
|
||||
|
||||
search_anno[:, 2] = search_anno[:, 2] + search_anno[:, 0]
|
||||
search_anno[:, 3] = search_anno[:, 3] + search_anno[:, 1]
|
||||
target = (search_anno / self.cfg.DATA.SEARCH.SIZE + 0.5) * (self.bins - 1)
|
||||
|
||||
target = target.clamp(min=0.0, max=(self.bins * self.range - 0.0001))
|
||||
target_iou = target
|
||||
target = torch.cat([target], dim=1)
|
||||
target = target.reshape(-1).to(torch.int64)
|
||||
pred = pred_feat.permute(1, 0, 2).reshape(-1, self.bins * self.range + 2)
|
||||
varifocal_loss = self.focal(pred, target)
|
||||
pred = pred_feat[0:4, :, 0:self.bins * self.range]
|
||||
target = target_iou[:, 0:4].to(pred_feat) / (self.bins - 1) - magic_num
|
||||
out = pred.softmax(-1).to(pred)
|
||||
mul = torch.range(-1 * magic_num + 1 / (self.bins * self.range), 1 + magic_num - 1 / (self.bins * self.range), 2 / (self.bins * self.range)).to(pred)
|
||||
ans = out * mul
|
||||
ans = ans.sum(dim=-1)
|
||||
ans = ans.permute(1, 0).to(pred)
|
||||
extra_seq = ans
|
||||
extra_seq = extra_seq.to(pred)
|
||||
|
||||
cious, iou = SIoU_loss(extra_seq, target, 4)
|
||||
cious = cious.mean()
|
||||
|
||||
giou_loss = cious
|
||||
loss_bb = self.loss_weight['giou'] * giou_loss + self.loss_weight[
|
||||
'focal'] * varifocal_loss
|
||||
|
||||
total_losses = loss_bb
|
||||
|
||||
mean_iou = iou.detach().mean()
|
||||
status = {"Loss/total": total_losses.item(),
|
||||
"Loss/giou": giou_loss.item(),
|
||||
"Loss/location": varifocal_loss.item(),
|
||||
"IoU": mean_iou.item()}
|
||||
|
||||
return total_losses, status
|
||||
|
44
lib/train/actors/base_actor.py
Normal file
44
lib/train/actors/base_actor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class BaseActor:
|
||||
""" Base class for actor. The actor class handles the passing of the data through the network
|
||||
and calculation the loss"""
|
||||
def __init__(self, net, objective):
|
||||
"""
|
||||
args:
|
||||
net - The network to train
|
||||
objective - The loss function
|
||||
"""
|
||||
self.net = net
|
||||
self.objective = objective
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
""" Called in each training iteration. Should pass in input data through the network, calculate the loss, and
|
||||
return the training stats for the input data
|
||||
args:
|
||||
data - A TensorDict containing all the necessary data blocks.
|
||||
|
||||
returns:
|
||||
loss - loss for the input data
|
||||
stats - a dict containing detailed losses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device):
|
||||
""" Move the network to device
|
||||
args:
|
||||
device - device to use. 'cpu' or 'cuda'
|
||||
"""
|
||||
self.net.to(device)
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set whether the network is in train mode.
|
||||
args:
|
||||
mode (True) - Bool specifying whether in training mode.
|
||||
"""
|
||||
self.net.train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set network to eval mode"""
|
||||
self.train(False)
|
3
lib/train/admin/__init__.py
Normal file
3
lib/train/admin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .environment import env_settings, create_default_local_file_ITP_train
|
||||
from .stats import AverageMeter, StatValue
|
||||
#from .tensorboard import TensorboardWriter
|
102
lib/train/admin/environment.py
Normal file
102
lib/train/admin/environment.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': empty_str,
|
||||
'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
|
||||
'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
|
||||
'lasot_dir': empty_str,
|
||||
'got10k_dir': empty_str,
|
||||
'trackingnet_dir': empty_str,
|
||||
'coco_dir': empty_str,
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': empty_str,
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def create_default_local_file_ITP_train(workspace_dir, data_dir):
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': workspace_dir,
|
||||
'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files.
|
||||
'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'),
|
||||
'lasot_dir': os.path.join(data_dir, 'lasot'),
|
||||
'got10k_dir': os.path.join(data_dir, 'got10k/train'),
|
||||
'got10k_val_dir': os.path.join(data_dir, 'got10k/val'),
|
||||
'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'),
|
||||
'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'),
|
||||
'trackingnet_dir': os.path.join(data_dir, 'trackingnet'),
|
||||
'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'),
|
||||
'coco_dir': os.path.join(data_dir, 'coco'),
|
||||
'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'),
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': os.path.join(data_dir, 'vid'),
|
||||
'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'),
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
if attr_val == empty_str:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.train.admin.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.EnvironmentSettings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
|
24
lib/train/admin/local.py
Normal file
24
lib/train/admin/local.py
Normal file
@@ -0,0 +1,24 @@
|
||||
class EnvironmentSettings:
|
||||
def __init__(self):
|
||||
self.workspace_dir = '/home/baiyifan/code/2stage_update_intrain' # Base directory for saving network checkpoints.
|
||||
self.tensorboard_dir = '/home/baiyifan/code/2stage/tensorboard' # Directory for tensorboard files.
|
||||
self.pretrained_networks = '/home/baiyifan/code/2stage/pretrained_networks'
|
||||
self.lasot_dir = '/home/baiyifan/LaSOT/LaSOTBenchmark'
|
||||
self.got10k_dir = '/home/baiyifan/GOT-10k/train'
|
||||
self.got10k_val_dir = '/home/baiyifan/GOT-10k/val'
|
||||
self.lasot_lmdb_dir = '/home/baiyifan/code/2stage/data/lasot_lmdb'
|
||||
self.got10k_lmdb_dir = '/home/baiyifan/code/2stage/data/got10k_lmdb'
|
||||
self.trackingnet_dir = '/ssddata/TrackingNet/all_zip'
|
||||
self.trackingnet_lmdb_dir = '/home/baiyifan/code/2stage/data/trackingnet_lmdb'
|
||||
self.coco_dir = '/home/baiyifan/coco'
|
||||
self.coco_lmdb_dir = '/home/baiyifan/code/2stage/data/coco_lmdb'
|
||||
self.lvis_dir = ''
|
||||
self.sbd_dir = ''
|
||||
self.imagenet_dir = '/home/baiyifan/code/2stage/data/vid'
|
||||
self.imagenet_lmdb_dir = '/home/baiyifan/code/2stage/data/vid_lmdb'
|
||||
self.imagenetdet_dir = ''
|
||||
self.ecssd_dir = ''
|
||||
self.hkuis_dir = ''
|
||||
self.msra10k_dir = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
15
lib/train/admin/multigpu.py
Normal file
15
lib/train/admin/multigpu.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||||
|
||||
|
||||
def is_multi_gpu(net):
|
||||
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||||
|
||||
|
||||
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
pass
|
||||
return getattr(self.module, item)
|
13
lib/train/admin/settings.py
Normal file
13
lib/train/admin/settings.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from lib.train.admin.environment import env_settings
|
||||
|
||||
|
||||
class Settings:
|
||||
""" Training settings, e.g. the paths to datasets and networks."""
|
||||
def __init__(self):
|
||||
self.set_default()
|
||||
|
||||
def set_default(self):
|
||||
self.env = env_settings()
|
||||
self.use_gpu = True
|
||||
|
||||
|
71
lib/train/admin/stats.py
Normal file
71
lib/train/admin/stats.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
class StatValue:
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
self.history.append(self.val)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.has_new_data = False
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.val = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def new_epoch(self):
|
||||
if self.count > 0:
|
||||
self.history.append(self.avg)
|
||||
self.reset()
|
||||
self.has_new_data = True
|
||||
else:
|
||||
self.has_new_data = False
|
||||
|
||||
|
||||
def topk_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
single_input = not isinstance(topk, (tuple, list))
|
||||
if single_input:
|
||||
topk = (topk,)
|
||||
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
|
||||
res.append(correct_k * 100.0 / batch_size)
|
||||
|
||||
if single_input:
|
||||
return res[0]
|
||||
|
||||
return res
|
27
lib/train/admin/tensorboard.py
Normal file
27
lib/train/admin/tensorboard.py
Normal file
@@ -0,0 +1,27 @@
|
||||
#import os
|
||||
#from collections import OrderedDict
|
||||
#try:
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
#except:
|
||||
# print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
#class TensorboardWriter:
|
||||
# def __init__(self, directory, loader_names):
|
||||
# self.directory = directory
|
||||
# self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
|
||||
|
||||
# def write_info(self, script_name, description):
|
||||
# tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
|
||||
# tb_info_writer.add_text('Script_name', script_name)
|
||||
# tb_info_writer.add_text('Description', description)
|
||||
# tb_info_writer.close()
|
||||
|
||||
# def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
|
||||
# for loader_name, loader_stats in stats.items():
|
||||
# if loader_stats is None:
|
||||
# continue
|
||||
# for var_name, val in loader_stats.items():
|
||||
# if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
|
||||
# self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
|
193
lib/train/base_functions.py
Normal file
193
lib/train/base_functions.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
# datasets related
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader
|
||||
import lib.train.data.transforms as tfm
|
||||
from lib.utils.misc import is_main_process
|
||||
|
||||
|
||||
def update_settings(settings, cfg):
|
||||
settings.print_interval = cfg.TRAIN.PRINT_INTERVAL
|
||||
settings.search_area_factor = {'template': cfg.DATA.TEMPLATE.FACTOR,
|
||||
'search': cfg.DATA.SEARCH.FACTOR}
|
||||
settings.output_sz = {'template': cfg.DATA.TEMPLATE.SIZE,
|
||||
'search': cfg.DATA.SEARCH.SIZE}
|
||||
settings.center_jitter_factor = {'template': cfg.DATA.TEMPLATE.CENTER_JITTER,
|
||||
'search': cfg.DATA.SEARCH.CENTER_JITTER}
|
||||
settings.scale_jitter_factor = {'template': cfg.DATA.TEMPLATE.SCALE_JITTER,
|
||||
'search': cfg.DATA.SEARCH.SCALE_JITTER}
|
||||
settings.grad_clip_norm = cfg.TRAIN.GRAD_CLIP_NORM
|
||||
settings.print_stats = None
|
||||
settings.batchsize = cfg.TRAIN.BATCH_SIZE
|
||||
settings.scheduler_type = cfg.TRAIN.SCHEDULER.TYPE
|
||||
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
|
||||
def build_dataloaders(cfg, settings):
|
||||
# Data transform
|
||||
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05),
|
||||
tfm.RandomHorizontalFlip(probability=0.5))
|
||||
|
||||
transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
|
||||
tfm.RandomHorizontalFlip_Norm(probability=0.5),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
transform_val = tfm.Transform(tfm.ToTensor(),
|
||||
tfm.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD))
|
||||
|
||||
# The tracking pairs processing module
|
||||
output_sz = settings.output_sz
|
||||
search_area_factor = settings.search_area_factor
|
||||
|
||||
data_processing_train = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_train,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
data_processing_val = processing.STARKProcessing(search_area_factor=search_area_factor,
|
||||
output_sz=output_sz,
|
||||
center_jitter_factor=settings.center_jitter_factor,
|
||||
scale_jitter_factor=settings.scale_jitter_factor,
|
||||
mode='sequence',
|
||||
transform=transform_val,
|
||||
joint_transform=transform_joint,
|
||||
settings=settings)
|
||||
|
||||
# Train sampler and loader
|
||||
settings.num_template = getattr(cfg.DATA.TEMPLATE, "NUMBER", 1)
|
||||
settings.num_search = getattr(cfg.DATA.SEARCH, "NUMBER", 1)
|
||||
sampler_mode = getattr(cfg.DATA, "SAMPLER_MODE", "causal")
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
print("sampler_mode", sampler_mode)
|
||||
dataset_train = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_train,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
|
||||
train_sampler = DistributedSampler(dataset_train) if settings.local_rank != -1 else None
|
||||
shuffle = False if settings.local_rank != -1 else True
|
||||
|
||||
loader_train = LTRLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=shuffle,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=train_sampler)
|
||||
|
||||
# Validation samplers and loaders
|
||||
dataset_val = sampler.TrackingSampler(datasets=names2datasets(cfg.DATA.VAL.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.VAL.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.VAL.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_SAMPLE_INTERVAL, num_search_frames=settings.num_search,
|
||||
num_template_frames=settings.num_template, processing=data_processing_val,
|
||||
frame_sample_mode=sampler_mode, train_cls=train_cls)
|
||||
val_sampler = DistributedSampler(dataset_val) if settings.local_rank != -1 else None
|
||||
loader_val = LTRLoader('val', dataset_val, training=False, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER, drop_last=True, stack_dim=1, sampler=val_sampler,
|
||||
epoch_interval=cfg.TRAIN.VAL_EPOCH_INTERVAL)
|
||||
|
||||
return loader_train, loader_val
|
||||
|
||||
|
||||
def get_optimizer_scheduler(net, cfg):
|
||||
train_cls = getattr(cfg.TRAIN, "TRAIN_CLS", False)
|
||||
if train_cls:
|
||||
print("Only training classification head. Learnable parameters are shown below.")
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "cls" in n and p.requires_grad]}
|
||||
]
|
||||
|
||||
for n, p in net.named_parameters():
|
||||
if "cls" not in n:
|
||||
p.requires_grad = False
|
||||
else:
|
||||
print(n)
|
||||
else:
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in net.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in net.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.TRAIN.LR * cfg.TRAIN.BACKBONE_MULTIPLIER,
|
||||
},
|
||||
]
|
||||
if is_main_process():
|
||||
print("Learnable parameters are shown below.")
|
||||
for n, p in net.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
if cfg.TRAIN.OPTIMIZER == "ADAMW":
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR,
|
||||
weight_decay=cfg.TRAIN.WEIGHT_DECAY)
|
||||
else:
|
||||
raise ValueError("Unsupported Optimizer")
|
||||
if cfg.TRAIN.SCHEDULER.TYPE == 'step':
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP_EPOCH)
|
||||
elif cfg.TRAIN.SCHEDULER.TYPE == "Mstep":
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=cfg.TRAIN.SCHEDULER.MILESTONES,
|
||||
gamma=cfg.TRAIN.SCHEDULER.GAMMA)
|
||||
else:
|
||||
raise ValueError("Unsupported scheduler")
|
||||
return optimizer, lr_scheduler
|
2
lib/train/data/__init__.py
Normal file
2
lib/train/data/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .loader import LTRLoader
|
||||
from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader
|
150
lib/train/data/bounding_box_utils.py
Normal file
150
lib/train/data/bounding_box_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def batch_center2corner(boxes):
|
||||
xmin = boxes[:, 0] - boxes[:, 2] * 0.5
|
||||
ymin = boxes[:, 1] - boxes[:, 3] * 0.5
|
||||
xmax = boxes[:, 0] + boxes[:, 2] * 0.5
|
||||
ymax = boxes[:, 1] + boxes[:, 3] * 0.5
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def batch_corner2center(boxes):
|
||||
cx = (boxes[:, 0] + boxes[:, 2]) * 0.5
|
||||
cy = (boxes[:, 1] + boxes[:, 3]) * 0.5
|
||||
w = (boxes[:, 2] - boxes[:, 0])
|
||||
h = (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center(boxes):
|
||||
cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2
|
||||
cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
def batch_xywh2center2(boxes):
|
||||
cx = boxes[:, 0] + boxes[:, 2] / 2
|
||||
cy = boxes[:, 1] + boxes[:, 3] / 2
|
||||
w = boxes[:, 2]
|
||||
h = boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([cx, cy, w, h], 1)
|
||||
else:
|
||||
return torch.stack([cx, cy, w, h], 1)
|
||||
|
||||
|
||||
def batch_xywh2corner(boxes):
|
||||
xmin = boxes[:, 0]
|
||||
ymin = boxes[:, 1]
|
||||
xmax = boxes[:, 0] + boxes[:, 2]
|
||||
ymax = boxes[:, 1] + boxes[:, 3]
|
||||
|
||||
if isinstance(boxes, np.ndarray):
|
||||
return np.stack([xmin, ymin, xmax, ymax], 1)
|
||||
else:
|
||||
return torch.stack([xmin, ymin, xmax, ymax], 1)
|
||||
|
||||
def rect_to_rel(bb, sz_norm=None):
|
||||
"""Convert standard rectangular parametrization of the bounding box [x, y, w, h]
|
||||
to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
|
||||
args:
|
||||
bb - N x 4 tensor of boxes.
|
||||
sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
|
||||
"""
|
||||
|
||||
c = bb[...,:2] + 0.5 * bb[...,2:]
|
||||
if sz_norm is None:
|
||||
c_rel = c / bb[...,2:]
|
||||
else:
|
||||
c_rel = c / sz_norm
|
||||
sz_rel = torch.log(bb[...,2:])
|
||||
return torch.cat((c_rel, sz_rel), dim=-1)
|
||||
|
||||
|
||||
def rel_to_rect(bb, sz_norm=None):
|
||||
"""Inverts the effect of rect_to_rel. See above."""
|
||||
|
||||
sz = torch.exp(bb[...,2:])
|
||||
if sz_norm is None:
|
||||
c = bb[...,:2] * sz
|
||||
else:
|
||||
c = bb[...,:2] * sz_norm
|
||||
tl = c - 0.5 * sz
|
||||
return torch.cat((tl, sz), dim=-1)
|
||||
|
||||
|
||||
def masks_to_bboxes(mask, fmt='c'):
|
||||
|
||||
""" Convert a mask tensor to one or more bounding boxes.
|
||||
Note: This function is a bit new, make sure it does what it says. /Andreas
|
||||
:param mask: Tensor of masks, shape = (..., H, W)
|
||||
:param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
|
||||
't' => "top left + size" or (x_left, y_top, width, height)
|
||||
'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
|
||||
:return: tensor containing a batch of bounding boxes, shape = (..., 4)
|
||||
"""
|
||||
batch_shape = mask.shape[:-2]
|
||||
mask = mask.reshape((-1, *mask.shape[-2:]))
|
||||
bboxes = []
|
||||
|
||||
for m in mask:
|
||||
mx = m.sum(dim=-2).nonzero()
|
||||
my = m.sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
bboxes.append(bb)
|
||||
|
||||
bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
|
||||
bboxes = bboxes.reshape(batch_shape + (4,))
|
||||
|
||||
if fmt == 'v':
|
||||
return bboxes
|
||||
|
||||
x1 = bboxes[..., :2]
|
||||
s = bboxes[..., 2:] - x1 + 1
|
||||
|
||||
if fmt == 'c':
|
||||
return torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
return torch.cat((x1, s), dim=-1)
|
||||
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
|
||||
|
||||
def masks_to_bboxes_multi(mask, ids, fmt='c'):
|
||||
assert mask.dim() == 2
|
||||
bboxes = []
|
||||
|
||||
for id in ids:
|
||||
mx = (mask == id).sum(dim=-2).nonzero()
|
||||
my = (mask == id).float().sum(dim=-1).nonzero()
|
||||
bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
|
||||
|
||||
bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
|
||||
|
||||
x1 = bb[:2]
|
||||
s = bb[2:] - x1 + 1
|
||||
|
||||
if fmt == 'v':
|
||||
pass
|
||||
elif fmt == 'c':
|
||||
bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
|
||||
elif fmt == 't':
|
||||
bb = torch.cat((x1, s), dim=-1)
|
||||
else:
|
||||
raise ValueError("Undefined bounding box layout '%s'" % fmt)
|
||||
bboxes.append(bb)
|
||||
|
||||
return bboxes
|
103
lib/train/data/image_loader.py
Normal file
103
lib/train/data/image_loader.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import jpeg4py
|
||||
import cv2 as cv
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
|
||||
davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
|
||||
[64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
|
||||
[0, 64, 128], [128, 64, 128]]
|
||||
|
||||
|
||||
def default_image_loader(path):
|
||||
"""The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
|
||||
but reverts to the opencv_loader if the former is not available."""
|
||||
if default_image_loader.use_jpeg4py is None:
|
||||
# Try using jpeg4py
|
||||
im = jpeg4py_loader(path)
|
||||
if im is None:
|
||||
default_image_loader.use_jpeg4py = False
|
||||
print('Using opencv_loader instead.')
|
||||
else:
|
||||
default_image_loader.use_jpeg4py = True
|
||||
return im
|
||||
if default_image_loader.use_jpeg4py:
|
||||
return jpeg4py_loader(path)
|
||||
return opencv_loader(path)
|
||||
|
||||
default_image_loader.use_jpeg4py = None
|
||||
|
||||
|
||||
def jpeg4py_loader(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_loader(path):
|
||||
""" Read image using opencv's imread function and returns it in rgb format"""
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def jpeg4py_loader_w_failsafe(path):
|
||||
""" Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
|
||||
try:
|
||||
return jpeg4py.JPEG(path).decode()
|
||||
except:
|
||||
try:
|
||||
im = cv.imread(path, cv.IMREAD_COLOR)
|
||||
|
||||
# convert to rgb and return
|
||||
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def opencv_seg_loader(path):
|
||||
""" Read segmentation annotation using opencv's imread function"""
|
||||
try:
|
||||
return cv.imread(path)
|
||||
except Exception as e:
|
||||
print('ERROR: Could not read image "{}"'.format(path))
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def imread_indexed(filename):
|
||||
""" Load indexed image with given filename. Used to read segmentation annotations."""
|
||||
|
||||
im = Image.open(filename)
|
||||
|
||||
annotation = np.atleast_3d(im)[...,0]
|
||||
return annotation
|
||||
|
||||
|
||||
def imwrite_indexed(filename, array, color_palette=None):
|
||||
""" Save indexed image as png. Used to save segmentation annotation."""
|
||||
|
||||
if color_palette is None:
|
||||
color_palette = davis_palette
|
||||
|
||||
if np.atleast_3d(array).shape[2] != 1:
|
||||
raise Exception("Saving indexed PNGs requires 2D array.")
|
||||
|
||||
im = Image.fromarray(array)
|
||||
im.putpalette(color_palette.ravel())
|
||||
im.save(filename, format='PNG')
|
199
lib/train/data/loader.py
Normal file
199
lib/train/data/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
import torch.utils.data.dataloader
|
||||
import importlib
|
||||
import collections
|
||||
# from torch._six import string_classes
|
||||
from lib.utils import TensorDict, TensorList
|
||||
|
||||
if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3:
|
||||
int_classes = int
|
||||
else:
|
||||
from torch._six import int_classes
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
string_classes = str
|
||||
|
||||
def _check_use_shared_memory():
|
||||
if hasattr(torch.utils.data.dataloader, '_use_shared_memory'):
|
||||
return getattr(torch.utils.data.dataloader, '_use_shared_memory')
|
||||
collate_lib = importlib.import_module('torch.utils.data._utils.collate')
|
||||
if hasattr(collate_lib, '_use_shared_memory'):
|
||||
return getattr(collate_lib, '_use_shared_memory')
|
||||
return torch.utils.data.get_worker_info() is not None
|
||||
|
||||
|
||||
def ltr_collate(batch):
|
||||
"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 0, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
def ltr_collate_stack1(batch):
|
||||
"""Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if isinstance(batch[0], torch.Tensor):
|
||||
out = None
|
||||
if _check_use_shared_memory():
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
return torch.stack(batch, 1, out=out)
|
||||
# if batch[0].dim() < 4:
|
||||
# return torch.stack(batch, 0, out=out)
|
||||
# return torch.cat(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return torch.stack([torch.from_numpy(b) for b in batch], 1)
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
||||
elif isinstance(batch[0], int_classes):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], string_classes):
|
||||
return batch
|
||||
elif isinstance(batch[0], TensorDict):
|
||||
return TensorDict({key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]})
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], TensorList):
|
||||
transposed = zip(*batch)
|
||||
return TensorList([ltr_collate_stack1(samples) for samples in transposed])
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [ltr_collate_stack1(samples) for samples in transposed]
|
||||
elif batch[0] is None:
|
||||
return batch
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
|
||||
class LTRLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
|
||||
Note: The only difference with default pytorch DataLoader is that an additional option stack_dim is available to
|
||||
select along which dimension the data should be stacked to form a batch.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: 1).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: False).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, ``shuffle`` must be False.
|
||||
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
||||
indices at a time. Mutually exclusive with batch_size, shuffle,
|
||||
sampler, and drop_last.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means that the data will be loaded in the main process.
|
||||
(default: 0)
|
||||
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
||||
stack_dim (int): Dimension along which to stack to form the batch. (default: 0)
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
||||
into CUDA pinned memory before returning them.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: False)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: 0)
|
||||
worker_init_fn (callable, optional): If not None, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: None)
|
||||
|
||||
.. note:: By default, each worker will have its PyTorch seed set to
|
||||
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
||||
by main process using its RNG. However, seeds for other libraries
|
||||
may be duplicated upon initializing workers (w.g., NumPy), causing
|
||||
each worker to return identical random numbers. (See
|
||||
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
|
||||
use ``torch.initial_seed()`` to access the PyTorch seed for each
|
||||
worker in :attr:`worker_init_fn`, and use it to set other seeds
|
||||
before data loading.
|
||||
|
||||
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
|
||||
unpicklable object, e.g., a lambda function.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
print("pin_memory is", pin_memory)
|
||||
if collate_fn is None:
|
||||
if stack_dim == 0:
|
||||
collate_fn = ltr_collate
|
||||
elif stack_dim == 1:
|
||||
collate_fn = ltr_collate_stack1
|
||||
else:
|
||||
raise ValueError('Stack dim no supported. Must be 0 or 1.')
|
||||
|
||||
super(LTRLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
155
lib/train/data/processing.py
Normal file
155
lib/train/data/processing.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from lib.utils import TensorDict
|
||||
import lib.train.data.processing_utils as prutils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stack_tensors(x):
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor):
|
||||
return torch.stack(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseProcessing:
|
||||
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
|
||||
through the network. For example, it can be used to crop a search region around the object, apply various data
|
||||
augmentations, etc."""
|
||||
def __init__(self, transform=transforms.ToTensor(), template_transform=None, search_transform=None, joint_transform=None):
|
||||
"""
|
||||
args:
|
||||
transform - The set of transformations to be applied on the images. Used only if template_transform or
|
||||
search_transform is None.
|
||||
template_transform - The set of transformations to be applied on the template images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
search_transform - The set of transformations to be applied on the search images. If None, the 'transform'
|
||||
argument is used instead.
|
||||
joint_transform - The set of transformations to be applied 'jointly' on the template and search images. For
|
||||
example, it can be used to convert both template and search images to grayscale.
|
||||
"""
|
||||
self.transform = {'template': transform if template_transform is None else template_transform,
|
||||
'search': transform if search_transform is None else search_transform,
|
||||
'joint': joint_transform}
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class STARKProcessing(BaseProcessing):
|
||||
""" The processing class used for training LittleBoy. The images are processed in the following way.
|
||||
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
|
||||
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
|
||||
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
|
||||
always at the center of the search region. The search region is then resized to a fixed size given by the
|
||||
argument output_sz.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor,
|
||||
mode='pair', settings=None, *args, **kwargs):
|
||||
"""
|
||||
args:
|
||||
search_area_factor - The size of the search region relative to the target size.
|
||||
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
|
||||
square.
|
||||
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
|
||||
extracting the search region. See _get_jittered_box for how the jittering is done.
|
||||
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.search_area_factor = search_area_factor
|
||||
self.output_sz = output_sz
|
||||
self.center_jitter_factor = center_jitter_factor
|
||||
self.scale_jitter_factor = scale_jitter_factor
|
||||
self.mode = mode
|
||||
self.settings = settings
|
||||
|
||||
def _get_jittered_box(self, box, mode):
|
||||
""" Jitter the input box
|
||||
args:
|
||||
box - input bounding box
|
||||
mode - string 'template' or 'search' indicating template or search data
|
||||
|
||||
returns:
|
||||
torch.Tensor - jittered box
|
||||
"""
|
||||
|
||||
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
|
||||
max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode]).float())
|
||||
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
|
||||
|
||||
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
"""
|
||||
args:
|
||||
data - The input data, should contain the following fields:
|
||||
'template_images', search_images', 'template_anno', 'search_anno'
|
||||
returns:
|
||||
TensorDict - output data block with following fields:
|
||||
'template_images', 'search_images', 'template_anno', 'search_anno', 'test_proposals', 'proposal_iou'
|
||||
"""
|
||||
# Apply joint transforms
|
||||
if self.transform['joint'] is not None:
|
||||
data['template_images'], data['template_anno'], data['template_masks'] = self.transform['joint'](
|
||||
image=data['template_images'], bbox=data['template_anno'], mask=data['template_masks'])
|
||||
data['search_images'], data['search_anno'], data['search_masks'] = self.transform['joint'](
|
||||
image=data['search_images'], bbox=data['search_anno'], mask=data['search_masks'], new_roll=False)
|
||||
|
||||
for s in ['template', 'search']:
|
||||
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
|
||||
"In pair mode, num train/test frames must be 1"
|
||||
|
||||
# Add a uniform noise to the center pos
|
||||
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
|
||||
|
||||
# 2021.1.9 Check whether data is valid. Avoid too small bounding boxes
|
||||
w, h = torch.stack(jittered_anno, dim=0)[:, 2], torch.stack(jittered_anno, dim=0)[:, 3]
|
||||
|
||||
crop_sz = torch.ceil(torch.sqrt(w * h) * self.search_area_factor[s])
|
||||
if (crop_sz < 1).any():
|
||||
data['valid'] = False
|
||||
# print("Too small box is found. Replace it with new data.")
|
||||
return data
|
||||
|
||||
# Crop image region centered at jittered_anno box and get the attention mask
|
||||
crops, boxes, att_mask, mask_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno,
|
||||
data[s + '_anno'], self.search_area_factor[s],
|
||||
self.output_sz[s], masks=data[s + '_masks'])
|
||||
# Apply transforms
|
||||
data[s + '_images'], data[s + '_anno'], data[s + '_att'], data[s + '_masks'] = self.transform[s](
|
||||
image=crops, bbox=boxes, att=att_mask, mask=mask_crops, joint=False)
|
||||
|
||||
|
||||
# 2021.1.9 Check whether elements in data[s + '_att'] is all 1
|
||||
# Note that type of data[s + '_att'] is tuple, type of ele is torch.tensor
|
||||
for ele in data[s + '_att']:
|
||||
if (ele == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of original attention mask are all one. Replace it with new data.")
|
||||
return data
|
||||
# 2021.1.10 more strict conditions: require the donwsampled masks not to be all 1
|
||||
for ele in data[s + '_att']:
|
||||
feat_size = self.output_sz[s] // 16 # 16 is the backbone stride
|
||||
# (1,1,128,128) (1,1,256,256) --> (1,1,8,8) (1,1,16,16)
|
||||
mask_down = F.interpolate(ele[None, None].float(), size=feat_size).to(torch.bool)[0]
|
||||
if (mask_down == 1).all():
|
||||
data['valid'] = False
|
||||
# print("Values of down-sampled attention mask are all one. "
|
||||
# "Replace it with new data.")
|
||||
return data
|
||||
|
||||
data['valid'] = True
|
||||
# if we use copy-and-paste augmentation
|
||||
if data["template_masks"] is None or data["search_masks"] is None:
|
||||
data["template_masks"] = torch.zeros((1, self.output_sz["template"], self.output_sz["template"]))
|
||||
data["search_masks"] = torch.zeros((1, self.output_sz["search"], self.output_sz["search"]))
|
||||
# Prepare output
|
||||
if self.mode == 'sequence':
|
||||
data = data.apply(stack_tensors)
|
||||
else:
|
||||
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
|
||||
|
||||
return data
|
168
lib/train/data/processing_utils.py
Normal file
168
lib/train/data/processing_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
'''modified from the original test implementation
|
||||
Replace cv.BORDER_REPLICATE with cv.BORDER_CONSTANT
|
||||
Add a variable called att_mask for computing attention and positional encoding later'''
|
||||
|
||||
|
||||
def sample_target(im, target_bb, search_area_factor, output_sz=None, mask=None):
|
||||
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
|
||||
|
||||
args:
|
||||
im - cv image
|
||||
target_bb - target box [x, y, w, h]
|
||||
search_area_factor - Ratio of crop size to target size
|
||||
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
|
||||
|
||||
returns:
|
||||
cv image - extracted crop
|
||||
float - the factor by which the crop has been resized to make the crop size equal output_size
|
||||
"""
|
||||
if not isinstance(target_bb, list):
|
||||
x, y, w, h = target_bb.tolist()
|
||||
else:
|
||||
x, y, w, h = target_bb
|
||||
# Crop image
|
||||
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
|
||||
|
||||
if crop_sz < 1:
|
||||
raise Exception('Too small bounding box.')
|
||||
|
||||
x1 = round(x + 0.5 * w - crop_sz * 0.5)
|
||||
x2 = x1 + crop_sz
|
||||
|
||||
y1 = round(y + 0.5 * h - crop_sz * 0.5)
|
||||
y2 = y1 + crop_sz
|
||||
|
||||
x1_pad = max(0, -x1)
|
||||
x2_pad = max(x2 - im.shape[1] + 1, 0)
|
||||
|
||||
y1_pad = max(0, -y1)
|
||||
y2_pad = max(y2 - im.shape[0] + 1, 0)
|
||||
|
||||
# Crop target
|
||||
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
|
||||
if mask is not None:
|
||||
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
|
||||
|
||||
# Pad
|
||||
im_crop_padded = cv.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv.BORDER_CONSTANT)
|
||||
# deal with attention mask
|
||||
H, W, _ = im_crop_padded.shape
|
||||
att_mask = np.ones((H,W))
|
||||
end_x, end_y = -x2_pad, -y2_pad
|
||||
if y2_pad == 0:
|
||||
end_y = None
|
||||
if x2_pad == 0:
|
||||
end_x = None
|
||||
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
|
||||
if mask is not None:
|
||||
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
|
||||
|
||||
if output_sz is not None:
|
||||
resize_factor = output_sz / crop_sz
|
||||
im_crop_padded = cv.resize(im_crop_padded, (output_sz, output_sz))
|
||||
att_mask = cv.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
|
||||
if mask is None:
|
||||
return im_crop_padded, resize_factor, att_mask
|
||||
mask_crop_padded = \
|
||||
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
|
||||
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
|
||||
|
||||
else:
|
||||
if mask is None:
|
||||
return im_crop_padded, att_mask.astype(np.bool_), 1.0
|
||||
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
|
||||
|
||||
|
||||
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
|
||||
crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box_in - the box for which the co-ordinates are to be transformed
|
||||
box_extract - the box about which the image crop has been extracted.
|
||||
resize_factor - the ratio between the original image scale and the scale of the image crop
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]
|
||||
|
||||
box_in_center = box_in[0:2] + 0.5 * box_in[2:4]
|
||||
|
||||
box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
|
||||
box_out_wh = box_in[2:4] * resize_factor
|
||||
|
||||
box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
||||
|
||||
def jittered_center_crop(frames, box_extract, box_gt, search_area_factor, output_sz, masks=None):
|
||||
""" For each frame in frames, extracts a square crop centered at box_extract, of area search_area_factor^2
|
||||
times box_extract area. The extracted crops are then resized to output_sz. Further, the co-ordinates of the box
|
||||
box_gt are transformed to the image crop co-ordinates
|
||||
|
||||
args:
|
||||
frames - list of frames
|
||||
box_extract - list of boxes of same length as frames. The crops are extracted using anno_extract
|
||||
box_gt - list of boxes of same length as frames. The co-ordinates of these boxes are transformed from
|
||||
image co-ordinates to the crop co-ordinates
|
||||
search_area_factor - The area of the extracted crop is search_area_factor^2 times box_extract area
|
||||
output_sz - The size to which the extracted crops are resized
|
||||
|
||||
returns:
|
||||
list - list of image crops
|
||||
list - box_gt location in the crop co-ordinates
|
||||
"""
|
||||
|
||||
if masks is None:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz)
|
||||
for f, a in zip(frames, box_extract)]
|
||||
frames_crop, resize_factors, att_mask = zip(*crops_resize_factors)
|
||||
masks_crop = None
|
||||
else:
|
||||
crops_resize_factors = [sample_target(f, a, search_area_factor, output_sz, m)
|
||||
for f, a, m in zip(frames, box_extract, masks)]
|
||||
frames_crop, resize_factors, att_mask, masks_crop = zip(*crops_resize_factors)
|
||||
# frames_crop: tuple of ndarray (128,128,3), att_mask: tuple of ndarray (128,128)
|
||||
crop_sz = torch.Tensor([output_sz, output_sz])
|
||||
|
||||
# find the bb location in the crop
|
||||
'''Note that here we use normalized coord'''
|
||||
box_crop = [transform_image_to_crop(a_gt, a_ex, rf, crop_sz, normalize=True)
|
||||
for a_gt, a_ex, rf in zip(box_gt, box_extract, resize_factors)] # (x1,y1,w,h) list of tensors
|
||||
|
||||
return frames_crop, box_crop, att_mask, masks_crop
|
||||
|
||||
|
||||
def transform_box_to_crop(box: torch.Tensor, crop_box: torch.Tensor, crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
|
||||
""" Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image
|
||||
args:
|
||||
box - the box for which the co-ordinates are to be transformed
|
||||
crop_box - bounding box defining the crop in the original image
|
||||
crop_sz - size of the cropped image
|
||||
|
||||
returns:
|
||||
torch.Tensor - transformed co-ordinates of box_in
|
||||
"""
|
||||
|
||||
box_out = box.clone()
|
||||
box_out[:2] -= crop_box[:2]
|
||||
|
||||
scale_factor = crop_sz / crop_box[2:]
|
||||
|
||||
box_out[:2] *= scale_factor
|
||||
box_out[2:] *= scale_factor
|
||||
if normalize:
|
||||
return box_out / crop_sz[0]
|
||||
else:
|
||||
return box_out
|
||||
|
349
lib/train/data/sampler.py
Normal file
349
lib/train/data/sampler.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
from lib.utils import TensorDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def no_processing(data):
|
||||
return data
|
||||
|
||||
|
||||
class TrackingSampler(torch.utils.data.Dataset):
|
||||
""" Class responsible for sampling frames from training sequences to form batches.
|
||||
|
||||
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
|
||||
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
|
||||
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
|
||||
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
|
||||
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.
|
||||
|
||||
The sampled frames are then passed through the input 'processing' function for the necessary processing-
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, processing=no_processing, frame_sample_mode='causal',
|
||||
train_cls=False, pos_prob=0.5):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the test frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.train_cls = train_cls # whether we are training classification
|
||||
self.pos_prob = pos_prob # probability of sampling positive class when making classification
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.processing = processing
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None,
|
||||
allow_invisible=False, force_invisible=False):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
# get valid ids
|
||||
if force_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id) if not visible[i]]
|
||||
else:
|
||||
if allow_invisible:
|
||||
valid_ids = [i for i in range(min_id, max_id)]
|
||||
else:
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.train_cls:
|
||||
return self.getitem_cls()
|
||||
else:
|
||||
return self.getitem()
|
||||
|
||||
def getitem(self):
|
||||
"""
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
|
||||
if is_video_dataset:
|
||||
template_frame_ids = None
|
||||
search_frame_ids = None
|
||||
gap_increase = 0
|
||||
|
||||
if self.frame_sample_mode == 'causal':
|
||||
# Sample test and train frames in a causal manner, i.e. search_frame_ids > template_frame_ids
|
||||
while search_frame_ids is None:
|
||||
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_template_frames - 1,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_template_frames - 1,
|
||||
min_id=base_frame_id[0] - self.max_gap - gap_increase,
|
||||
max_id=base_frame_id[0])
|
||||
if prev_frame_ids is None:
|
||||
gap_increase += 5
|
||||
continue
|
||||
template_frame_ids = base_frame_id + prev_frame_ids
|
||||
search_frame_ids = self._sample_visible_ids(visible, min_id=template_frame_ids[0] + 1,
|
||||
max_id=template_frame_ids[0] + self.max_gap + gap_increase,
|
||||
num_ids=self.num_search_frames)
|
||||
# Increase gap until a frame is found
|
||||
gap_increase += 5
|
||||
|
||||
elif self.frame_sample_mode == "trident" or self.frame_sample_mode == "trident_pro":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("Illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros((H, W))] * self.num_template_frames
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros((H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def getitem_cls(self):
|
||||
# get data for classification
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
aux (bool): whether the current data is for auxiliary use (e.g. copy-and-paste)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
valid = False
|
||||
label = None
|
||||
while not valid:
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# sample a sequence from the given dataset
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample template and search frame ids
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode in ["trident", "trident_pro"]:
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_trident(visible)
|
||||
elif self.frame_sample_mode == "stark":
|
||||
template_frame_ids, search_frame_ids = self.get_frame_ids_stark(visible, seq_info_dict["valid"])
|
||||
else:
|
||||
raise ValueError("illegal frame sample mode")
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
try:
|
||||
# "try" is used to handle trackingnet data failure
|
||||
# get images and bounding boxes (for templates)
|
||||
template_frames, template_anno, meta_obj_train = dataset.get_frames(seq_id, template_frame_ids,
|
||||
seq_info_dict)
|
||||
H, W, _ = template_frames[0].shape
|
||||
template_masks = template_anno['mask'] if 'mask' in template_anno else [torch.zeros(
|
||||
(H, W))] * self.num_template_frames
|
||||
# get images and bounding boxes (for searches)
|
||||
# positive samples
|
||||
if random.random() < self.pos_prob:
|
||||
label = torch.ones(1,)
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
# negative samples
|
||||
else:
|
||||
label = torch.zeros(1,)
|
||||
if is_video_dataset:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, force_invisible=True)
|
||||
if search_frame_ids is None:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids,
|
||||
seq_info_dict)
|
||||
search_anno["bbox"] = [self.get_center_box(H, W)]
|
||||
else:
|
||||
search_frames, search_anno, meta_obj_test = self.get_one_search()
|
||||
H, W, _ = search_frames[0].shape
|
||||
search_masks = search_anno['mask'] if 'mask' in search_anno else [torch.zeros(
|
||||
(H, W))] * self.num_search_frames
|
||||
|
||||
data = TensorDict({'template_images': template_frames,
|
||||
'template_anno': template_anno['bbox'],
|
||||
'template_masks': template_masks,
|
||||
'search_images': search_frames,
|
||||
'search_anno': search_anno['bbox'],
|
||||
'search_masks': search_masks,
|
||||
'dataset': dataset.get_name(),
|
||||
'test_class': meta_obj_test.get('object_class_name')})
|
||||
|
||||
# make data augmentation
|
||||
data = self.processing(data)
|
||||
# add classification label
|
||||
data["label"] = label
|
||||
# check whether data is valid
|
||||
valid = data['valid']
|
||||
except:
|
||||
valid = False
|
||||
|
||||
return data
|
||||
|
||||
def get_center_box(self, H, W, ratio=1/8):
|
||||
cx, cy, w, h = W/2, H/2, W * ratio, H * ratio
|
||||
return torch.tensor([int(cx-w/2), int(cy-h/2), int(w), int(h)])
|
||||
|
||||
def sample_seq_from_dataset(self, dataset, is_video_dataset):
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= 20
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
return seq_id, visible, seq_info_dict
|
||||
|
||||
def get_one_search(self):
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
# sample a sequence
|
||||
seq_id, visible, seq_info_dict = self.sample_seq_from_dataset(dataset, is_video_dataset)
|
||||
# sample a frame
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == "stark":
|
||||
search_frame_ids = self._sample_visible_ids(seq_info_dict["valid"], num_ids=1)
|
||||
else:
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, allow_invisible=True)
|
||||
else:
|
||||
search_frame_ids = [1]
|
||||
# get the image, bounding box and other info
|
||||
search_frames, search_anno, meta_obj_test = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
|
||||
return search_frames, search_anno, meta_obj_test
|
||||
|
||||
def get_frame_ids_trident(self, visible):
|
||||
# get template and search ids in a 'trident' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
if self.frame_sample_mode == "trident_pro":
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id,
|
||||
allow_invisible=True)
|
||||
else:
|
||||
f_id = self._sample_visible_ids(visible, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
def get_frame_ids_stark(self, visible, valid):
|
||||
# get template and search ids in a 'stark' manner
|
||||
template_frame_ids_extra = []
|
||||
while None in template_frame_ids_extra or len(template_frame_ids_extra) == 0:
|
||||
template_frame_ids_extra = []
|
||||
# first randomly sample two frames from a video
|
||||
template_frame_id1 = self._sample_visible_ids(visible, num_ids=1) # the initial template id
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1) # the search region id
|
||||
# get the dynamic template id
|
||||
for max_gap in self.max_gap:
|
||||
if template_frame_id1[0] >= search_frame_ids[0]:
|
||||
min_id, max_id = search_frame_ids[0], search_frame_ids[0] + max_gap
|
||||
else:
|
||||
min_id, max_id = search_frame_ids[0] - max_gap, search_frame_ids[0]
|
||||
"""we require the frame to be valid but not necessary visible"""
|
||||
f_id = self._sample_visible_ids(valid, num_ids=1, min_id=min_id, max_id=max_id)
|
||||
if f_id is None:
|
||||
template_frame_ids_extra += [None]
|
||||
else:
|
||||
template_frame_ids_extra += f_id
|
||||
|
||||
template_frame_ids = template_frame_id1 + template_frame_ids_extra
|
||||
return template_frame_ids, search_frame_ids
|
265
lib/train/data/sequence_sampler.py
Normal file
265
lib/train/data/sequence_sampler.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import random
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class SequenceSampler(torch.utils.data.Dataset):
|
||||
"""
|
||||
Sample sequence for sequence-level training
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
|
||||
num_search_frames, num_template_frames=1, frame_sample_mode='sequential', max_interval=10, prob=0.7):
|
||||
"""
|
||||
args:
|
||||
datasets - List of datasets to be used for training
|
||||
p_datasets - List containing the probabilities by which each dataset will be sampled
|
||||
samples_per_epoch - Number of training samples per epoch
|
||||
max_gap - Maximum gap, in frame numbers, between the train frames and the search frames.\
|
||||
max_interval - Maximum interval between sampled frames
|
||||
num_search_frames - Number of search frames to sample.
|
||||
num_template_frames - Number of template frames to sample.
|
||||
processing - An instance of Processing class which performs the necessary processing of the data.
|
||||
frame_sample_mode - Either 'causal' or 'interval'. If 'causal', then the search frames are sampled in a causally,
|
||||
otherwise randomly within the interval.
|
||||
prob - sequential sampling by prob / interval sampling by 1-prob
|
||||
"""
|
||||
self.datasets = datasets
|
||||
|
||||
# If p not provided, sample uniformly from all videos
|
||||
if p_datasets is None:
|
||||
p_datasets = [len(d) for d in self.datasets]
|
||||
|
||||
# Normalize
|
||||
p_total = sum(p_datasets)
|
||||
self.p_datasets = [x / p_total for x in p_datasets]
|
||||
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
self.num_search_frames = num_search_frames
|
||||
self.num_template_frames = num_template_frames
|
||||
self.frame_sample_mode = frame_sample_mode
|
||||
self.prob=prob
|
||||
self.extra=1
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_epoch
|
||||
|
||||
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
|
||||
""" Samples num_ids frames between min_id and max_id for which target is visible
|
||||
|
||||
args:
|
||||
visible - 1d Tensor indicating whether target is visible for each frame
|
||||
num_ids - number of frames to be samples
|
||||
min_id - Minimum allowed frame number
|
||||
max_id - Maximum allowed frame number
|
||||
|
||||
returns:
|
||||
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
|
||||
"""
|
||||
if num_ids == 0:
|
||||
return []
|
||||
if min_id is None or min_id < 0:
|
||||
min_id = 0
|
||||
if max_id is None or max_id > len(visible):
|
||||
max_id = len(visible)
|
||||
|
||||
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
|
||||
|
||||
# No visible ids
|
||||
if len(valid_ids) == 0:
|
||||
return None
|
||||
|
||||
return random.choices(valid_ids, k=num_ids)
|
||||
|
||||
|
||||
def _sequential_sample(self, visible):
|
||||
# Sample frames in sequential manner
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - self.num_search_frames)
|
||||
if self.max_gap == -1:
|
||||
left = template_frame_ids[0]
|
||||
else:
|
||||
# template frame (1) ->(max_gap) -> search frame (num_search_frames)
|
||||
left_max = min(len(visible) - self.num_search_frames, template_frame_ids[0] + self.max_gap)
|
||||
left = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)[0]
|
||||
|
||||
valid_ids = [i for i in range(left, len(visible)) if visible[i]]
|
||||
search_frame_ids = valid_ids[:self.num_search_frames]
|
||||
|
||||
# if length is not enough
|
||||
last = search_frame_ids[-1]
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
if last >= len(visible) - 1:
|
||||
search_frame_ids.append(last)
|
||||
else:
|
||||
last += 1
|
||||
if visible[last]:
|
||||
search_frame_ids.append(last)
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def _random_interval_sample(self, visible):
|
||||
# Get valid ids
|
||||
valid_ids = [i for i in range(len(visible)) if visible[i]]
|
||||
|
||||
# Sample template frame
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
template_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=0,
|
||||
max_id=len(visible) - avg_interval * (self.num_search_frames - 1))
|
||||
if template_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == 0:
|
||||
template_frame_ids = [valid_ids[0]]
|
||||
break
|
||||
|
||||
# Sample first search frame
|
||||
if self.max_gap == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
else:
|
||||
avg_interval = self.max_interval
|
||||
while avg_interval * (self.num_search_frames - 1) > len(visible):
|
||||
avg_interval = max(avg_interval - 1, 1)
|
||||
|
||||
while True:
|
||||
left_max = min(max(len(visible) - avg_interval * (self.num_search_frames - 1), template_frame_ids[0] + 1),
|
||||
template_frame_ids[0] + self.max_gap)
|
||||
search_frame_ids = self._sample_visible_ids(visible, num_ids=1, min_id=template_frame_ids[0],
|
||||
max_id=left_max)
|
||||
|
||||
if search_frame_ids == None:
|
||||
avg_interval = avg_interval - 1
|
||||
else:
|
||||
break
|
||||
|
||||
if avg_interval == -1:
|
||||
search_frame_ids = template_frame_ids
|
||||
break
|
||||
|
||||
# Sample rest of the search frames with random interval
|
||||
last = search_frame_ids[0]
|
||||
while last <= len(visible) - 1 and len(search_frame_ids) < self.num_search_frames:
|
||||
# sample id with interval
|
||||
max_id = min(last + self.max_interval + 1, len(visible))
|
||||
id = self._sample_visible_ids(visible, num_ids=1, min_id=last,
|
||||
max_id=max_id)
|
||||
|
||||
if id is None:
|
||||
# If not found in current range, find from previous range
|
||||
last = last + self.max_interval
|
||||
else:
|
||||
search_frame_ids.append(id[0])
|
||||
last = search_frame_ids[-1]
|
||||
|
||||
# if length is not enough, randomly sample new ids
|
||||
if len(search_frame_ids) < self.num_search_frames:
|
||||
valid_ids = [x for x in valid_ids if x > search_frame_ids[0] and x not in search_frame_ids]
|
||||
|
||||
if len(valid_ids) > 0:
|
||||
new_ids = random.choices(valid_ids, k=min(len(valid_ids),
|
||||
self.num_search_frames - len(search_frame_ids)))
|
||||
search_frame_ids = search_frame_ids + new_ids
|
||||
search_frame_ids = sorted(search_frame_ids, key=int)
|
||||
|
||||
# if length is still not enough, duplicate last frame
|
||||
while len(search_frame_ids) < self.num_search_frames:
|
||||
search_frame_ids.append(search_frame_ids[-1])
|
||||
|
||||
for i in range(1, self.num_search_frames):
|
||||
if search_frame_ids[i] - search_frame_ids[i - 1] > self.max_interval:
|
||||
print(search_frame_ids[i] - search_frame_ids[i - 1])
|
||||
|
||||
return template_frame_ids, search_frame_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
args:
|
||||
index (int): Index (Ignored since we sample randomly)
|
||||
|
||||
returns:
|
||||
TensorDict - dict containing all the data blocks
|
||||
"""
|
||||
|
||||
# Select a dataset
|
||||
dataset = random.choices(self.datasets, self.p_datasets)[0]
|
||||
if dataset.get_name() == 'got10k' :
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
else:
|
||||
max_gap = self.max_gap
|
||||
max_interval = self.max_interval
|
||||
self.max_gap = max_gap * self.extra
|
||||
self.max_interval = max_interval * self.extra
|
||||
|
||||
is_video_dataset = dataset.is_video_sequence()
|
||||
|
||||
# Sample a sequence with enough visible frames
|
||||
while True:
|
||||
try:
|
||||
enough_visible_frames = False
|
||||
while not enough_visible_frames:
|
||||
# Sample a sequence
|
||||
seq_id = random.randint(0, dataset.get_num_sequences() - 1)
|
||||
|
||||
# Sample frames
|
||||
seq_info_dict = dataset.get_sequence_info(seq_id)
|
||||
visible = seq_info_dict['visible']
|
||||
|
||||
enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
|
||||
self.num_search_frames + self.num_template_frames) and len(visible) >= (self.num_search_frames + self.num_template_frames)
|
||||
|
||||
enough_visible_frames = enough_visible_frames or not is_video_dataset
|
||||
|
||||
if is_video_dataset:
|
||||
if self.frame_sample_mode == 'sequential':
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
|
||||
elif self.frame_sample_mode == 'random_interval':
|
||||
if random.random() < self.prob:
|
||||
template_frame_ids, search_frame_ids = self._random_interval_sample(visible)
|
||||
else:
|
||||
template_frame_ids, search_frame_ids = self._sequential_sample(visible)
|
||||
else:
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# In case of image dataset, just repeat the image to generate synthetic video
|
||||
template_frame_ids = [1] * self.num_template_frames
|
||||
search_frame_ids = [1] * self.num_search_frames
|
||||
#print(dataset.get_name(), search_frame_ids, self.max_gap, self.max_interval)
|
||||
self.max_gap = max_gap
|
||||
self.max_interval = max_interval
|
||||
#print(self.max_gap, self.max_interval)
|
||||
template_frames, template_anno, meta_obj_template = dataset.get_frames(seq_id, template_frame_ids, seq_info_dict)
|
||||
search_frames, search_anno, meta_obj_search = dataset.get_frames(seq_id, search_frame_ids, seq_info_dict)
|
||||
template_bbox = [bbox.numpy() for bbox in template_anno['bbox']] # tensor -> numpy array
|
||||
search_bbox = [bbox.numpy() for bbox in search_anno['bbox']] # tensor -> numpy array
|
||||
# print("====================================================================================")
|
||||
# print("dataset index: {}".format(index))
|
||||
# print("seq_id: {}".format(seq_id))
|
||||
# print('template_frame_ids: {}'.format(template_frame_ids))
|
||||
# print('search_frame_ids: {}'.format(search_frame_ids))
|
||||
return TensorDict({'template_images': np.array(template_frames).squeeze(), # 1 template images
|
||||
'template_annos': np.array(template_bbox).squeeze(),
|
||||
'search_images': np.array(search_frames), # (num_frames) search images
|
||||
'search_annos': np.array(search_bbox),
|
||||
'seq_id': seq_id,
|
||||
'dataset': dataset.get_name(),
|
||||
'search_class': meta_obj_search.get('object_class_name'),
|
||||
'num_frames': len(search_frames)
|
||||
})
|
||||
except Exception:
|
||||
pass
|
335
lib/train/data/transforms.py
Normal file
335
lib/train/data/transforms.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as tvisf
|
||||
|
||||
|
||||
class Transform:
|
||||
"""A set of transformations, used for e.g. data augmentation.
|
||||
Args of constructor:
|
||||
transforms: An arbitrary number of transformations, derived from the TransformBase class.
|
||||
They are applied in the order they are given.
|
||||
|
||||
The Transform object can jointly transform images, bounding boxes and segmentation masks.
|
||||
This is done by calling the object with the following key-word arguments (all are optional).
|
||||
|
||||
The following arguments are inputs to be transformed. They are either supplied as a single instance, or a list of instances.
|
||||
image - Image
|
||||
coords - 2xN dimensional Tensor of 2D image coordinates [y, x]
|
||||
bbox - Bounding box on the form [x, y, w, h]
|
||||
mask - Segmentation mask with discrete classes
|
||||
|
||||
The following parameters can be supplied with calling the transform object:
|
||||
joint [Bool] - If True then transform all images/coords/bbox/mask in the list jointly using the same transformation.
|
||||
Otherwise each tuple (images, coords, bbox, mask) will be transformed independently using
|
||||
different random rolls. Default: True.
|
||||
new_roll [Bool] - If False, then no new random roll is performed, and the saved result from the previous roll
|
||||
is used instead. Default: True.
|
||||
|
||||
Check the DiMPProcessing class for examples.
|
||||
"""
|
||||
|
||||
def __init__(self, *transforms):
|
||||
if len(transforms) == 1 and isinstance(transforms[0], (list, tuple)):
|
||||
transforms = transforms[0]
|
||||
self.transforms = transforms
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['joint', 'new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
|
||||
def __call__(self, **inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
for v in inputs.keys():
|
||||
if v not in self._valid_all:
|
||||
raise ValueError('Incorrect input \"{}\" to transform. Only supports inputs {} and arguments {}.'.format(v, self._valid_inputs, self._valid_args))
|
||||
|
||||
joint_mode = inputs.get('joint', True)
|
||||
new_roll = inputs.get('new_roll', True)
|
||||
|
||||
if not joint_mode:
|
||||
out = zip(*[self(**inp) for inp in self._split_inputs(inputs)])
|
||||
return tuple(list(o) for o in out)
|
||||
|
||||
out = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
|
||||
for t in self.transforms:
|
||||
out = t(**out, joint=joint_mode, new_roll=new_roll)
|
||||
if len(var_names) == 1:
|
||||
return out[var_names[0]]
|
||||
# Make sure order is correct
|
||||
return tuple(out[v] for v in var_names)
|
||||
|
||||
def _split_inputs(self, inputs):
|
||||
var_names = [k for k in inputs.keys() if k in self._valid_inputs]
|
||||
split_inputs = [{k: v for k, v in zip(var_names, vals)} for vals in zip(*[inputs[vn] for vn in var_names])]
|
||||
for arg_name, arg_val in filter(lambda it: it[0]!='joint' and it[0] in self._valid_args, inputs.items()):
|
||||
if isinstance(arg_val, list):
|
||||
for inp, av in zip(split_inputs, arg_val):
|
||||
inp[arg_name] = av
|
||||
else:
|
||||
for inp in split_inputs:
|
||||
inp[arg_name] = arg_val
|
||||
return split_inputs
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class TransformBase:
|
||||
"""Base class for transformation objects. See the Transform class for details."""
|
||||
def __init__(self):
|
||||
"""2020.12.24 Add 'att' to valid inputs"""
|
||||
self._valid_inputs = ['image', 'coords', 'bbox', 'mask', 'att']
|
||||
self._valid_args = ['new_roll']
|
||||
self._valid_all = self._valid_inputs + self._valid_args
|
||||
self._rand_params = None
|
||||
|
||||
def __call__(self, **inputs):
|
||||
# Split input
|
||||
input_vars = {k: v for k, v in inputs.items() if k in self._valid_inputs}
|
||||
input_args = {k: v for k, v in inputs.items() if k in self._valid_args}
|
||||
|
||||
# Roll random parameters for the transform
|
||||
if input_args.get('new_roll', True):
|
||||
rand_params = self.roll()
|
||||
if rand_params is None:
|
||||
rand_params = ()
|
||||
elif not isinstance(rand_params, tuple):
|
||||
rand_params = (rand_params,)
|
||||
self._rand_params = rand_params
|
||||
|
||||
outputs = dict()
|
||||
for var_name, var in input_vars.items():
|
||||
if var is not None:
|
||||
transform_func = getattr(self, 'transform_' + var_name)
|
||||
if var_name in ['coords', 'bbox']:
|
||||
params = (self._get_image_size(input_vars),) + self._rand_params
|
||||
else:
|
||||
params = self._rand_params
|
||||
if isinstance(var, (list, tuple)):
|
||||
outputs[var_name] = [transform_func(x, *params) for x in var]
|
||||
else:
|
||||
outputs[var_name] = transform_func(var, *params)
|
||||
return outputs
|
||||
|
||||
def _get_image_size(self, inputs):
|
||||
im = None
|
||||
for var_name in ['image', 'mask']:
|
||||
if inputs.get(var_name) is not None:
|
||||
im = inputs[var_name]
|
||||
break
|
||||
if im is None:
|
||||
return None
|
||||
if isinstance(im, (list, tuple)):
|
||||
im = im[0]
|
||||
if isinstance(im, np.ndarray):
|
||||
return im.shape[:2]
|
||||
if torch.is_tensor(im):
|
||||
return (im.shape[-2], im.shape[-1])
|
||||
raise Exception('Unknown image type')
|
||||
|
||||
def roll(self):
|
||||
return None
|
||||
|
||||
def transform_image(self, image, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return coords
|
||||
|
||||
def transform_bbox(self, bbox, image_shape, *rand_params):
|
||||
"""Assumes [x, y, w, h]"""
|
||||
# Check if not overloaded
|
||||
if self.transform_coords.__code__ == TransformBase.transform_coords.__code__:
|
||||
return bbox
|
||||
|
||||
coord = bbox.clone().view(-1,2).t().flip(0)
|
||||
|
||||
x1 = coord[1, 0]
|
||||
x2 = coord[1, 0] + coord[1, 1]
|
||||
|
||||
y1 = coord[0, 0]
|
||||
y2 = coord[0, 0] + coord[0, 1]
|
||||
|
||||
coord_all = torch.tensor([[y1, y1, y2, y2], [x1, x2, x2, x1]])
|
||||
|
||||
coord_transf = self.transform_coords(coord_all, image_shape, *rand_params).flip(0)
|
||||
tl = torch.min(coord_transf, dim=1)[0]
|
||||
sz = torch.max(coord_transf, dim=1)[0] - tl
|
||||
bbox_out = torch.cat((tl, sz), dim=-1).reshape(bbox.shape)
|
||||
return bbox_out
|
||||
|
||||
def transform_mask(self, mask, *rand_params):
|
||||
"""Must be deterministic"""
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, *rand_params):
|
||||
"""2020.12.24 Added to deal with attention masks"""
|
||||
return att
|
||||
|
||||
|
||||
class ToTensor(TransformBase):
|
||||
"""Convert to a Tensor"""
|
||||
|
||||
def transform_image(self, image):
|
||||
# handle numpy array
|
||||
if image.ndim == 2:
|
||||
image = image[:, :, None]
|
||||
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(image, torch.ByteTensor):
|
||||
return image.float().div(255)
|
||||
else:
|
||||
return image
|
||||
|
||||
def transfrom_mask(self, mask):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
|
||||
def transform_att(self, att):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class ToTensorAndJitter(TransformBase):
|
||||
"""Convert to a Tensor and jitter brightness"""
|
||||
def __init__(self, brightness_jitter=0.0, normalize=True):
|
||||
super().__init__()
|
||||
self.brightness_jitter = brightness_jitter
|
||||
self.normalize = normalize
|
||||
|
||||
def roll(self):
|
||||
return np.random.uniform(max(0, 1 - self.brightness_jitter), 1 + self.brightness_jitter)
|
||||
|
||||
def transform_image(self, image, brightness_factor):
|
||||
# handle numpy array
|
||||
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
||||
|
||||
# backward compatibility
|
||||
if self.normalize:
|
||||
return image.float().mul(brightness_factor/255.0).clamp(0.0, 1.0)
|
||||
else:
|
||||
return image.float().mul(brightness_factor).clamp(0.0, 255.0)
|
||||
|
||||
def transform_mask(self, mask, brightness_factor):
|
||||
if isinstance(mask, np.ndarray):
|
||||
return torch.from_numpy(mask)
|
||||
else:
|
||||
return mask
|
||||
def transform_att(self, att, brightness_factor):
|
||||
if isinstance(att, np.ndarray):
|
||||
return torch.from_numpy(att).to(torch.bool)
|
||||
elif isinstance(att, torch.Tensor):
|
||||
return att.to(torch.bool)
|
||||
else:
|
||||
raise ValueError ("dtype must be np.ndarray or torch.Tensor")
|
||||
|
||||
|
||||
class Normalize(TransformBase):
|
||||
"""Normalize image"""
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
super().__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def transform_image(self, image):
|
||||
return tvisf.normalize(image, self.mean, self.std, self.inplace)
|
||||
|
||||
|
||||
class ToGrayscale(TransformBase):
|
||||
"""Converts image to grayscale with probability"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
self.color_weights = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32)
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_grayscale):
|
||||
if do_grayscale:
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_gray = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
|
||||
return np.stack([img_gray, img_gray, img_gray], axis=2)
|
||||
# return np.repeat(np.sum(img * self.color_weights, axis=2, keepdims=True).astype(np.uint8), 3, axis=2)
|
||||
return image
|
||||
|
||||
|
||||
class ToBGR(TransformBase):
|
||||
"""Converts image to BGR"""
|
||||
def transform_image(self, image):
|
||||
if torch.is_tensor(image):
|
||||
raise NotImplementedError('Implement torch variant.')
|
||||
img_bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
||||
return img_bgr
|
||||
|
||||
|
||||
class RandomHorizontalFlip(TransformBase):
|
||||
"""Horizontally flip image randomly with a probability p."""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def roll(self):
|
||||
return random.random() < self.probability
|
||||
|
||||
def transform_image(self, image, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(image):
|
||||
return image.flip((2,))
|
||||
return np.fliplr(image).copy()
|
||||
return image
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = (image_shape[1] - 1) - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
||||
|
||||
def transform_mask(self, mask, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(mask):
|
||||
return mask.flip((-1,))
|
||||
return np.fliplr(mask).copy()
|
||||
return mask
|
||||
|
||||
def transform_att(self, att, do_flip):
|
||||
if do_flip:
|
||||
if torch.is_tensor(att):
|
||||
return att.flip((-1,))
|
||||
return np.fliplr(att).copy()
|
||||
return att
|
||||
|
||||
|
||||
class RandomHorizontalFlip_Norm(RandomHorizontalFlip):
|
||||
"""Horizontally flip image randomly with a probability p.
|
||||
The difference is that the coord is normalized to [0,1]"""
|
||||
def __init__(self, probability = 0.5):
|
||||
super().__init__()
|
||||
self.probability = probability
|
||||
|
||||
def transform_coords(self, coords, image_shape, do_flip):
|
||||
"""we should use 1 rather than image_shape"""
|
||||
if do_flip:
|
||||
coords_flip = coords.clone()
|
||||
coords_flip[1,:] = 1 - coords[1,:]
|
||||
return coords_flip
|
||||
return coords
|
33
lib/train/data/wandb_logger.py
Normal file
33
lib/train/data/wandb_logger.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please run "pip install wandb" to install wandb')
|
||||
|
||||
|
||||
class WandbWriter:
|
||||
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
|
||||
self.wandb = wandb
|
||||
self.step = cur_step
|
||||
self.interval = step_interval
|
||||
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
|
||||
|
||||
def write_log(self, stats: OrderedDict, epoch=-1):
|
||||
self.step += 1
|
||||
for loader_name, loader_stats in stats.items():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
|
||||
log_dict = {}
|
||||
for var_name, val in loader_stats.items():
|
||||
if hasattr(val, 'avg'):
|
||||
log_dict.update({loader_name + '/' + var_name: val.avg})
|
||||
else:
|
||||
log_dict.update({loader_name + '/' + var_name: val.val})
|
||||
|
||||
if epoch >= 0:
|
||||
log_dict.update({loader_name + '/epoch': epoch})
|
||||
|
||||
self.wandb.log(log_dict, step=self.step*self.interval)
|
16
lib/train/data_specs/README.md
Normal file
16
lib/train/data_specs/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# README
|
||||
|
||||
## Description for different text files
|
||||
GOT10K
|
||||
- got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos)
|
||||
- got10k_train_split.txt: part of videos from the GOT-10K training set
|
||||
- got10k_val_split.txt: another part of videos from the GOT-10K training set
|
||||
- got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html))
|
||||
- got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set
|
||||
- got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set
|
||||
|
||||
LaSOT
|
||||
- lasot_train_split.txt: the complete LaSOT training set
|
||||
|
||||
TrackingNnet
|
||||
- trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet
|
9335
lib/train/data_specs/got10k_train_full_split.txt
Normal file
9335
lib/train/data_specs/got10k_train_full_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
7934
lib/train/data_specs/got10k_train_split.txt
Normal file
7934
lib/train/data_specs/got10k_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1401
lib/train/data_specs/got10k_val_split.txt
Normal file
1401
lib/train/data_specs/got10k_val_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1000
lib/train/data_specs/got10k_vot_exclude.txt
Normal file
1000
lib/train/data_specs/got10k_vot_exclude.txt
Normal file
File diff suppressed because it is too large
Load Diff
7086
lib/train/data_specs/got10k_vot_train_split.txt
Normal file
7086
lib/train/data_specs/got10k_vot_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1249
lib/train/data_specs/got10k_vot_val_split.txt
Normal file
1249
lib/train/data_specs/got10k_vot_val_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
1120
lib/train/data_specs/lasot_train_split.txt
Normal file
1120
lib/train/data_specs/lasot_train_split.txt
Normal file
File diff suppressed because it is too large
Load Diff
30134
lib/train/data_specs/trackingnet_classmap.txt
Normal file
30134
lib/train/data_specs/trackingnet_classmap.txt
Normal file
File diff suppressed because it is too large
Load Diff
437
lib/train/dataset/COCO_tool.py
Normal file
437
lib/train/dataset/COCO_tool.py
Normal file
@@ -0,0 +1,437 @@
|
||||
__author__ = 'tylin'
|
||||
__version__ = '2.0'
|
||||
# Interface for accessing the Microsoft COCO dataset.
|
||||
|
||||
# Microsoft COCO is a large image dataset designed for object detection,
|
||||
# segmentation, and caption generation. pycocotools is a Python API that
|
||||
# assists in loading, parsing and visualizing the annotations in COCO.
|
||||
# Please visit http://mscoco.org/ for more information on COCO, including
|
||||
# for the data, paper, and tutorials. The exact format of the annotations
|
||||
# is also described on the COCO website. For example usage of the pycocotools
|
||||
# please see pycocotools_demo.ipynb. In addition to this API, please download both
|
||||
# the COCO images and annotations in order to run the demo.
|
||||
|
||||
# An alternative to using the API is to load the annotations directly
|
||||
# into Python dictionary
|
||||
# Using the API provides additional utility functions. Note that this API
|
||||
# supports both *instance* and *caption* annotations. In the case of
|
||||
# captions not all functions are defined (e.g. categories are undefined).
|
||||
|
||||
# The following API functions are defined:
|
||||
# COCO - COCO api class that loads COCO annotation file and prepare data structures.
|
||||
# decodeMask - Decode binary mask M encoded via run-length encoding.
|
||||
# encodeMask - Encode binary mask M using run-length encoding.
|
||||
# getAnnIds - Get ann ids that satisfy given filter conditions.
|
||||
# getCatIds - Get cat ids that satisfy given filter conditions.
|
||||
# getImgIds - Get img ids that satisfy given filter conditions.
|
||||
# loadAnns - Load anns with the specified ids.
|
||||
# loadCats - Load cats with the specified ids.
|
||||
# loadImgs - Load imgs with the specified ids.
|
||||
# annToMask - Convert segmentation in an annotation to binary mask.
|
||||
# showAnns - Display the specified annotations.
|
||||
# loadRes - Load algorithm results and create API for accessing them.
|
||||
# download - Download COCO images from mscoco.org server.
|
||||
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
|
||||
# Help on each functions can be accessed by: "help COCO>function".
|
||||
|
||||
# See also COCO>decodeMask,
|
||||
# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
|
||||
# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
|
||||
# COCO>loadImgs, COCO>annToMask, COCO>showAnns
|
||||
|
||||
# Microsoft COCO Toolbox. version 2.0
|
||||
# Data, paper, and tutorials available at: http://mscoco.org/
|
||||
# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
|
||||
# Licensed under the Simplified BSD License [see bsd.txt]
|
||||
|
||||
import json
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import Polygon
|
||||
import numpy as np
|
||||
import copy
|
||||
import itertools
|
||||
from pycocotools import mask as maskUtils
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
PYTHON_VERSION = sys.version_info[0]
|
||||
if PYTHON_VERSION == 2:
|
||||
from urllib import urlretrieve
|
||||
elif PYTHON_VERSION == 3:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
def _isArrayLike(obj):
|
||||
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
|
||||
|
||||
|
||||
class COCO:
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
|
||||
:param annotation_file (str): location of annotation file
|
||||
:param image_folder (str): location to the folder that hosts images.
|
||||
:return:
|
||||
"""
|
||||
# load dataset
|
||||
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
|
||||
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
||||
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
|
||||
self.dataset = dataset
|
||||
self.createIndex()
|
||||
|
||||
def createIndex(self):
|
||||
# create index
|
||||
print('creating index...')
|
||||
anns, cats, imgs = {}, {}, {}
|
||||
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
|
||||
if 'annotations' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
anns[ann['id']] = ann
|
||||
|
||||
if 'images' in self.dataset:
|
||||
for img in self.dataset['images']:
|
||||
imgs[img['id']] = img
|
||||
|
||||
if 'categories' in self.dataset:
|
||||
for cat in self.dataset['categories']:
|
||||
cats[cat['id']] = cat
|
||||
|
||||
if 'annotations' in self.dataset and 'categories' in self.dataset:
|
||||
for ann in self.dataset['annotations']:
|
||||
catToImgs[ann['category_id']].append(ann['image_id'])
|
||||
|
||||
print('index created!')
|
||||
|
||||
# create class members
|
||||
self.anns = anns
|
||||
self.imgToAnns = imgToAnns
|
||||
self.catToImgs = catToImgs
|
||||
self.imgs = imgs
|
||||
self.cats = cats
|
||||
|
||||
def info(self):
|
||||
"""
|
||||
Print information about the annotation file.
|
||||
:return:
|
||||
"""
|
||||
for key, value in self.dataset['info'].items():
|
||||
print('{}: {}'.format(key, value))
|
||||
|
||||
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
|
||||
"""
|
||||
Get ann ids that satisfy given filter conditions. default skips that filter
|
||||
:param imgIds (int array) : get anns for given imgs
|
||||
catIds (int array) : get anns for given cats
|
||||
areaRng (float array) : get anns for given area range (e.g. [0 inf])
|
||||
iscrowd (boolean) : get anns for given crowd label (False or True)
|
||||
:return: ids (int array) : integer array of ann ids
|
||||
"""
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == len(areaRng) == 0:
|
||||
anns = self.dataset['annotations']
|
||||
else:
|
||||
if not len(imgIds) == 0:
|
||||
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
|
||||
anns = list(itertools.chain.from_iterable(lists))
|
||||
else:
|
||||
anns = self.dataset['annotations']
|
||||
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
|
||||
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
|
||||
if not iscrowd == None:
|
||||
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
|
||||
else:
|
||||
ids = [ann['id'] for ann in anns]
|
||||
return ids
|
||||
|
||||
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
|
||||
"""
|
||||
filtering parameters. default skips that filter.
|
||||
:param catNms (str array) : get cats for given cat names
|
||||
:param supNms (str array) : get cats for given supercategory names
|
||||
:param catIds (int array) : get cats for given cat ids
|
||||
:return: ids (int array) : integer array of cat ids
|
||||
"""
|
||||
catNms = catNms if _isArrayLike(catNms) else [catNms]
|
||||
supNms = supNms if _isArrayLike(supNms) else [supNms]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(catNms) == len(supNms) == len(catIds) == 0:
|
||||
cats = self.dataset['categories']
|
||||
else:
|
||||
cats = self.dataset['categories']
|
||||
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
|
||||
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
|
||||
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
|
||||
ids = [cat['id'] for cat in cats]
|
||||
return ids
|
||||
|
||||
def getImgIds(self, imgIds=[], catIds=[]):
|
||||
'''
|
||||
Get img ids that satisfy given filter conditions.
|
||||
:param imgIds (int array) : get imgs for given ids
|
||||
:param catIds (int array) : get imgs with all given cats
|
||||
:return: ids (int array) : integer array of img ids
|
||||
'''
|
||||
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
|
||||
catIds = catIds if _isArrayLike(catIds) else [catIds]
|
||||
|
||||
if len(imgIds) == len(catIds) == 0:
|
||||
ids = self.imgs.keys()
|
||||
else:
|
||||
ids = set(imgIds)
|
||||
for i, catId in enumerate(catIds):
|
||||
if i == 0 and len(ids) == 0:
|
||||
ids = set(self.catToImgs[catId])
|
||||
else:
|
||||
ids &= set(self.catToImgs[catId])
|
||||
return list(ids)
|
||||
|
||||
def loadAnns(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying anns
|
||||
:return: anns (object array) : loaded ann objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.anns[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.anns[ids]]
|
||||
|
||||
def loadCats(self, ids=[]):
|
||||
"""
|
||||
Load cats with the specified ids.
|
||||
:param ids (int array) : integer ids specifying cats
|
||||
:return: cats (object array) : loaded cat objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.cats[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.cats[ids]]
|
||||
|
||||
def loadImgs(self, ids=[]):
|
||||
"""
|
||||
Load anns with the specified ids.
|
||||
:param ids (int array) : integer ids specifying img
|
||||
:return: imgs (object array) : loaded img objects
|
||||
"""
|
||||
if _isArrayLike(ids):
|
||||
return [self.imgs[id] for id in ids]
|
||||
elif type(ids) == int:
|
||||
return [self.imgs[ids]]
|
||||
|
||||
def showAnns(self, anns, draw_bbox=False):
|
||||
"""
|
||||
Display the specified annotations.
|
||||
:param anns (array of object): annotations to display
|
||||
:return: None
|
||||
"""
|
||||
if len(anns) == 0:
|
||||
return 0
|
||||
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
|
||||
datasetType = 'instances'
|
||||
elif 'caption' in anns[0]:
|
||||
datasetType = 'captions'
|
||||
else:
|
||||
raise Exception('datasetType not supported')
|
||||
if datasetType == 'instances':
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in anns:
|
||||
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
|
||||
if 'segmentation' in ann:
|
||||
if type(ann['segmentation']) == list:
|
||||
# polygon
|
||||
for seg in ann['segmentation']:
|
||||
poly = np.array(seg).reshape((int(len(seg)/2), 2))
|
||||
polygons.append(Polygon(poly))
|
||||
color.append(c)
|
||||
else:
|
||||
# mask
|
||||
t = self.imgs[ann['image_id']]
|
||||
if type(ann['segmentation']['counts']) == list:
|
||||
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
|
||||
else:
|
||||
rle = [ann['segmentation']]
|
||||
m = maskUtils.decode(rle)
|
||||
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
||||
if ann['iscrowd'] == 1:
|
||||
color_mask = np.array([2.0,166.0,101.0])/255
|
||||
if ann['iscrowd'] == 0:
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack( (img, m*0.5) ))
|
||||
if 'keypoints' in ann and type(ann['keypoints']) == list:
|
||||
# turn skeleton into zero-based index
|
||||
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
|
||||
kp = np.array(ann['keypoints'])
|
||||
x = kp[0::3]
|
||||
y = kp[1::3]
|
||||
v = kp[2::3]
|
||||
for sk in sks:
|
||||
if np.all(v[sk]>0):
|
||||
plt.plot(x[sk],y[sk], linewidth=3, color=c)
|
||||
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
|
||||
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
|
||||
|
||||
if draw_bbox:
|
||||
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
||||
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
||||
np_poly = np.array(poly).reshape((4,2))
|
||||
polygons.append(Polygon(np_poly))
|
||||
color.append(c)
|
||||
|
||||
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
|
||||
ax.add_collection(p)
|
||||
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
|
||||
ax.add_collection(p)
|
||||
elif datasetType == 'captions':
|
||||
for ann in anns:
|
||||
print(ann['caption'])
|
||||
|
||||
def loadRes(self, resFile):
|
||||
"""
|
||||
Load result file and return a result api object.
|
||||
:param resFile (str) : file name of result file
|
||||
:return: res (obj) : result api object
|
||||
"""
|
||||
res = COCO()
|
||||
res.dataset['images'] = [img for img in self.dataset['images']]
|
||||
|
||||
print('Loading and preparing results...')
|
||||
tic = time.time()
|
||||
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
|
||||
with open(resFile) as f:
|
||||
anns = json.load(f)
|
||||
elif type(resFile) == np.ndarray:
|
||||
anns = self.loadNumpyAnnotations(resFile)
|
||||
else:
|
||||
anns = resFile
|
||||
assert type(anns) == list, 'results in not an array of objects'
|
||||
annsImgIds = [ann['image_id'] for ann in anns]
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
|
||||
'Results do not correspond to current coco set'
|
||||
if 'caption' in anns[0]:
|
||||
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
|
||||
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
|
||||
for id, ann in enumerate(anns):
|
||||
ann['id'] = id+1
|
||||
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
bb = ann['bbox']
|
||||
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
|
||||
if not 'segmentation' in ann:
|
||||
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
|
||||
ann['area'] = bb[2]*bb[3]
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'segmentation' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
# now only support compressed RLE format as segmentation results
|
||||
ann['area'] = maskUtils.area(ann['segmentation'])
|
||||
if not 'bbox' in ann:
|
||||
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
|
||||
ann['id'] = id+1
|
||||
ann['iscrowd'] = 0
|
||||
elif 'keypoints' in anns[0]:
|
||||
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
|
||||
for id, ann in enumerate(anns):
|
||||
s = ann['keypoints']
|
||||
x = s[0::3]
|
||||
y = s[1::3]
|
||||
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
|
||||
ann['area'] = (x1-x0)*(y1-y0)
|
||||
ann['id'] = id + 1
|
||||
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
|
||||
print('DONE (t={:0.2f}s)'.format(time.time()- tic))
|
||||
|
||||
res.dataset['annotations'] = anns
|
||||
res.createIndex()
|
||||
return res
|
||||
|
||||
def download(self, tarDir = None, imgIds = [] ):
|
||||
'''
|
||||
Download COCO images from mscoco.org server.
|
||||
:param tarDir (str): COCO results directory name
|
||||
imgIds (list): images to be downloaded
|
||||
:return:
|
||||
'''
|
||||
if tarDir is None:
|
||||
print('Please specify target directory')
|
||||
return -1
|
||||
if len(imgIds) == 0:
|
||||
imgs = self.imgs.values()
|
||||
else:
|
||||
imgs = self.loadImgs(imgIds)
|
||||
N = len(imgs)
|
||||
if not os.path.exists(tarDir):
|
||||
os.makedirs(tarDir)
|
||||
for i, img in enumerate(imgs):
|
||||
tic = time.time()
|
||||
fname = os.path.join(tarDir, img['file_name'])
|
||||
if not os.path.exists(fname):
|
||||
urlretrieve(img['coco_url'], fname)
|
||||
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
|
||||
|
||||
def loadNumpyAnnotations(self, data):
|
||||
"""
|
||||
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
|
||||
:param data (numpy.ndarray)
|
||||
:return: annotations (python nested list)
|
||||
"""
|
||||
print('Converting ndarray to lists...')
|
||||
assert(type(data) == np.ndarray)
|
||||
print(data.shape)
|
||||
assert(data.shape[1] == 7)
|
||||
N = data.shape[0]
|
||||
ann = []
|
||||
for i in range(N):
|
||||
if i % 1000000 == 0:
|
||||
print('{}/{}'.format(i,N))
|
||||
ann += [{
|
||||
'image_id' : int(data[i, 0]),
|
||||
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
|
||||
'score' : data[i, 5],
|
||||
'category_id': int(data[i, 6]),
|
||||
}]
|
||||
return ann
|
||||
|
||||
def annToRLE(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE to RLE.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
t = self.imgs[ann['image_id']]
|
||||
h, w = t['height'], t['width']
|
||||
segm = ann['segmentation']
|
||||
if type(segm) == list:
|
||||
# polygon -- a single object might consist of multiple parts
|
||||
# we merge all parts into one mask rle code
|
||||
rles = maskUtils.frPyObjects(segm, h, w)
|
||||
rle = maskUtils.merge(rles)
|
||||
elif type(segm['counts']) == list:
|
||||
# uncompressed RLE
|
||||
rle = maskUtils.frPyObjects(segm, h, w)
|
||||
else:
|
||||
# rle
|
||||
rle = ann['segmentation']
|
||||
return rle
|
||||
|
||||
def annToMask(self, ann):
|
||||
"""
|
||||
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
|
||||
:return: binary mask (numpy 2D array)
|
||||
"""
|
||||
rle = self.annToRLE(ann)
|
||||
m = maskUtils.decode(rle)
|
||||
return m
|
11
lib/train/dataset/__init__.py
Normal file
11
lib/train/dataset/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .lasot import Lasot
|
||||
from .got10k import Got10k
|
||||
from .tracking_net import TrackingNet
|
||||
from .imagenetvid import ImagenetVID
|
||||
from .coco import MSCOCO
|
||||
from .coco_seq import MSCOCOSeq
|
||||
from .got10k_lmdb import Got10k_lmdb
|
||||
from .lasot_lmdb import Lasot_lmdb
|
||||
from .imagenetvid_lmdb import ImagenetVID_lmdb
|
||||
from .coco_seq_lmdb import MSCOCOSeq_lmdb
|
||||
from .tracking_net_lmdb import TrackingNet_lmdb
|
92
lib/train/dataset/base_image_dataset.py
Normal file
92
lib/train/dataset/base_image_dataset.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch.utils.data
|
||||
from lib.train.data.image_loader import jpeg4py_loader
|
||||
|
||||
|
||||
class BaseImageDataset(torch.utils.data.Dataset):
|
||||
""" Base class for image datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.image_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_images()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_images(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.image_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def get_class_name(self, image_id):
|
||||
return None
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_image_info(self, seq_id):
|
||||
""" Returns information about a particular image,
|
||||
|
||||
args:
|
||||
seq_id - index of the image
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
""" Get a image
|
||||
|
||||
args:
|
||||
image_id - index of image
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
image -
|
||||
anno -
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
110
lib/train/dataset/base_video_dataset.py
Normal file
110
lib/train/dataset/base_video_dataset.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import torch.utils.data
|
||||
# 2021.1.5 use jpeg4py_loader_w_failsafe as default
|
||||
from lib.train.data.image_loader import jpeg4py_loader_w_failsafe
|
||||
|
||||
|
||||
class BaseVideoDataset(torch.utils.data.Dataset):
|
||||
""" Base class for video datasets """
|
||||
|
||||
def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe):
|
||||
"""
|
||||
args:
|
||||
root - The root path to the dataset
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
"""
|
||||
self.name = name
|
||||
self.root = root
|
||||
self.image_loader = image_loader
|
||||
|
||||
self.sequence_list = [] # Contains the list of sequences.
|
||||
self.class_list = []
|
||||
|
||||
def __len__(self):
|
||||
""" Returns size of the dataset
|
||||
returns:
|
||||
int - number of samples in the dataset
|
||||
"""
|
||||
return self.get_num_sequences()
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Not to be used! Check get_frames() instead.
|
||||
"""
|
||||
return None
|
||||
|
||||
def is_video_sequence(self):
|
||||
""" Returns whether the dataset is a video dataset or an image dataset
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return True
|
||||
|
||||
def is_synthetic_video_dataset(self):
|
||||
""" Returns whether the dataset contains real videos or synthetic
|
||||
|
||||
returns:
|
||||
bool - True if a video dataset
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_name(self):
|
||||
""" Name of the dataset
|
||||
|
||||
returns:
|
||||
string - Name of the dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_sequences(self):
|
||||
""" Number of sequences in a dataset
|
||||
|
||||
returns:
|
||||
int - number of sequences in the dataset."""
|
||||
return len(self.sequence_list)
|
||||
|
||||
def has_class_info(self):
|
||||
return False
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.class_list
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
raise NotImplementedError
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return False
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
""" Returns information about a particular sequences,
|
||||
|
||||
args:
|
||||
seq_id - index of the sequence
|
||||
|
||||
returns:
|
||||
Dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
""" Get a set of frames from a particular sequence
|
||||
|
||||
args:
|
||||
seq_id - index of sequence
|
||||
frame_ids - a list of frame numbers
|
||||
anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
|
||||
|
||||
returns:
|
||||
list - List of frames corresponding to frame_ids
|
||||
list - List of dicts for each frame
|
||||
dict - A dict containing meta information about the sequence, e.g. class of the target object.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
156
lib/train/dataset/coco.py
Normal file
156
lib/train/dataset/coco.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
from .base_image_dataset import BaseImageDataset
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
|
||||
class MSCOCO(BaseImageDataset):
|
||||
""" The COCO object detection dataset.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None,
|
||||
split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to coco root folder
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
min_area - Objects with area less than min_area are filtered out. Default is 0.0
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list() # the parent class thing would happen in the sampler
|
||||
|
||||
self.image_list = self._get_image_list(min_area=min_area)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction))
|
||||
self.im_per_class = self._build_im_per_class()
|
||||
|
||||
def _get_image_list(self, min_area=None):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
if min_area is not None:
|
||||
image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area]
|
||||
|
||||
return image_list
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def _build_im_per_class(self):
|
||||
im_per_class = {}
|
||||
for i, im in enumerate(self.image_list):
|
||||
class_name = self.cats[self.coco_set.anns[im]['category_id']]['name']
|
||||
if class_name not in im_per_class:
|
||||
im_per_class[class_name] = [i]
|
||||
else:
|
||||
im_per_class[class_name].append(i)
|
||||
|
||||
return im_per_class
|
||||
|
||||
def get_images_in_class(self, class_name):
|
||||
return self.im_per_class[class_name]
|
||||
|
||||
def get_image_info(self, im_id):
|
||||
anno = self._get_anno(im_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(4,)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno))
|
||||
|
||||
valid = (bbox[2] > 0) & (bbox[3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, im_id):
|
||||
anno = self.coco_set.anns[self.image_list[im_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_image(self, im_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, im_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def get_class_name(self, im_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_image(self, image_id, anno=None):
|
||||
frame = self._get_image(image_id)
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_image_info(image_id)
|
||||
|
||||
object_meta = self.get_meta_info(image_id)
|
||||
|
||||
return frame, anno, object_meta
|
170
lib/train/dataset/coco_seq.py
Normal file
170
lib/train/dataset/coco_seq.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from pycocotools.coco import COCO
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class MSCOCOSeq(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO', root, image_loader)
|
||||
|
||||
self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version))
|
||||
self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version))
|
||||
|
||||
# Load the COCO set.
|
||||
self.coco_set = COCO(self.anno_path)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
177
lib/train/dataset/coco_seq_lmdb.py
Normal file
177
lib/train/dataset/coco_seq_lmdb.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.train.dataset.COCO_tool import COCO
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
import time
|
||||
|
||||
class MSCOCOSeq_lmdb(BaseVideoDataset):
|
||||
""" The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
|
||||
|
||||
Publication:
|
||||
Microsoft COCO: Common Objects in Context.
|
||||
Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
|
||||
Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
|
||||
ECCV, 2014
|
||||
https://arxiv.org/pdf/1405.0312.pdf
|
||||
|
||||
Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
|
||||
organized as follows.
|
||||
- coco_root
|
||||
- annotations
|
||||
- instances_train2014.json
|
||||
- instances_train2017.json
|
||||
- images
|
||||
- train2014
|
||||
- train2017
|
||||
|
||||
Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014"):
|
||||
"""
|
||||
args:
|
||||
root - path to the coco dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the
|
||||
images will be used
|
||||
split - 'train' or 'val'.
|
||||
version - version of coco dataset (2014 or 2017)
|
||||
"""
|
||||
root = env_settings().coco_dir if root is None else root
|
||||
super().__init__('COCO_lmdb', root, image_loader)
|
||||
self.root = root
|
||||
self.img_pth = 'images/{}{}/'.format(split, version)
|
||||
self.anno_path = 'annotations/instances_{}{}.json'.format(split, version)
|
||||
|
||||
# Load the COCO set.
|
||||
print('loading annotations into memory...')
|
||||
tic = time.time()
|
||||
coco_json = decode_json(root, self.anno_path)
|
||||
print('Done (t={:0.2f}s)'.format(time.time() - tic))
|
||||
|
||||
self.coco_set = COCO(coco_json)
|
||||
|
||||
self.cats = self.coco_set.cats
|
||||
|
||||
self.class_list = self.get_class_list()
|
||||
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
def _get_sequence_list(self):
|
||||
ann_list = list(self.coco_set.anns.keys())
|
||||
seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
|
||||
|
||||
return seq_list
|
||||
|
||||
def is_video_sequence(self):
|
||||
return False
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_name(self):
|
||||
return 'coco_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_class_list(self):
|
||||
class_list = []
|
||||
for cat_id in self.cats.keys():
|
||||
class_list.append(self.cats[cat_id]['name'])
|
||||
return class_list
|
||||
|
||||
def has_segmentation_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name']
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
anno = self._get_anno(seq_id)
|
||||
|
||||
bbox = torch.Tensor(anno['bbox']).view(1, 4)
|
||||
|
||||
mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0)
|
||||
|
||||
'''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels'''
|
||||
valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50)
|
||||
|
||||
visible = valid.clone().byte()
|
||||
|
||||
return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_anno(self, seq_id):
|
||||
anno = self.coco_set.anns[self.sequence_list[seq_id]]
|
||||
|
||||
return anno
|
||||
|
||||
def _get_frames(self, seq_id):
|
||||
path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
|
||||
# img = self.image_loader(os.path.join(self.img_pth, path))
|
||||
img = decode_img(self.root, os.path.join(self.img_pth, path))
|
||||
return img
|
||||
|
||||
def get_meta_info(self, seq_id):
|
||||
try:
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
object_meta = OrderedDict({'object_class_name': cat_dict_current['name'],
|
||||
'motion_class': None,
|
||||
'major_class': cat_dict_current['supercategory'],
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
|
||||
return cat_dict_current['name']
|
||||
|
||||
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
|
||||
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
|
||||
# list containing these replicated images.
|
||||
frame = self._get_frames(seq_id)
|
||||
|
||||
frame_list = [frame.copy() for _ in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[0, ...] for _ in frame_ids]
|
||||
|
||||
object_meta = self.get_meta_info(seq_id)
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
186
lib/train/dataset/got10k.py
Normal file
186
lib/train/dataset/got10k.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Got10k(BaseVideoDataset):
|
||||
""" GOT-10k dataset.
|
||||
|
||||
Publication:
|
||||
GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
|
||||
Lianghua Huang, Xin Zhao, and Kaiqi Huang
|
||||
arXiv:1810.11981, 2018
|
||||
https://arxiv.org/pdf/1810.11981.pdf
|
||||
|
||||
Download dataset from http://got-10k.aitestunion.com/downloads
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().got10k_dir if root is None else root
|
||||
super().__init__('GOT10k', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
seq_ids = pandas.read_csv(file_path, header=None, dtype=np.int64).squeeze("columns").values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
sequence_meta_info = {s: self._read_meta(os.path.join(self.root, s)) for s in self.sequence_list}
|
||||
return sequence_meta_info
|
||||
|
||||
def _read_meta(self, seq_path):
|
||||
try:
|
||||
with open(os.path.join(seq_path, 'meta_info.ini')) as f:
|
||||
meta_info = f.readlines()
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1][:-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1][:-1],
|
||||
'major_class': meta_info[7].split(': ')[-1][:-1],
|
||||
'root_class': meta_info[8].split(': ')[-1][:-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1][:-1]})
|
||||
except:
|
||||
object_meta = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return object_meta
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
with open(os.path.join(self.root, 'list.txt')) as f:
|
||||
dir_list = list(csv.reader(f))
|
||||
dir_list = [dir_name[0] for dir_name in dir_list]
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
with open(cover_file, 'r', newline='') as f:
|
||||
cover = torch.ByteTensor([int(v[0]) for v in csv.reader(f)])
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join(self.root, self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
183
lib/train/dataset/got10k_lmdb.py
Normal file
183
lib/train/dataset/got10k_lmdb.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import torch
|
||||
import csv
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
'''2021.1.16 Gok10k for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Got10k_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, split=None, seq_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the got-10k training data. Note: This should point to the 'train' folder inside GOT-10k
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
split - 'train' or 'val'. Note: The validation split here is a subset of the official got-10k train split,
|
||||
not NOT the official got-10k validation split. To use the official validation split, provide that as
|
||||
the root folder instead.
|
||||
seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids'
|
||||
options can be used at the same time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
use_lmdb - whether the dataset is stored in lmdb format
|
||||
"""
|
||||
root = env_settings().got10k_lmdb_dir if root is None else root
|
||||
super().__init__('GOT10k_lmdb', root, image_loader)
|
||||
|
||||
# all folders inside the root
|
||||
self.sequence_list = self._get_sequence_list()
|
||||
|
||||
# seq_id is the index of the folder inside the got10k root path
|
||||
if split is not None:
|
||||
if seq_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and seq_ids.')
|
||||
train_lib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_split.txt')
|
||||
elif split == 'val':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_val_split.txt')
|
||||
elif split == 'train_full':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_train_full_split.txt')
|
||||
elif split == 'vottrain':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_train_split.txt')
|
||||
elif split == 'votval':
|
||||
file_path = os.path.join(train_lib_path, 'data_specs', 'got10k_vot_val_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
seq_ids = pandas.read_csv(file_path, header=None, squeeze=True, dtype=np.int64).values.tolist()
|
||||
elif seq_ids is None:
|
||||
seq_ids = list(range(0, len(self.sequence_list)))
|
||||
|
||||
self.sequence_list = [self.sequence_list[i] for i in seq_ids]
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.sequence_meta_info = self._load_meta_info()
|
||||
self.seq_per_class = self._build_seq_per_class()
|
||||
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def get_name(self):
|
||||
return 'got10k_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def _load_meta_info(self):
|
||||
def _read_meta(meta_info):
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': meta_info[5].split(': ')[-1],
|
||||
'motion_class': meta_info[6].split(': ')[-1],
|
||||
'major_class': meta_info[7].split(': ')[-1],
|
||||
'root_class': meta_info[8].split(': ')[-1],
|
||||
'motion_adverb': meta_info[9].split(': ')[-1]})
|
||||
|
||||
return object_meta
|
||||
sequence_meta_info = {}
|
||||
for s in self.sequence_list:
|
||||
try:
|
||||
meta_str = decode_str(self.root, "train/%s/meta_info.ini" %s)
|
||||
sequence_meta_info[s] = _read_meta(meta_str.split('\n'))
|
||||
except:
|
||||
sequence_meta_info[s] = OrderedDict({'object_class_name': None,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
return sequence_meta_info
|
||||
|
||||
def _build_seq_per_class(self):
|
||||
seq_per_class = {}
|
||||
|
||||
for i, s in enumerate(self.sequence_list):
|
||||
object_class = self.sequence_meta_info[s]['object_class_name']
|
||||
if object_class in seq_per_class:
|
||||
seq_per_class[object_class].append(i)
|
||||
else:
|
||||
seq_per_class[object_class] = [i]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _get_sequence_list(self):
|
||||
dir_str = decode_str(self.root, 'train/list.txt')
|
||||
dir_list = dir_str.split('\n')
|
||||
return dir_list
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line in got10k is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# full occlusion and out_of_view files
|
||||
occlusion_file = os.path.join(seq_path, "absence.label")
|
||||
cover_file = os.path.join(seq_path, "cover.label")
|
||||
# Read these files
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
cover_list = list(map(int, decode_str(self.root, cover_file).split('\n')[:-1])) # the last line in got10k is empty
|
||||
cover = torch.ByteTensor(cover_list)
|
||||
|
||||
target_visible = ~occlusion & (cover>0).byte()
|
||||
|
||||
visible_ratio = cover.float() / 8
|
||||
return target_visible, visible_ratio
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
return os.path.join("train", self.sequence_list[seq_id])
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible, visible_ratio = self._read_target_visible(seq_path)
|
||||
visible = visible & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible, 'visible_ratio': visible_ratio}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
return obj_meta['object_class_name']
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_meta = self.sequence_meta_info[self.sequence_list[seq_id]]
|
||||
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
return frame_list, anno_frames, obj_meta
|
159
lib/train/dataset/imagenetvid.py
Normal file
159
lib/train/dataset/imagenetvid.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import xml.etree.ElementTree as ET
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid", root, image_loader)
|
||||
|
||||
cache_file = os.path.join(root, 'cache.json')
|
||||
if os.path.isfile(cache_file):
|
||||
# If available, load the pre-processed cache file containing meta-info for each sequence
|
||||
with open(cache_file, 'r') as f:
|
||||
sequence_list_dict = json.load(f)
|
||||
|
||||
self.sequence_list = sequence_list_dict
|
||||
else:
|
||||
# Else process the imagenet annotations and generate the cache file
|
||||
self.sequence_list = self._process_anno(root)
|
||||
|
||||
with open(cache_file, 'w') as f:
|
||||
json.dump(self.sequence_list, f)
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join(self.root, 'Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
||||
def _process_anno(self, root):
|
||||
# Builds individual tracklets
|
||||
base_vid_anno_path = os.path.join(root, 'Annotations', 'VID', 'train')
|
||||
|
||||
all_sequences = []
|
||||
for set in sorted(os.listdir(base_vid_anno_path)):
|
||||
set_id = int(set.split('_')[-1])
|
||||
for vid in sorted(os.listdir(os.path.join(base_vid_anno_path, set))):
|
||||
|
||||
vid_id = int(vid.split('_')[-1])
|
||||
anno_files = sorted(os.listdir(os.path.join(base_vid_anno_path, set, vid)))
|
||||
|
||||
frame1_anno = ET.parse(os.path.join(base_vid_anno_path, set, vid, anno_files[0]))
|
||||
image_size = [int(frame1_anno.find('size/width').text), int(frame1_anno.find('size/height').text)]
|
||||
|
||||
objects = [ET.ElementTree(file=os.path.join(base_vid_anno_path, set, vid, f)).findall('object')
|
||||
for f in anno_files]
|
||||
|
||||
tracklets = {}
|
||||
|
||||
# Find all tracklets along with start frame
|
||||
for f_id, all_targets in enumerate(objects):
|
||||
for target in all_targets:
|
||||
tracklet_id = target.find('trackid').text
|
||||
if tracklet_id not in tracklets:
|
||||
tracklets[tracklet_id] = f_id
|
||||
|
||||
for tracklet_id, tracklet_start in tracklets.items():
|
||||
tracklet_anno = []
|
||||
target_visible = []
|
||||
class_name_id = None
|
||||
|
||||
for f_id in range(tracklet_start, len(objects)):
|
||||
found = False
|
||||
for target in objects[f_id]:
|
||||
if target.find('trackid').text == tracklet_id:
|
||||
if not class_name_id:
|
||||
class_name_id = target.find('name').text
|
||||
x1 = int(target.find('bndbox/xmin').text)
|
||||
y1 = int(target.find('bndbox/ymin').text)
|
||||
x2 = int(target.find('bndbox/xmax').text)
|
||||
y2 = int(target.find('bndbox/ymax').text)
|
||||
|
||||
tracklet_anno.append([x1, y1, x2 - x1, y2 - y1])
|
||||
target_visible.append(target.find('occluded').text == '0')
|
||||
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
break
|
||||
|
||||
new_sequence = {'set_id': set_id, 'vid_id': vid_id, 'class_name': class_name_id,
|
||||
'start_frame': tracklet_start, 'anno': tracklet_anno,
|
||||
'target_visible': target_visible, 'image_size': image_size}
|
||||
all_sequences.append(new_sequence)
|
||||
|
||||
return all_sequences
|
90
lib/train/dataset/imagenetvid_lmdb.py
Normal file
90
lib/train/dataset/imagenetvid_lmdb.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from lib.train.admin import env_settings
|
||||
from lib.utils.lmdb_utils import decode_img, decode_json
|
||||
|
||||
|
||||
def get_target_to_image_ratio(seq):
|
||||
anno = torch.Tensor(seq['anno'])
|
||||
img_sz = torch.Tensor(seq['image_size'])
|
||||
return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
|
||||
|
||||
|
||||
class ImagenetVID_lmdb(BaseVideoDataset):
|
||||
""" Imagenet VID dataset.
|
||||
|
||||
Publication:
|
||||
ImageNet Large Scale Visual Recognition Challenge
|
||||
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
|
||||
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
|
||||
IJCV, 2015
|
||||
https://arxiv.org/pdf/1409.0575.pdf
|
||||
|
||||
Download the dataset from http://image-net.org/
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
|
||||
"""
|
||||
args:
|
||||
root - path to the imagenet vid dataset.
|
||||
image_loader (default_image_loader) - The function to read the images. If installed,
|
||||
jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
|
||||
opencv's imread is used.
|
||||
min_length - Minimum allowed sequence length.
|
||||
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
|
||||
which cover complete image.
|
||||
"""
|
||||
root = env_settings().imagenet_dir if root is None else root
|
||||
super().__init__("imagenetvid_lmdb", root, image_loader)
|
||||
|
||||
sequence_list_dict = decode_json(root, "cache.json")
|
||||
self.sequence_list = sequence_list_dict
|
||||
|
||||
# Filter the sequences based on min_length and max_target_area in the first frame
|
||||
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
|
||||
get_target_to_image_ratio(x) < max_target_area]
|
||||
|
||||
def get_name(self):
|
||||
return 'imagenetvid_lmdb'
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
|
||||
valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
|
||||
visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
|
||||
return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, sequence, frame_id):
|
||||
set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
|
||||
vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
|
||||
frame_number = frame_id + sequence['start_frame']
|
||||
frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name,
|
||||
'{:06d}.JPEG'.format(frame_number))
|
||||
return decode_img(self.root, frame_path)
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
sequence = self.sequence_list[seq_id]
|
||||
|
||||
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
# Create anno dict
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
# added the class info to the meta info
|
||||
object_meta = OrderedDict({'object_class': sequence['class_name'],
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
||||
|
169
lib/train/dataset/lasot.py
Normal file
169
lib/train/dataset/lasot.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
class Lasot(BaseVideoDataset):
|
||||
""" LaSOT dataset.
|
||||
|
||||
Publication:
|
||||
LaSOT: A High-quality Benchmark for Large-scale Single Object Tracking
|
||||
Heng Fan, Liting Lin, Fan Yang, Peng Chu, Ge Deng, Sijia Yu, Hexin Bai, Yong Xu, Chunyuan Liao and Haibin Ling
|
||||
CVPR, 2019
|
||||
https://arxiv.org/pdf/1809.07845.pdf
|
||||
|
||||
Download the dataset from https://cis.temple.edu/lasot/download.html
|
||||
"""
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_dir if root is None else root
|
||||
super().__init__('LaSOT', root, image_loader)
|
||||
|
||||
# Keep a list of all classes
|
||||
self.class_list = [f for f in os.listdir(self.root)]
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
# sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
sequence_list = pandas.read_csv(file_path, header=None).squeeze("columns").values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
with open(occlusion_file, 'r', newline='') as f:
|
||||
occlusion = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
with open(out_of_view_file, 'r') as f:
|
||||
out_of_view = torch.ByteTensor([int(v) for v in list(csv.reader(f))[0]])
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(self.root, class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return self.image_loader(self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
165
lib/train/dataset/lasot_lmdb.py
Normal file
165
lib/train/dataset/lasot_lmdb.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import os.path
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas
|
||||
import csv
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from lib.train.admin import env_settings
|
||||
'''2021.1.16 Lasot for loading lmdb dataset'''
|
||||
from lib.utils.lmdb_utils import *
|
||||
|
||||
|
||||
class Lasot_lmdb(BaseVideoDataset):
|
||||
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - path to the lasot dataset.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the
|
||||
videos with subscripts -1, -3, and -5 from each class will be used for training.
|
||||
split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of
|
||||
vid_ids or split option can be used at a time.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().lasot_lmdb_dir if root is None else root
|
||||
super().__init__('LaSOT_lmdb', root, image_loader)
|
||||
|
||||
self.sequence_list = self._build_sequence_list(vid_ids, split)
|
||||
class_list = [seq_name.split('-')[0] for seq_name in self.sequence_list]
|
||||
self.class_list = []
|
||||
for ele in class_list:
|
||||
if ele not in self.class_list:
|
||||
self.class_list.append(ele)
|
||||
# Keep a list of all classes
|
||||
self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)}
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction))
|
||||
|
||||
self.seq_per_class = self._build_class_list()
|
||||
|
||||
def _build_sequence_list(self, vid_ids=None, split=None):
|
||||
if split is not None:
|
||||
if vid_ids is not None:
|
||||
raise ValueError('Cannot set both split_name and vid_ids.')
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
if split == 'train':
|
||||
file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt')
|
||||
else:
|
||||
raise ValueError('Unknown split name.')
|
||||
sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist()
|
||||
elif vid_ids is not None:
|
||||
sequence_list = [c+'-'+str(v) for c in self.class_list for v in vid_ids]
|
||||
else:
|
||||
raise ValueError('Set either split_name or vid_ids.')
|
||||
|
||||
return sequence_list
|
||||
|
||||
def _build_class_list(self):
|
||||
seq_per_class = {}
|
||||
for seq_id, seq_name in enumerate(self.sequence_list):
|
||||
class_name = seq_name.split('-')[0]
|
||||
if class_name in seq_per_class:
|
||||
seq_per_class[class_name].append(seq_id)
|
||||
else:
|
||||
seq_per_class[class_name] = [seq_id]
|
||||
|
||||
return seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'lasot_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def has_occlusion_info(self):
|
||||
return True
|
||||
|
||||
def get_num_sequences(self):
|
||||
return len(self.sequence_list)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.class_list)
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_path):
|
||||
bb_anno_file = os.path.join(seq_path, "groundtruth.txt")
|
||||
gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line is empty
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def _read_target_visible(self, seq_path):
|
||||
# Read full occlusion and out_of_view
|
||||
occlusion_file = os.path.join(seq_path, "full_occlusion.txt")
|
||||
out_of_view_file = os.path.join(seq_path, "out_of_view.txt")
|
||||
|
||||
occ_list = list(map(int, decode_str(self.root, occlusion_file).split(',')))
|
||||
occlusion = torch.ByteTensor(occ_list)
|
||||
out_view_list = list(map(int, decode_str(self.root, out_of_view_file).split(',')))
|
||||
out_of_view = torch.ByteTensor(out_view_list)
|
||||
|
||||
target_visible = ~occlusion & ~out_of_view
|
||||
|
||||
return target_visible
|
||||
|
||||
def _get_sequence_path(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id]
|
||||
class_name = seq_name.split('-')[0]
|
||||
vid_id = seq_name.split('-')[1]
|
||||
|
||||
return os.path.join(class_name, class_name + '-' + vid_id)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
bbox = self._read_bb_anno(seq_path)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = self._read_target_visible(seq_path) & valid.byte()
|
||||
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame_path(self, seq_path, frame_id):
|
||||
return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id+1)) # frames start from 1
|
||||
|
||||
def _get_frame(self, seq_path, frame_id):
|
||||
return decode_img(self.root, self._get_frame_path(seq_path, frame_id))
|
||||
|
||||
def _get_class(self, seq_path):
|
||||
raw_class = seq_path.split('/')[-2]
|
||||
return raw_class
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
obj_class = self._get_class(seq_path)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
seq_path = self._get_sequence_path(seq_id)
|
||||
|
||||
obj_class = self._get_class(seq_path)
|
||||
frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
151
lib/train/dataset/tracking_net.py
Normal file
151
lib/train/dataset/tracking_net.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import pandas
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
|
||||
|
||||
def list_sequences(root, set_ids):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
set_ids: Sets (0-11) which are to be used
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
sequence_list = []
|
||||
|
||||
for s in set_ids:
|
||||
anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno")
|
||||
|
||||
sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
|
||||
sequence_list += sequences_cur_set
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_dir if root is None else root
|
||||
super().__init__('TrackingNet', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root, self.set_ids)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
bb_anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt")
|
||||
gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False,
|
||||
low_memory=False).values
|
||||
return torch.tensor(gt)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg")
|
||||
return self.image_loader(frame_path)
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
147
lib/train/dataset/tracking_net_lmdb.py
Normal file
147
lib/train/dataset/tracking_net_lmdb.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
from lib.train.data import jpeg4py_loader
|
||||
from .base_video_dataset import BaseVideoDataset
|
||||
from lib.train.admin import env_settings
|
||||
import json
|
||||
from lib.utils.lmdb_utils import decode_img, decode_str
|
||||
|
||||
|
||||
def list_sequences(root):
|
||||
""" Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
|
||||
|
||||
args:
|
||||
root: Root directory to TrackingNet
|
||||
|
||||
returns:
|
||||
list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
|
||||
"""
|
||||
fname = os.path.join(root, "seq_list.json")
|
||||
with open(fname, "r") as f:
|
||||
sequence_list = json.loads(f.read())
|
||||
return sequence_list
|
||||
|
||||
|
||||
class TrackingNet_lmdb(BaseVideoDataset):
|
||||
""" TrackingNet dataset.
|
||||
|
||||
Publication:
|
||||
TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
|
||||
Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
|
||||
ECCV, 2018
|
||||
https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
|
||||
|
||||
Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
|
||||
"""
|
||||
def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None):
|
||||
"""
|
||||
args:
|
||||
root - The path to the TrackingNet folder, containing the training sets.
|
||||
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
|
||||
is used by default.
|
||||
set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
|
||||
sets (0 - 11) will be used.
|
||||
data_fraction - Fraction of dataset to be used. The complete dataset is used by default
|
||||
"""
|
||||
root = env_settings().trackingnet_lmdb_dir if root is None else root
|
||||
super().__init__('TrackingNet_lmdb', root, image_loader)
|
||||
|
||||
if set_ids is None:
|
||||
set_ids = [i for i in range(12)]
|
||||
|
||||
self.set_ids = set_ids
|
||||
|
||||
# Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
|
||||
# video_name for each sequence
|
||||
self.sequence_list = list_sequences(self.root)
|
||||
|
||||
if data_fraction is not None:
|
||||
self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction))
|
||||
|
||||
self.seq_to_class_map, self.seq_per_class = self._load_class_info()
|
||||
|
||||
# we do not have the class_lists for the tracking net
|
||||
self.class_list = list(self.seq_per_class.keys())
|
||||
self.class_list.sort()
|
||||
|
||||
def _load_class_info(self):
|
||||
ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
||||
class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt')
|
||||
|
||||
with open(class_map_path, 'r') as f:
|
||||
seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f}
|
||||
|
||||
seq_per_class = {}
|
||||
for i, seq in enumerate(self.sequence_list):
|
||||
class_name = seq_to_class_map.get(seq[1], 'Unknown')
|
||||
if class_name not in seq_per_class:
|
||||
seq_per_class[class_name] = [i]
|
||||
else:
|
||||
seq_per_class[class_name].append(i)
|
||||
|
||||
return seq_to_class_map, seq_per_class
|
||||
|
||||
def get_name(self):
|
||||
return 'trackingnet_lmdb'
|
||||
|
||||
def has_class_info(self):
|
||||
return True
|
||||
|
||||
def get_sequences_in_class(self, class_name):
|
||||
return self.seq_per_class[class_name]
|
||||
|
||||
def _read_bb_anno(self, seq_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
gt_str_list = decode_str(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("anno", vid_name + ".txt")).split('\n')[:-1]
|
||||
gt_list = [list(map(float, line.split(','))) for line in gt_str_list]
|
||||
gt_arr = np.array(gt_list).astype(np.float32)
|
||||
return torch.tensor(gt_arr)
|
||||
|
||||
def get_sequence_info(self, seq_id):
|
||||
bbox = self._read_bb_anno(seq_id)
|
||||
|
||||
valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0)
|
||||
visible = valid.clone().byte()
|
||||
return {'bbox': bbox, 'valid': valid, 'visible': visible}
|
||||
|
||||
def _get_frame(self, seq_id, frame_id):
|
||||
set_id = self.sequence_list[seq_id][0]
|
||||
vid_name = self.sequence_list[seq_id][1]
|
||||
return decode_img(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id),
|
||||
os.path.join("frames", vid_name, str(frame_id) + ".jpg"))
|
||||
|
||||
def _get_class(self, seq_id):
|
||||
seq_name = self.sequence_list[seq_id][1]
|
||||
return self.seq_to_class_map[seq_name]
|
||||
|
||||
def get_class_name(self, seq_id):
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
return obj_class
|
||||
|
||||
def get_frames(self, seq_id, frame_ids, anno=None):
|
||||
frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
|
||||
|
||||
if anno is None:
|
||||
anno = self.get_sequence_info(seq_id)
|
||||
|
||||
anno_frames = {}
|
||||
for key, value in anno.items():
|
||||
anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
|
||||
|
||||
obj_class = self._get_class(seq_id)
|
||||
|
||||
object_meta = OrderedDict({'object_class_name': obj_class,
|
||||
'motion_class': None,
|
||||
'major_class': None,
|
||||
'root_class': None,
|
||||
'motion_adverb': None})
|
||||
|
||||
return frame_list, anno_frames, object_meta
|
113
lib/train/run_training.py
Normal file
113
lib/train/run_training.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import importlib
|
||||
import cv2 as cv
|
||||
import torch.backends.cudnn
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
import _init_paths
|
||||
import lib.train.admin.settings as ws_settings
|
||||
|
||||
|
||||
def init_seeds(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(1)
|
||||
cv.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None,
|
||||
use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False,
|
||||
distill=None, script_teacher=None, config_teacher=None):
|
||||
"""Run the train script.
|
||||
args:
|
||||
script_name: Name of emperiment in the "experiments/" folder.
|
||||
config_name: Name of the yaml file in the "experiments/<script_name>".
|
||||
cudnn_benchmark: Use cudnn benchmark or not (default is True).
|
||||
"""
|
||||
if save_dir is None:
|
||||
print("save_dir dir is not given. Use the default dir instead.")
|
||||
# This is needed to avoid strange crashes related to opencv
|
||||
torch.set_num_threads(4)
|
||||
cv.setNumThreads(4)
|
||||
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
|
||||
print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name))
|
||||
|
||||
'''2021.1.5 set seed for different process'''
|
||||
if base_seed is not None:
|
||||
if local_rank != -1:
|
||||
init_seeds(base_seed + local_rank)
|
||||
else:
|
||||
init_seeds(base_seed)
|
||||
|
||||
settings = ws_settings.Settings()
|
||||
settings.script_name = script_name
|
||||
settings.config_name = config_name
|
||||
settings.project_path = 'train/{}/{}'.format(script_name, config_name)
|
||||
if script_name_prv is not None and config_name_prv is not None:
|
||||
settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv)
|
||||
settings.local_rank = local_rank
|
||||
settings.save_dir = os.path.abspath(save_dir)
|
||||
settings.use_lmdb = use_lmdb
|
||||
prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name))
|
||||
settings.use_wandb = use_wandb
|
||||
if distill:
|
||||
settings.distill = distill
|
||||
settings.script_teacher = script_teacher
|
||||
settings.config_teacher = config_teacher
|
||||
if script_teacher is not None and config_teacher is not None:
|
||||
settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher)
|
||||
settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher))
|
||||
expr_module = importlib.import_module('lib.train.train_script_distill')
|
||||
else:
|
||||
expr_module = importlib.import_module('lib.train.train_script')
|
||||
expr_func = getattr(expr_module, 'run')
|
||||
|
||||
expr_func(settings)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
|
||||
parser.add_argument('--script', type=str, required=True, help='Name of the train script.')
|
||||
parser.add_argument('--config', type=str, required=True, help="Name of the config file.")
|
||||
parser.add_argument('--cudnn_benchmark', type=bool, default=False, help='Set cudnn benchmark on (1) or off (0) (default is on).')
|
||||
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
|
||||
parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs')
|
||||
parser.add_argument('--seed', type=int, default=42, help='seed for random numbers')
|
||||
parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format
|
||||
parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.')
|
||||
parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.")
|
||||
parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb
|
||||
# for knowledge distillation
|
||||
parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation
|
||||
parser.add_argument('--script_teacher', type=str, help='teacher script name')
|
||||
parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.local_rank != -1:
|
||||
dist.init_process_group(backend='nccl')
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
else:
|
||||
torch.cuda.set_device(0)
|
||||
run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark,
|
||||
local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed,
|
||||
use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv,
|
||||
use_wandb=args.use_wandb,
|
||||
distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
203
lib/train/train_script.py
Normal file
203
lib/train/train_script.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer, LTRSeqTrainer
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader, sequence_sampler
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
# forward propagation related
|
||||
from lib.train.actors import ARTrackActor, ARTrackSeqActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
from ..utils.focal_loss import FocalLoss
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
def slt_collate(batch):
|
||||
ret = {}
|
||||
for k in batch[0].keys():
|
||||
here_list = []
|
||||
for ex in batch:
|
||||
here_list.append(ex[k])
|
||||
ret[k] = here_list
|
||||
return ret
|
||||
|
||||
class SLTLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = slt_collate
|
||||
|
||||
super(SLTLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
bins = cfg.MODEL.BINS
|
||||
search_size = cfg.DATA.SEARCH.SIZE
|
||||
# Create network
|
||||
if settings.script_name == "artrack":
|
||||
net = build_artrack(cfg)
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
net = build_artrack_seq(cfg)
|
||||
dataset_train = sequence_sampler.SequenceSampler(
|
||||
datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_GAP, max_interval=cfg.DATA.MAX_INTERVAL,
|
||||
num_search_frames=cfg.DATA.SEARCH.NUMBER, num_template_frames=1,
|
||||
frame_sample_mode='random_interval',
|
||||
prob=cfg.DATA.INTERVAL_PROB)
|
||||
loader_train = SLTLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER,
|
||||
shuffle=False, drop_last=True)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
if settings.local_rank != -1:
|
||||
# net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "artrack":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackSeqActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# if cfg.TRAIN.DEEP_SUPERVISION:
|
||||
# raise ValueError("Deep supervision is not supported now.")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
if settings.script_name == "artrack":
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
trainer = LTRSeqTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|
111
lib/train/train_script_distill.py
Normal file
111
lib/train/train_script_distill.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.stark import build_starks, build_starkst
|
||||
from lib.models.stark import build_stark_lightning_x_trt
|
||||
# forward propagation related
|
||||
from lib.train.actors import STARKLightningXtrtdistillActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
|
||||
def build_network(script_name, cfg):
|
||||
# Create network
|
||||
if script_name == "stark_s":
|
||||
net = build_starks(cfg)
|
||||
elif script_name == "stark_st1" or script_name == "stark_st2":
|
||||
net = build_starkst(cfg)
|
||||
elif script_name == "stark_lightning_X_trt":
|
||||
net = build_stark_lightning_x_trt(cfg, phase="train")
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
return net
|
||||
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update the default teacher configs with teacher config file
|
||||
if not os.path.exists(settings.cfg_file_teacher):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file_teacher)
|
||||
config_module_teacher = importlib.import_module("lib.config.%s.config" % settings.script_teacher)
|
||||
cfg_teacher = config_module_teacher.cfg
|
||||
config_module_teacher.update_config_from_file(settings.cfg_file_teacher)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New teacher configuration is shown below.")
|
||||
for key in cfg_teacher.keys():
|
||||
print("%s configuration:" % key, cfg_teacher[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
"""turn on the distillation mode"""
|
||||
cfg.TRAIN.DISTILL = True
|
||||
cfg_teacher.TRAIN.DISTILL = True
|
||||
net = build_network(settings.script_name, cfg)
|
||||
net_teacher = build_network(settings.script_teacher, cfg_teacher)
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
net_teacher.cuda()
|
||||
net_teacher.eval()
|
||||
|
||||
if settings.local_rank != -1:
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
net_teacher = DDP(net_teacher, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
# settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
# settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "L1")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "stark_lightning_X_trt":
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
|
||||
actor = STARKLightningXtrtdistillActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings,
|
||||
net_teacher=net_teacher)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True, distill=True)
|
3
lib/train/trainers/__init__.py
Normal file
3
lib/train/trainers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_trainer import BaseTrainer
|
||||
from .ltr_trainer import LTRTrainer
|
||||
from .ltr_seq_trainer import LTRSeqTrainer
|
275
lib/train/trainers/base_trainer.py
Normal file
275
lib/train/trainers/base_trainer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import traceback
|
||||
from lib.train.admin import multigpu
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""Base trainer class. Contains functions for training and saving/loading checkpoints.
|
||||
Trainer classes should inherit from this one and overload the train_epoch function."""
|
||||
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
self.actor = actor
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.loaders = loaders
|
||||
|
||||
self.update_settings(settings)
|
||||
|
||||
self.epoch = 0
|
||||
self.stats = {}
|
||||
|
||||
self.device = getattr(settings, 'device', None)
|
||||
if self.device is None:
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() and settings.use_gpu else "cpu")
|
||||
|
||||
self.actor.to(self.device)
|
||||
self.settings = settings
|
||||
|
||||
def update_settings(self, settings=None):
|
||||
"""Updates the trainer settings. Must be called to update internal settings."""
|
||||
if settings is not None:
|
||||
self.settings = settings
|
||||
|
||||
if self.settings.env.workspace_dir is not None:
|
||||
self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
|
||||
'''2021.1.4 New function: specify checkpoint dir'''
|
||||
if self.settings.save_dir is None:
|
||||
self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
|
||||
else:
|
||||
self._checkpoint_dir = os.path.join(self.settings.save_dir, 'checkpoints')
|
||||
print("checkpoints will be saved to %s" % self._checkpoint_dir)
|
||||
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(self._checkpoint_dir):
|
||||
print("Training with multiple GPUs. checkpoints directory doesn't exist. "
|
||||
"Create checkpoints directory")
|
||||
os.makedirs(self._checkpoint_dir)
|
||||
else:
|
||||
self._checkpoint_dir = None
|
||||
|
||||
def train(self, max_epochs, load_latest=False, fail_safe=True, load_previous_ckpt=False, distill=False):
|
||||
"""Do training for the given number of epochs.
|
||||
args:
|
||||
max_epochs - Max number of training epochs,
|
||||
load_latest - Bool indicating whether to resume from latest epoch.
|
||||
fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
|
||||
"""
|
||||
|
||||
epoch = -1
|
||||
num_tries = 1
|
||||
for i in range(num_tries):
|
||||
try:
|
||||
if load_latest:
|
||||
self.load_checkpoint()
|
||||
if load_previous_ckpt:
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_prv)
|
||||
self.load_state_dict(directory)
|
||||
if distill:
|
||||
directory_teacher = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_teacher)
|
||||
self.load_state_dict(directory_teacher, distill=True)
|
||||
for epoch in range(self.epoch+1, max_epochs+1):
|
||||
self.epoch = epoch
|
||||
|
||||
self.train_epoch()
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
if self.settings.scheduler_type != 'cosine':
|
||||
self.lr_scheduler.step()
|
||||
else:
|
||||
self.lr_scheduler.step(epoch - 1)
|
||||
# only save the last 10 checkpoints
|
||||
save_every_epoch = getattr(self.settings, "save_every_epoch", False)
|
||||
save_epochs = []
|
||||
if epoch > (max_epochs - 1) or save_every_epoch or epoch % 5 == 0 or epoch in save_epochs or epoch > (max_epochs - 5):
|
||||
# if epoch > (max_epochs - 10) or save_every_epoch or epoch % 100 == 0:
|
||||
if self._checkpoint_dir:
|
||||
if self.settings.local_rank in [-1, 0]:
|
||||
self.save_checkpoint()
|
||||
except:
|
||||
print('Training crashed at epoch {}'.format(epoch))
|
||||
if fail_safe:
|
||||
self.epoch -= 1
|
||||
load_latest = True
|
||||
print('Traceback for the error!')
|
||||
print(traceback.format_exc())
|
||||
print('Restarting training from last epoch ...')
|
||||
else:
|
||||
raise
|
||||
|
||||
print('Finished training!')
|
||||
|
||||
def train_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self):
|
||||
"""Saves a checkpoint of the network and other variables."""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
state = {
|
||||
'epoch': self.epoch,
|
||||
'actor_type': actor_type,
|
||||
'net_type': net_type,
|
||||
'net': net.state_dict(),
|
||||
'net_info': getattr(net, 'info', None),
|
||||
'constructor': getattr(net, 'constructor', None),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'stats': self.stats,
|
||||
'settings': self.settings
|
||||
}
|
||||
|
||||
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
|
||||
print(directory)
|
||||
if not os.path.exists(directory):
|
||||
print("directory doesn't exist. creating...")
|
||||
os.makedirs(directory)
|
||||
|
||||
# First save as a tmp file
|
||||
tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
|
||||
torch.save(state, tmp_file_path)
|
||||
|
||||
file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
|
||||
|
||||
# Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
|
||||
os.rename(tmp_file_path, file_path)
|
||||
|
||||
def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
actor_type = type(self.actor).__name__
|
||||
net_type = type(net).__name__
|
||||
|
||||
if checkpoint is None:
|
||||
# Load most recent checkpoint
|
||||
checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
|
||||
self.settings.project_path, net_type)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
print('No matching checkpoint file found')
|
||||
return
|
||||
elif isinstance(checkpoint, int):
|
||||
# Checkpoint is the epoch number
|
||||
checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
|
||||
net_type, checkpoint)
|
||||
elif isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
if fields is None:
|
||||
fields = checkpoint_dict.keys()
|
||||
if ignore_fields is None:
|
||||
ignore_fields = ['settings']
|
||||
|
||||
# Never load the scheduler. It exists in older checkpoints.
|
||||
ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])
|
||||
|
||||
# Load all fields
|
||||
for key in fields:
|
||||
if key in ignore_fields:
|
||||
continue
|
||||
if key == 'net':
|
||||
net.load_state_dict(checkpoint_dict[key])
|
||||
elif key == 'optimizer':
|
||||
self.optimizer.load_state_dict(checkpoint_dict[key])
|
||||
else:
|
||||
setattr(self, key, checkpoint_dict[key])
|
||||
|
||||
# Set the net info
|
||||
if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
|
||||
net.constructor = checkpoint_dict['constructor']
|
||||
if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
|
||||
net.info = checkpoint_dict['net_info']
|
||||
|
||||
# Update the epoch in lr scheduler
|
||||
if 'epoch' in fields:
|
||||
self.lr_scheduler.last_epoch = self.epoch
|
||||
# 2021.1.10 Update the epoch in data_samplers
|
||||
for loader in self.loaders:
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
return True
|
||||
|
||||
def load_state_dict(self, checkpoint=None, distill=False):
|
||||
"""Loads a network checkpoint file.
|
||||
|
||||
Can be called in three different ways:
|
||||
load_checkpoint():
|
||||
Loads the latest epoch from the workspace. Use this to continue training.
|
||||
load_checkpoint(epoch_num):
|
||||
Loads the network at the given epoch number (int).
|
||||
load_checkpoint(path_to_checkpoint):
|
||||
Loads the file from the given absolute path (str).
|
||||
"""
|
||||
if distill:
|
||||
net = self.actor.net_teacher.module if multigpu.is_multi_gpu(self.actor.net_teacher) \
|
||||
else self.actor.net_teacher
|
||||
else:
|
||||
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
||||
|
||||
net_type = type(net).__name__
|
||||
|
||||
if isinstance(checkpoint, str):
|
||||
# checkpoint is the path
|
||||
if os.path.isdir(checkpoint):
|
||||
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
||||
if checkpoint_list:
|
||||
checkpoint_path = checkpoint_list[-1]
|
||||
else:
|
||||
raise Exception('No checkpoint found')
|
||||
else:
|
||||
checkpoint_path = os.path.expanduser(checkpoint)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
# Load network
|
||||
print("Loading pretrained model from ", checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
||||
|
||||
missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
|
||||
print("previous checkpoint is loaded.")
|
||||
print("missing keys: ", missing_k)
|
||||
print("unexpected keys:", unexpected_k)
|
||||
|
||||
return True
|
322
lib/train/trainers/ltr_seq_trainer.py
Normal file
322
lib/train/trainers/ltr_seq_trainer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
# from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
from memory_profiler import profile
|
||||
# from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRSeqTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
# self.wandb_writer = None
|
||||
# if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
self.miou_list = []
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.actor.eval()
|
||||
self.data_read_done_time = time.time()
|
||||
with torch.no_grad():
|
||||
explore_result = self.actor.explore(data)
|
||||
if explore_result == None:
|
||||
print("this time i skip")
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
continue
|
||||
# get inputs
|
||||
# print(data)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
|
||||
stats = {}
|
||||
reward_record = []
|
||||
miou_record = []
|
||||
e_miou_record = []
|
||||
num_seq = len(data['num_frames'])
|
||||
|
||||
# Calculate reward tensor
|
||||
# reward_tensor = torch.zeros(explore_result['baseline_iou'].size())
|
||||
baseline_iou = explore_result['baseline_iou']
|
||||
# explore_iou = explore_result['explore_iou']
|
||||
for seq_idx in range(num_seq):
|
||||
num_frames = data['num_frames'][seq_idx] - 1
|
||||
b_miou = torch.mean(baseline_iou[:num_frames, seq_idx])
|
||||
# e_miou = torch.mean(explore_iou[:num_frames, seq_idx])
|
||||
miou_record.append(b_miou.item())
|
||||
# e_miou_record.append(e_miou.item())
|
||||
|
||||
b_reward = b_miou.item()
|
||||
# e_reward = e_miou.item()
|
||||
# iou_gap = e_reward - b_reward
|
||||
# reward_record.append(iou_gap)
|
||||
# reward_tensor[:num_frames, seq_idx] = iou_gap
|
||||
|
||||
# Training mode
|
||||
cursor = 0
|
||||
bs_backward = 1
|
||||
|
||||
# print(self.actor.net.module.box_head.decoder.layers[2].mlpx.fc1.weight)
|
||||
self.optimizer.zero_grad()
|
||||
while cursor < num_seq:
|
||||
# print("now is ", cursor , "and all is ", num_seq)
|
||||
model_inputs = {}
|
||||
model_inputs['slt_loss_weight'] = 15
|
||||
if cursor < num_seq:
|
||||
model_inputs['template_images'] = explore_result['template_images'][
|
||||
cursor:cursor + bs_backward].cuda()
|
||||
else:
|
||||
model_inputs['template_images'] = explore_result['template_images_reverse'][
|
||||
cursor - num_seq:cursor - num_seq + bs_backward].cuda()
|
||||
model_inputs['search_images'] = explore_result['search_images'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['search_anno'] = explore_result['search_anno'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['pre_seq'] = explore_result['pre_seq'][:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['x_feat'] = explore_result['x_feat'].squeeze(1)[:, cursor:cursor + bs_backward].cuda()
|
||||
model_inputs['epoch'] = data['epoch']
|
||||
# model_inputs['template_update'] = explore_result['template_update'].squeeze(1)[:,
|
||||
# cursor:cursor + bs_backward].cuda()
|
||||
# print("this is cursor")
|
||||
# print(explore_result['pre_seq'].shape)
|
||||
# print(explore_result['x_feat'].squeeze(1).shape)
|
||||
# model_inputs['action_tensor'] = explore_result['action_tensor'][:, cursor:cursor + bs_backward].cuda()
|
||||
# model_inputs['reward_tensor'] = reward_tensor[:, cursor:cursor + bs_backward].cuda()
|
||||
|
||||
loss, stats_cur = self.actor.compute_sequence_losses(model_inputs)
|
||||
# for name, param in self.actor.net.named_parameters():
|
||||
# shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
|
||||
# print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')
|
||||
# print("i make this!")
|
||||
loss.backward()
|
||||
# print("i made that?")
|
||||
|
||||
for key, val in stats_cur.items():
|
||||
if key in stats:
|
||||
stats[key] += val * (bs_backward / num_seq)
|
||||
else:
|
||||
stats[key] = val * (bs_backward / num_seq)
|
||||
cursor += bs_backward
|
||||
grad_norm = clip_grad_norm_(self.actor.net.parameters(), 100)
|
||||
stats['grad_norm'] = grad_norm
|
||||
# print(self.actor.net.module.backbone.blocks[8].mlp.fc1.weight)
|
||||
self.optimizer.step()
|
||||
# print(self.optimizer)
|
||||
|
||||
miou = np.mean(miou_record)
|
||||
self.miou_list.append(miou)
|
||||
# stats['reward'] = np.mean(reward_record)
|
||||
# stats['e_mIoU'] = np.mean(e_miou_record)
|
||||
stats['mIoU'] = miou
|
||||
stats['mIoU10'] = np.mean(self.miou_list[-10:])
|
||||
stats['mIoU100'] = np.mean(self.miou_list[-100:])
|
||||
|
||||
batch_size = num_seq * np.max(data['num_frames'])
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
self._print_stats(i, loader, batch_size)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# # forward pass
|
||||
# if not self.use_amp:
|
||||
# loss, stats = self.actor(data)
|
||||
# else:
|
||||
# with autocast():
|
||||
# loss, stats = self.actor(data)
|
||||
#
|
||||
# # backward pass and update weights
|
||||
# if loader.training:
|
||||
# self.optimizer.zero_grad()
|
||||
# if not self.use_amp:
|
||||
# loss.backward()
|
||||
# if self.settings.grad_clip_norm > 0:
|
||||
# torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
# self.optimizer.step()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
# self.scaler.step(self.optimizer)
|
||||
# self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
# batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
# self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
# self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
# if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
# epoch_time = self.prev_time - self.start_time
|
||||
# print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
# print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
# print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
# print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (
|
||||
self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
# def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
225
lib/train/trainers/ltr_trainer.py
Normal file
225
lib/train/trainers/ltr_trainer.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
|
||||
#from lib.train.data.wandb_logger import WandbWriter
|
||||
from lib.train.trainers import BaseTrainer
|
||||
from lib.train.admin import AverageMeter, StatValue
|
||||
#from lib.train.admin import TensorboardWriter
|
||||
import torch
|
||||
import time
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lib.utils.misc import get_world_size
|
||||
|
||||
|
||||
class LTRTrainer(BaseTrainer):
|
||||
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None, use_amp=False):
|
||||
"""
|
||||
args:
|
||||
actor - The actor for training the network
|
||||
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
||||
epoch for each loader.
|
||||
optimizer - The optimizer used for training, e.g. Adam
|
||||
settings - Training settings
|
||||
lr_scheduler - Learning rate scheduler
|
||||
"""
|
||||
super().__init__(actor, loaders, optimizer, settings, lr_scheduler)
|
||||
|
||||
self._set_default_settings()
|
||||
|
||||
# Initialize statistics variables
|
||||
self.stats = OrderedDict({loader.name: None for loader in self.loaders})
|
||||
|
||||
# Initialize tensorboard and wandb
|
||||
#self.wandb_writer = None
|
||||
#if settings.local_rank in [-1, 0]:
|
||||
# tensorboard_writer_dir = os.path.join(self.settings.env.tensorboard_dir, self.settings.project_path)
|
||||
# if not os.path.exists(tensorboard_writer_dir):
|
||||
# os.makedirs(tensorboard_writer_dir)
|
||||
# self.tensorboard_writer = TensorboardWriter(tensorboard_writer_dir, [l.name for l in loaders])
|
||||
|
||||
# if settings.use_wandb:
|
||||
# world_size = get_world_size()
|
||||
# cur_train_samples = self.loaders[0].dataset.samples_per_epoch * max(0, self.epoch - 1)
|
||||
# interval = (world_size * settings.batchsize) # * interval
|
||||
# self.wandb_writer = WandbWriter(settings.project_path[6:], {}, tensorboard_writer_dir, cur_train_samples, interval)
|
||||
|
||||
self.move_data_to_gpu = getattr(settings, 'move_data_to_gpu', True)
|
||||
print("move_data", self.move_data_to_gpu)
|
||||
self.settings = settings
|
||||
self.use_amp = use_amp
|
||||
if use_amp:
|
||||
self.scaler = GradScaler()
|
||||
|
||||
def _set_default_settings(self):
|
||||
# Dict of all default values
|
||||
default = {'print_interval': 10,
|
||||
'print_stats': None,
|
||||
'description': ''}
|
||||
|
||||
for param, default_value in default.items():
|
||||
if getattr(self.settings, param, None) is None:
|
||||
setattr(self.settings, param, default_value)
|
||||
|
||||
def cycle_dataset(self, loader):
|
||||
"""Do a cycle of training or validation."""
|
||||
|
||||
self.actor.train(loader.training)
|
||||
torch.set_grad_enabled(loader.training)
|
||||
|
||||
self._init_timing()
|
||||
|
||||
for i, data in enumerate(loader, 1):
|
||||
self.data_read_done_time = time.time()
|
||||
# get inputs
|
||||
if self.move_data_to_gpu:
|
||||
data = data.to(self.device)
|
||||
|
||||
self.data_to_gpu_time = time.time()
|
||||
|
||||
data['epoch'] = self.epoch
|
||||
data['settings'] = self.settings
|
||||
# forward pass
|
||||
if not self.use_amp:
|
||||
loss, stats = self.actor(data)
|
||||
else:
|
||||
with autocast():
|
||||
loss, stats = self.actor(data)
|
||||
|
||||
# backward pass and update weights
|
||||
if loader.training:
|
||||
self.optimizer.zero_grad()
|
||||
if not self.use_amp:
|
||||
loss.backward()
|
||||
if self.settings.grad_clip_norm > 0:
|
||||
torch.nn.utils.clip_grad_norm_(self.actor.net.parameters(), self.settings.grad_clip_norm)
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
# update statistics
|
||||
batch_size = data['template_images'].shape[loader.stack_dim]
|
||||
self._update_stats(stats, batch_size, loader)
|
||||
|
||||
# print statistics
|
||||
self._print_stats(i, loader, batch_size)
|
||||
|
||||
# update wandb status
|
||||
#if self.wandb_writer is not None and i % self.settings.print_interval == 0:
|
||||
# if self.settings.local_rank in [-1, 0]:
|
||||
# self.wandb_writer.write_log(self.stats, self.epoch)
|
||||
|
||||
# calculate ETA after every epoch
|
||||
epoch_time = self.prev_time - self.start_time
|
||||
print("Epoch Time: " + str(datetime.timedelta(seconds=epoch_time)))
|
||||
print("Avg Data Time: %.5f" % (self.avg_date_time / self.num_frames * batch_size))
|
||||
print("Avg GPU Trans Time: %.5f" % (self.avg_gpu_trans_time / self.num_frames * batch_size))
|
||||
print("Avg Forward Time: %.5f" % (self.avg_forward_time / self.num_frames * batch_size))
|
||||
|
||||
def train_epoch(self):
|
||||
"""Do one epoch for each loader."""
|
||||
for loader in self.loaders:
|
||||
if self.epoch % loader.epoch_interval == 0:
|
||||
# 2021.1.10 Set epoch
|
||||
if isinstance(loader.sampler, DistributedSampler):
|
||||
loader.sampler.set_epoch(self.epoch)
|
||||
self.cycle_dataset(loader)
|
||||
|
||||
self._stats_new_epoch()
|
||||
#if self.settings.local_rank in [-1, 0]:
|
||||
# self._write_tensorboard()
|
||||
|
||||
def _init_timing(self):
|
||||
self.num_frames = 0
|
||||
self.start_time = time.time()
|
||||
self.prev_time = self.start_time
|
||||
self.avg_date_time = 0
|
||||
self.avg_gpu_trans_time = 0
|
||||
self.avg_forward_time = 0
|
||||
|
||||
def _update_stats(self, new_stats: OrderedDict, batch_size, loader):
|
||||
# Initialize stats if not initialized yet
|
||||
if loader.name not in self.stats.keys() or self.stats[loader.name] is None:
|
||||
self.stats[loader.name] = OrderedDict({name: AverageMeter() for name in new_stats.keys()})
|
||||
|
||||
# add lr state
|
||||
if loader.training:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for name, val in new_stats.items():
|
||||
if name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][name] = AverageMeter()
|
||||
self.stats[loader.name][name].update(val, batch_size)
|
||||
|
||||
def _print_stats(self, i, loader, batch_size):
|
||||
self.num_frames += batch_size
|
||||
current_time = time.time()
|
||||
batch_fps = batch_size / (current_time - self.prev_time)
|
||||
average_fps = self.num_frames / (current_time - self.start_time)
|
||||
prev_frame_time_backup = self.prev_time
|
||||
self.prev_time = current_time
|
||||
|
||||
self.avg_date_time += (self.data_read_done_time - prev_frame_time_backup)
|
||||
self.avg_gpu_trans_time += (self.data_to_gpu_time - self.data_read_done_time)
|
||||
self.avg_forward_time += current_time - self.data_to_gpu_time
|
||||
|
||||
if i % self.settings.print_interval == 0 or i == loader.__len__():
|
||||
print_str = '[%s: %d, %d / %d] ' % (loader.name, self.epoch, i, loader.__len__())
|
||||
print_str += 'FPS: %.1f (%.1f) , ' % (average_fps, batch_fps)
|
||||
|
||||
# 2021.12.14 add data time print
|
||||
print_str += 'DataTime: %.3f (%.3f) , ' % (self.avg_date_time / self.num_frames * batch_size, self.avg_gpu_trans_time / self.num_frames * batch_size)
|
||||
print_str += 'ForwardTime: %.3f , ' % (self.avg_forward_time / self.num_frames * batch_size)
|
||||
print_str += 'TotalTime: %.3f , ' % ((current_time - self.start_time) / self.num_frames * batch_size)
|
||||
# print_str += 'DataTime: %.3f (%.3f) , ' % (self.data_read_done_time - prev_frame_time_backup, self.data_to_gpu_time - self.data_read_done_time)
|
||||
# print_str += 'ForwardTime: %.3f , ' % (current_time - self.data_to_gpu_time)
|
||||
# print_str += 'TotalTime: %.3f , ' % (current_time - prev_frame_time_backup)
|
||||
|
||||
for name, val in self.stats[loader.name].items():
|
||||
if (self.settings.print_stats is None or name in self.settings.print_stats):
|
||||
if hasattr(val, 'avg'):
|
||||
print_str += '%s: %.5f , ' % (name, val.avg)
|
||||
# else:
|
||||
# print_str += '%s: %r , ' % (name, val)
|
||||
|
||||
print(print_str[:-5])
|
||||
log_str = print_str[:-5] + '\n'
|
||||
with open(self.settings.log_file, 'a') as f:
|
||||
f.write(log_str)
|
||||
|
||||
def _stats_new_epoch(self):
|
||||
# Record learning rate
|
||||
for loader in self.loaders:
|
||||
if loader.training:
|
||||
try:
|
||||
lr_list = self.lr_scheduler.get_last_lr()
|
||||
except:
|
||||
lr_list = self.lr_scheduler._get_lr(self.epoch)
|
||||
for i, lr in enumerate(lr_list):
|
||||
var_name = 'LearningRate/group{}'.format(i)
|
||||
if var_name not in self.stats[loader.name].keys():
|
||||
self.stats[loader.name][var_name] = StatValue()
|
||||
self.stats[loader.name][var_name].update(lr)
|
||||
|
||||
for loader_stats in self.stats.values():
|
||||
if loader_stats is None:
|
||||
continue
|
||||
for stat_value in loader_stats.values():
|
||||
if hasattr(stat_value, 'new_epoch'):
|
||||
stat_value.new_epoch()
|
||||
|
||||
#def _write_tensorboard(self):
|
||||
# if self.epoch == 1:
|
||||
# self.tensorboard_writer.write_info(self.settings.script_name, self.settings.description)
|
||||
|
||||
# self.tensorboard_writer.write_epoch(self.stats, self.epoch)
|
Reference in New Issue
Block a user