276 lines
12 KiB
Python
276 lines
12 KiB
Python
import os
|
|
import glob
|
|
import torch
|
|
import traceback
|
|
from lib.train.admin import multigpu
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
|
|
class BaseTrainer:
|
|
"""Base trainer class. Contains functions for training and saving/loading checkpoints.
|
|
Trainer classes should inherit from this one and overload the train_epoch function."""
|
|
|
|
def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
|
|
"""
|
|
args:
|
|
actor - The actor for training the network
|
|
loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
|
|
epoch for each loader.
|
|
optimizer - The optimizer used for training, e.g. Adam
|
|
settings - Training settings
|
|
lr_scheduler - Learning rate scheduler
|
|
"""
|
|
self.actor = actor
|
|
self.optimizer = optimizer
|
|
self.lr_scheduler = lr_scheduler
|
|
self.loaders = loaders
|
|
|
|
self.update_settings(settings)
|
|
|
|
self.epoch = 0
|
|
self.stats = {}
|
|
|
|
self.device = getattr(settings, 'device', None)
|
|
if self.device is None:
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() and settings.use_gpu else "cpu")
|
|
|
|
self.actor.to(self.device)
|
|
self.settings = settings
|
|
|
|
def update_settings(self, settings=None):
|
|
"""Updates the trainer settings. Must be called to update internal settings."""
|
|
if settings is not None:
|
|
self.settings = settings
|
|
|
|
if self.settings.env.workspace_dir is not None:
|
|
self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
|
|
'''2021.1.4 New function: specify checkpoint dir'''
|
|
if self.settings.save_dir is None:
|
|
self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
|
|
else:
|
|
self._checkpoint_dir = os.path.join(self.settings.save_dir, 'checkpoints')
|
|
print("checkpoints will be saved to %s" % self._checkpoint_dir)
|
|
|
|
if self.settings.local_rank in [-1, 0]:
|
|
if not os.path.exists(self._checkpoint_dir):
|
|
print("Training with multiple GPUs. checkpoints directory doesn't exist. "
|
|
"Create checkpoints directory")
|
|
os.makedirs(self._checkpoint_dir)
|
|
else:
|
|
self._checkpoint_dir = None
|
|
|
|
def train(self, max_epochs, load_latest=False, fail_safe=True, load_previous_ckpt=False, distill=False):
|
|
"""Do training for the given number of epochs.
|
|
args:
|
|
max_epochs - Max number of training epochs,
|
|
load_latest - Bool indicating whether to resume from latest epoch.
|
|
fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
|
|
"""
|
|
|
|
epoch = -1
|
|
num_tries = 1
|
|
for i in range(num_tries):
|
|
try:
|
|
if load_latest:
|
|
self.load_checkpoint()
|
|
if load_previous_ckpt:
|
|
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_prv)
|
|
self.load_state_dict(directory)
|
|
if distill:
|
|
directory_teacher = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_teacher)
|
|
self.load_state_dict(directory_teacher, distill=True)
|
|
for epoch in range(self.epoch+1, max_epochs+1):
|
|
self.epoch = epoch
|
|
|
|
self.train_epoch()
|
|
|
|
if self.lr_scheduler is not None:
|
|
if self.settings.scheduler_type != 'cosine':
|
|
self.lr_scheduler.step()
|
|
else:
|
|
self.lr_scheduler.step(epoch - 1)
|
|
# only save the last 10 checkpoints
|
|
save_every_epoch = getattr(self.settings, "save_every_epoch", False)
|
|
save_epochs = []
|
|
if epoch > (max_epochs - 1) or save_every_epoch or epoch % 5 == 0 or epoch in save_epochs or epoch > (max_epochs - 5):
|
|
# if epoch > (max_epochs - 10) or save_every_epoch or epoch % 100 == 0:
|
|
if self._checkpoint_dir:
|
|
if self.settings.local_rank in [-1, 0]:
|
|
self.save_checkpoint()
|
|
except:
|
|
print('Training crashed at epoch {}'.format(epoch))
|
|
if fail_safe:
|
|
self.epoch -= 1
|
|
load_latest = True
|
|
print('Traceback for the error!')
|
|
print(traceback.format_exc())
|
|
print('Restarting training from last epoch ...')
|
|
else:
|
|
raise
|
|
|
|
print('Finished training!')
|
|
|
|
def train_epoch(self):
|
|
raise NotImplementedError
|
|
|
|
def save_checkpoint(self):
|
|
"""Saves a checkpoint of the network and other variables."""
|
|
|
|
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
|
|
|
actor_type = type(self.actor).__name__
|
|
net_type = type(net).__name__
|
|
state = {
|
|
'epoch': self.epoch,
|
|
'actor_type': actor_type,
|
|
'net_type': net_type,
|
|
'net': net.state_dict(),
|
|
'net_info': getattr(net, 'info', None),
|
|
'constructor': getattr(net, 'constructor', None),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'stats': self.stats,
|
|
'settings': self.settings
|
|
}
|
|
|
|
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
|
|
print(directory)
|
|
if not os.path.exists(directory):
|
|
print("directory doesn't exist. creating...")
|
|
os.makedirs(directory)
|
|
|
|
# First save as a tmp file
|
|
tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
|
|
torch.save(state, tmp_file_path)
|
|
|
|
file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
|
|
|
|
# Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
|
|
os.rename(tmp_file_path, file_path)
|
|
|
|
def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
|
|
"""Loads a network checkpoint file.
|
|
|
|
Can be called in three different ways:
|
|
load_checkpoint():
|
|
Loads the latest epoch from the workspace. Use this to continue training.
|
|
load_checkpoint(epoch_num):
|
|
Loads the network at the given epoch number (int).
|
|
load_checkpoint(path_to_checkpoint):
|
|
Loads the file from the given absolute path (str).
|
|
"""
|
|
|
|
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
|
|
|
actor_type = type(self.actor).__name__
|
|
net_type = type(net).__name__
|
|
|
|
if checkpoint is None:
|
|
# Load most recent checkpoint
|
|
checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
|
|
self.settings.project_path, net_type)))
|
|
if checkpoint_list:
|
|
checkpoint_path = checkpoint_list[-1]
|
|
else:
|
|
print('No matching checkpoint file found')
|
|
return
|
|
elif isinstance(checkpoint, int):
|
|
# Checkpoint is the epoch number
|
|
checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
|
|
net_type, checkpoint)
|
|
elif isinstance(checkpoint, str):
|
|
# checkpoint is the path
|
|
if os.path.isdir(checkpoint):
|
|
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
|
if checkpoint_list:
|
|
checkpoint_path = checkpoint_list[-1]
|
|
else:
|
|
raise Exception('No checkpoint found')
|
|
else:
|
|
checkpoint_path = os.path.expanduser(checkpoint)
|
|
else:
|
|
raise TypeError
|
|
|
|
# Load network
|
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
|
|
|
if fields is None:
|
|
fields = checkpoint_dict.keys()
|
|
if ignore_fields is None:
|
|
ignore_fields = ['settings']
|
|
|
|
# Never load the scheduler. It exists in older checkpoints.
|
|
ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])
|
|
|
|
# Load all fields
|
|
for key in fields:
|
|
if key in ignore_fields:
|
|
continue
|
|
if key == 'net':
|
|
net.load_state_dict(checkpoint_dict[key])
|
|
elif key == 'optimizer':
|
|
self.optimizer.load_state_dict(checkpoint_dict[key])
|
|
else:
|
|
setattr(self, key, checkpoint_dict[key])
|
|
|
|
# Set the net info
|
|
if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
|
|
net.constructor = checkpoint_dict['constructor']
|
|
if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
|
|
net.info = checkpoint_dict['net_info']
|
|
|
|
# Update the epoch in lr scheduler
|
|
if 'epoch' in fields:
|
|
self.lr_scheduler.last_epoch = self.epoch
|
|
# 2021.1.10 Update the epoch in data_samplers
|
|
for loader in self.loaders:
|
|
if isinstance(loader.sampler, DistributedSampler):
|
|
loader.sampler.set_epoch(self.epoch)
|
|
return True
|
|
|
|
def load_state_dict(self, checkpoint=None, distill=False):
|
|
"""Loads a network checkpoint file.
|
|
|
|
Can be called in three different ways:
|
|
load_checkpoint():
|
|
Loads the latest epoch from the workspace. Use this to continue training.
|
|
load_checkpoint(epoch_num):
|
|
Loads the network at the given epoch number (int).
|
|
load_checkpoint(path_to_checkpoint):
|
|
Loads the file from the given absolute path (str).
|
|
"""
|
|
if distill:
|
|
net = self.actor.net_teacher.module if multigpu.is_multi_gpu(self.actor.net_teacher) \
|
|
else self.actor.net_teacher
|
|
else:
|
|
net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
|
|
|
|
net_type = type(net).__name__
|
|
|
|
if isinstance(checkpoint, str):
|
|
# checkpoint is the path
|
|
if os.path.isdir(checkpoint):
|
|
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
|
|
if checkpoint_list:
|
|
checkpoint_path = checkpoint_list[-1]
|
|
else:
|
|
raise Exception('No checkpoint found')
|
|
else:
|
|
checkpoint_path = os.path.expanduser(checkpoint)
|
|
else:
|
|
raise TypeError
|
|
|
|
# Load network
|
|
print("Loading pretrained model from ", checkpoint_path)
|
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
|
|
|
|
missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
|
|
print("previous checkpoint is loaded.")
|
|
print("missing keys: ", missing_k)
|
|
print("unexpected keys:", unexpected_k)
|
|
|
|
return True
|