init commit of samurai
This commit is contained in:
3
lib/train/admin/__init__.py
Normal file
3
lib/train/admin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .environment import env_settings, create_default_local_file_ITP_train
|
||||
from .stats import AverageMeter, StatValue
|
||||
#from .tensorboard import TensorboardWriter
|
102
lib/train/admin/environment.py
Normal file
102
lib/train/admin/environment.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_default_local_file():
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': empty_str,
|
||||
'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
|
||||
'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
|
||||
'lasot_dir': empty_str,
|
||||
'got10k_dir': empty_str,
|
||||
'trackingnet_dir': empty_str,
|
||||
'coco_dir': empty_str,
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': empty_str,
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def create_default_local_file_ITP_train(workspace_dir, data_dir):
|
||||
path = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
empty_str = '\'\''
|
||||
default_settings = OrderedDict({
|
||||
'workspace_dir': workspace_dir,
|
||||
'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files.
|
||||
'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'),
|
||||
'lasot_dir': os.path.join(data_dir, 'lasot'),
|
||||
'got10k_dir': os.path.join(data_dir, 'got10k/train'),
|
||||
'got10k_val_dir': os.path.join(data_dir, 'got10k/val'),
|
||||
'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'),
|
||||
'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'),
|
||||
'trackingnet_dir': os.path.join(data_dir, 'trackingnet'),
|
||||
'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'),
|
||||
'coco_dir': os.path.join(data_dir, 'coco'),
|
||||
'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'),
|
||||
'lvis_dir': empty_str,
|
||||
'sbd_dir': empty_str,
|
||||
'imagenet_dir': os.path.join(data_dir, 'vid'),
|
||||
'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'),
|
||||
'imagenetdet_dir': empty_str,
|
||||
'ecssd_dir': empty_str,
|
||||
'hkuis_dir': empty_str,
|
||||
'msra10k_dir': empty_str,
|
||||
'davis_dir': empty_str,
|
||||
'youtubevos_dir': empty_str})
|
||||
|
||||
comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
|
||||
'tensorboard_dir': 'Directory for tensorboard files.'}
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write('class EnvironmentSettings:\n')
|
||||
f.write(' def __init__(self):\n')
|
||||
|
||||
for attr, attr_val in default_settings.items():
|
||||
comment_str = None
|
||||
if attr in comment:
|
||||
comment_str = comment[attr]
|
||||
if comment_str is None:
|
||||
if attr_val == empty_str:
|
||||
f.write(' self.{} = {}\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\'\n'.format(attr, attr_val))
|
||||
else:
|
||||
f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
|
||||
|
||||
|
||||
def env_settings():
|
||||
env_module_name = 'lib.train.admin.local'
|
||||
try:
|
||||
env_module = importlib.import_module(env_module_name)
|
||||
return env_module.EnvironmentSettings()
|
||||
except:
|
||||
env_file = os.path.join(os.path.dirname(__file__), 'local.py')
|
||||
|
||||
create_default_local_file()
|
||||
raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
|
24
lib/train/admin/local.py
Normal file
24
lib/train/admin/local.py
Normal file
@@ -0,0 +1,24 @@
|
||||
class EnvironmentSettings:
|
||||
def __init__(self):
|
||||
self.workspace_dir = '/home/baiyifan/code/2stage_update_intrain' # Base directory for saving network checkpoints.
|
||||
self.tensorboard_dir = '/home/baiyifan/code/2stage/tensorboard' # Directory for tensorboard files.
|
||||
self.pretrained_networks = '/home/baiyifan/code/2stage/pretrained_networks'
|
||||
self.lasot_dir = '/home/baiyifan/LaSOT/LaSOTBenchmark'
|
||||
self.got10k_dir = '/home/baiyifan/GOT-10k/train'
|
||||
self.got10k_val_dir = '/home/baiyifan/GOT-10k/val'
|
||||
self.lasot_lmdb_dir = '/home/baiyifan/code/2stage/data/lasot_lmdb'
|
||||
self.got10k_lmdb_dir = '/home/baiyifan/code/2stage/data/got10k_lmdb'
|
||||
self.trackingnet_dir = '/ssddata/TrackingNet/all_zip'
|
||||
self.trackingnet_lmdb_dir = '/home/baiyifan/code/2stage/data/trackingnet_lmdb'
|
||||
self.coco_dir = '/home/baiyifan/coco'
|
||||
self.coco_lmdb_dir = '/home/baiyifan/code/2stage/data/coco_lmdb'
|
||||
self.lvis_dir = ''
|
||||
self.sbd_dir = ''
|
||||
self.imagenet_dir = '/home/baiyifan/code/2stage/data/vid'
|
||||
self.imagenet_lmdb_dir = '/home/baiyifan/code/2stage/data/vid_lmdb'
|
||||
self.imagenetdet_dir = ''
|
||||
self.ecssd_dir = ''
|
||||
self.hkuis_dir = ''
|
||||
self.msra10k_dir = ''
|
||||
self.davis_dir = ''
|
||||
self.youtubevos_dir = ''
|
15
lib/train/admin/multigpu.py
Normal file
15
lib/train/admin/multigpu.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||||
|
||||
|
||||
def is_multi_gpu(net):
|
||||
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||||
|
||||
|
||||
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
pass
|
||||
return getattr(self.module, item)
|
13
lib/train/admin/settings.py
Normal file
13
lib/train/admin/settings.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from lib.train.admin.environment import env_settings
|
||||
|
||||
|
||||
class Settings:
|
||||
""" Training settings, e.g. the paths to datasets and networks."""
|
||||
def __init__(self):
|
||||
self.set_default()
|
||||
|
||||
def set_default(self):
|
||||
self.env = env_settings()
|
||||
self.use_gpu = True
|
||||
|
||||
|
71
lib/train/admin/stats.py
Normal file
71
lib/train/admin/stats.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
|
||||
class StatValue:
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
self.history.append(self.val)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.clear()
|
||||
self.has_new_data = False
|
||||
|
||||
def reset(self):
|
||||
self.avg = 0
|
||||
self.val = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def clear(self):
|
||||
self.reset()
|
||||
self.history = []
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def new_epoch(self):
|
||||
if self.count > 0:
|
||||
self.history.append(self.avg)
|
||||
self.reset()
|
||||
self.has_new_data = True
|
||||
else:
|
||||
self.has_new_data = False
|
||||
|
||||
|
||||
def topk_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
single_input = not isinstance(topk, (tuple, list))
|
||||
if single_input:
|
||||
topk = (topk,)
|
||||
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
|
||||
res.append(correct_k * 100.0 / batch_size)
|
||||
|
||||
if single_input:
|
||||
return res[0]
|
||||
|
||||
return res
|
27
lib/train/admin/tensorboard.py
Normal file
27
lib/train/admin/tensorboard.py
Normal file
@@ -0,0 +1,27 @@
|
||||
#import os
|
||||
#from collections import OrderedDict
|
||||
#try:
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
#except:
|
||||
# print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
#class TensorboardWriter:
|
||||
# def __init__(self, directory, loader_names):
|
||||
# self.directory = directory
|
||||
# self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
|
||||
|
||||
# def write_info(self, script_name, description):
|
||||
# tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
|
||||
# tb_info_writer.add_text('Script_name', script_name)
|
||||
# tb_info_writer.add_text('Description', description)
|
||||
# tb_info_writer.close()
|
||||
|
||||
# def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
|
||||
# for loader_name, loader_stats in stats.items():
|
||||
# if loader_stats is None:
|
||||
# continue
|
||||
# for var_name, val in loader_stats.items():
|
||||
# if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
|
||||
# self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
|
Reference in New Issue
Block a user