init commit of samurai
This commit is contained in:
203
lib/train/train_script.py
Normal file
203
lib/train/train_script.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
# loss function related
|
||||
from lib.utils.box_ops import giou_loss
|
||||
from torch.nn.functional import l1_loss
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
# train pipeline related
|
||||
from lib.train.trainers import LTRTrainer, LTRSeqTrainer
|
||||
from lib.train.dataset import Lasot, Got10k, MSCOCOSeq, ImagenetVID, TrackingNet
|
||||
from lib.train.dataset import Lasot_lmdb, Got10k_lmdb, MSCOCOSeq_lmdb, ImagenetVID_lmdb, TrackingNet_lmdb
|
||||
from lib.train.data import sampler, opencv_loader, processing, LTRLoader, sequence_sampler
|
||||
# distributed training related
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
# some more advanced functions
|
||||
from .base_functions import *
|
||||
# network related
|
||||
from lib.models.artrack import build_artrack
|
||||
from lib.models.artrack_seq import build_artrack_seq
|
||||
# forward propagation related
|
||||
from lib.train.actors import ARTrackActor, ARTrackSeqActor
|
||||
# for import modules
|
||||
import importlib
|
||||
|
||||
from ..utils.focal_loss import FocalLoss
|
||||
|
||||
def names2datasets(name_list: list, settings, image_loader):
|
||||
assert isinstance(name_list, list)
|
||||
datasets = []
|
||||
#settings.use_lmdb = True
|
||||
for name in name_list:
|
||||
assert name in ["LASOT", "GOT10K_vottrain", "GOT10K_votval", "GOT10K_train_full", "GOT10K_official_val",
|
||||
"COCO17", "VID", "TRACKINGNET"]
|
||||
if name == "LASOT":
|
||||
if settings.use_lmdb:
|
||||
print("Building lasot dataset from lmdb")
|
||||
datasets.append(Lasot_lmdb(settings.env.lasot_lmdb_dir, split='train', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Lasot(settings.env.lasot_dir, split='train', image_loader=image_loader))
|
||||
if name == "GOT10K_vottrain":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='vottrain', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='vottrain', image_loader=image_loader))
|
||||
if name == "GOT10K_train_full":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k_train_full from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='train_full', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='train_full', image_loader=image_loader))
|
||||
if name == "GOT10K_votval":
|
||||
if settings.use_lmdb:
|
||||
print("Building got10k from lmdb")
|
||||
datasets.append(Got10k_lmdb(settings.env.got10k_lmdb_dir, split='votval', image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_dir, split='votval', image_loader=image_loader))
|
||||
if name == "GOT10K_official_val":
|
||||
if settings.use_lmdb:
|
||||
raise ValueError("Not implement")
|
||||
else:
|
||||
datasets.append(Got10k(settings.env.got10k_val_dir, split=None, image_loader=image_loader))
|
||||
if name == "COCO17":
|
||||
if settings.use_lmdb:
|
||||
print("Building COCO2017 from lmdb")
|
||||
datasets.append(MSCOCOSeq_lmdb(settings.env.coco_lmdb_dir, version="2017", image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(MSCOCOSeq(settings.env.coco_dir, version="2017", image_loader=image_loader))
|
||||
if name == "VID":
|
||||
if settings.use_lmdb:
|
||||
print("Building VID from lmdb")
|
||||
datasets.append(ImagenetVID_lmdb(settings.env.imagenet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
datasets.append(ImagenetVID(settings.env.imagenet_dir, image_loader=image_loader))
|
||||
if name == "TRACKINGNET":
|
||||
if settings.use_lmdb:
|
||||
print("Building TrackingNet from lmdb")
|
||||
datasets.append(TrackingNet_lmdb(settings.env.trackingnet_lmdb_dir, image_loader=image_loader))
|
||||
else:
|
||||
# raise ValueError("NOW WE CAN ONLY USE TRACKINGNET FROM LMDB")
|
||||
datasets.append(TrackingNet(settings.env.trackingnet_dir, image_loader=image_loader))
|
||||
return datasets
|
||||
|
||||
def slt_collate(batch):
|
||||
ret = {}
|
||||
for k in batch[0].keys():
|
||||
here_list = []
|
||||
for ex in batch:
|
||||
here_list.append(ex[k])
|
||||
ret[k] = here_list
|
||||
return ret
|
||||
|
||||
class SLTLoader(torch.utils.data.dataloader.DataLoader):
|
||||
"""
|
||||
Data loader. Combines a dataset and a sampler, and provides
|
||||
single- or multi-process iterators over the dataset.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
||||
num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False,
|
||||
timeout=0, worker_init_fn=None):
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = slt_collate
|
||||
|
||||
super(SLTLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
|
||||
num_workers, collate_fn, pin_memory, drop_last,
|
||||
timeout, worker_init_fn)
|
||||
|
||||
self.name = name
|
||||
self.training = training
|
||||
self.epoch_interval = epoch_interval
|
||||
self.stack_dim = stack_dim
|
||||
|
||||
def run(settings):
|
||||
settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
|
||||
|
||||
# update the default configs with config file
|
||||
if not os.path.exists(settings.cfg_file):
|
||||
raise ValueError("%s doesn't exist." % settings.cfg_file)
|
||||
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
||||
cfg = config_module.cfg
|
||||
config_module.update_config_from_file(settings.cfg_file)
|
||||
if settings.local_rank in [-1, 0]:
|
||||
print("New configuration is shown below.")
|
||||
for key in cfg.keys():
|
||||
print("%s configuration:" % key, cfg[key])
|
||||
print('\n')
|
||||
|
||||
# update settings based on cfg
|
||||
update_settings(settings, cfg)
|
||||
|
||||
# Record the training log
|
||||
log_dir = os.path.join(settings.save_dir, 'logs')
|
||||
if settings.local_rank in [-1, 0]:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
||||
|
||||
# Build dataloaders
|
||||
|
||||
if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
|
||||
cfg.ckpt_dir = settings.save_dir
|
||||
bins = cfg.MODEL.BINS
|
||||
search_size = cfg.DATA.SEARCH.SIZE
|
||||
# Create network
|
||||
if settings.script_name == "artrack":
|
||||
net = build_artrack(cfg)
|
||||
loader_train, loader_val = build_dataloaders(cfg, settings)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
net = build_artrack_seq(cfg)
|
||||
dataset_train = sequence_sampler.SequenceSampler(
|
||||
datasets=names2datasets(cfg.DATA.TRAIN.DATASETS_NAME, settings, opencv_loader),
|
||||
p_datasets=cfg.DATA.TRAIN.DATASETS_RATIO,
|
||||
samples_per_epoch=cfg.DATA.TRAIN.SAMPLE_PER_EPOCH,
|
||||
max_gap=cfg.DATA.MAX_GAP, max_interval=cfg.DATA.MAX_INTERVAL,
|
||||
num_search_frames=cfg.DATA.SEARCH.NUMBER, num_template_frames=1,
|
||||
frame_sample_mode='random_interval',
|
||||
prob=cfg.DATA.INTERVAL_PROB)
|
||||
loader_train = SLTLoader('train', dataset_train, training=True, batch_size=cfg.TRAIN.BATCH_SIZE,
|
||||
num_workers=cfg.TRAIN.NUM_WORKER,
|
||||
shuffle=False, drop_last=True)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# wrap networks to distributed one
|
||||
net.cuda()
|
||||
if settings.local_rank != -1:
|
||||
# net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
|
||||
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
||||
settings.device = torch.device("cuda:%d" % settings.local_rank)
|
||||
else:
|
||||
settings.device = torch.device("cuda:0")
|
||||
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
||||
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
||||
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
|
||||
# Loss functions and Actors
|
||||
if settings.script_name == "artrack":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
focal_loss = FocalLoss()
|
||||
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss}
|
||||
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 2.}
|
||||
actor = ARTrackSeqActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg, bins=bins, search_size=search_size)
|
||||
else:
|
||||
raise ValueError("illegal script name")
|
||||
|
||||
# if cfg.TRAIN.DEEP_SUPERVISION:
|
||||
# raise ValueError("Deep supervision is not supported now.")
|
||||
|
||||
# Optimizer, parameters, and learning rates
|
||||
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
||||
use_amp = getattr(cfg.TRAIN, "AMP", False)
|
||||
if settings.script_name == "artrack":
|
||||
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
elif settings.script_name == "artrack_seq":
|
||||
trainer = LTRSeqTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
||||
|
||||
# train process
|
||||
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|
Reference in New Issue
Block a user