289 lines
8.6 KiB
Python
289 lines
8.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import re
|
|
from datetime import timedelta
|
|
from typing import Optional
|
|
|
|
import hydra
|
|
|
|
import numpy as np
|
|
import omegaconf
|
|
import torch
|
|
import torch.distributed as dist
|
|
from iopath.common.file_io import g_pathmgr
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
def multiply_all(*args):
|
|
return np.prod(np.array(args)).item()
|
|
|
|
|
|
def collect_dict_keys(config):
|
|
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
|
|
val_keys = []
|
|
# If the this config points to the collate function, then it has a key
|
|
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
|
|
val_keys.append(config["dict_key"])
|
|
else:
|
|
# Recursively proceed
|
|
for v in config.values():
|
|
if isinstance(v, type(config)):
|
|
val_keys.extend(collect_dict_keys(v))
|
|
elif isinstance(v, omegaconf.listconfig.ListConfig):
|
|
for item in v:
|
|
if isinstance(item, type(config)):
|
|
val_keys.extend(collect_dict_keys(item))
|
|
return val_keys
|
|
|
|
|
|
class Phase:
|
|
TRAIN = "train"
|
|
VAL = "val"
|
|
|
|
|
|
def register_omegaconf_resolvers():
|
|
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
|
|
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
|
|
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
|
|
OmegaConf.register_new_resolver("times", multiply_all)
|
|
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
|
|
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
|
|
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
|
|
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
|
|
OmegaConf.register_new_resolver("int", lambda x: int(x))
|
|
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
|
|
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
|
|
|
|
|
|
def setup_distributed_backend(backend, timeout_mins):
|
|
"""
|
|
Initialize torch.distributed and set the CUDA device.
|
|
Expects environment variables to be set as per
|
|
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
|
|
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
|
|
"""
|
|
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
|
|
# of waiting
|
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
|
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
|
|
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
|
|
return dist.get_rank()
|
|
|
|
|
|
def get_machine_local_and_dist_rank():
|
|
"""
|
|
Get the distributed and local rank of the current gpu.
|
|
"""
|
|
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
|
distributed_rank = int(os.environ.get("RANK", None))
|
|
assert (
|
|
local_rank is not None and distributed_rank is not None
|
|
), "Please the set the RANK and LOCAL_RANK environment variables."
|
|
return local_rank, distributed_rank
|
|
|
|
|
|
def print_cfg(cfg):
|
|
"""
|
|
Supports printing both Hydra DictConfig and also the AttrDict config
|
|
"""
|
|
logging.info("Training with config:")
|
|
logging.info(OmegaConf.to_yaml(cfg))
|
|
|
|
|
|
def set_seeds(seed_value, max_epochs, dist_rank):
|
|
"""
|
|
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
|
|
seeds if the CUDA is available. This ensures deterministic nature of the training.
|
|
"""
|
|
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
|
|
seed_value = (seed_value + dist_rank) * max_epochs
|
|
logging.info(f"MACHINE SEED: {seed_value}")
|
|
random.seed(seed_value)
|
|
np.random.seed(seed_value)
|
|
torch.manual_seed(seed_value)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed_value)
|
|
|
|
|
|
def makedir(dir_path):
|
|
"""
|
|
Create the directory if it does not exist.
|
|
"""
|
|
is_success = False
|
|
try:
|
|
if not g_pathmgr.exists(dir_path):
|
|
g_pathmgr.mkdirs(dir_path)
|
|
is_success = True
|
|
except BaseException:
|
|
logging.info(f"Error creating directory: {dir_path}")
|
|
return is_success
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
if not dist.is_available():
|
|
return False
|
|
if not dist.is_initialized():
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_amp_type(amp_type: Optional[str] = None):
|
|
if amp_type is None:
|
|
return None
|
|
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
|
|
if amp_type == "bfloat16":
|
|
return torch.bfloat16
|
|
else:
|
|
return torch.float16
|
|
|
|
|
|
def log_env_variables():
|
|
env_keys = sorted(list(os.environ.keys()))
|
|
st = ""
|
|
for k in env_keys:
|
|
v = os.environ[k]
|
|
st += f"{k}={v}\n"
|
|
logging.info("Logging ENV_VARIABLES")
|
|
logging.info(st)
|
|
|
|
|
|
class AverageMeter:
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, name, device, fmt=":f"):
|
|
self.name = name
|
|
self.fmt = fmt
|
|
self.device = device
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
self._allow_updates = True
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
def __str__(self):
|
|
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
|
|
class MemMeter:
|
|
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
|
|
|
|
def __init__(self, name, device, fmt=":f"):
|
|
self.name = name
|
|
self.fmt = fmt
|
|
self.device = device
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0 # Per iteration max usage
|
|
self.avg = 0 # Avg per iteration max usage
|
|
self.peak = 0 # Peak usage for lifetime of program
|
|
self.sum = 0
|
|
self.count = 0
|
|
self._allow_updates = True
|
|
|
|
def update(self, n=1, reset_peak_usage=True):
|
|
self.val = torch.cuda.max_memory_allocated() // 1e9
|
|
self.sum += self.val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
self.peak = max(self.peak, self.val)
|
|
if reset_peak_usage:
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
def __str__(self):
|
|
fmtstr = (
|
|
"{name}: {val"
|
|
+ self.fmt
|
|
+ "} ({avg"
|
|
+ self.fmt
|
|
+ "}/{peak"
|
|
+ self.fmt
|
|
+ "})"
|
|
)
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
|
|
def human_readable_time(time_seconds):
|
|
time = int(time_seconds)
|
|
minutes, seconds = divmod(time, 60)
|
|
hours, minutes = divmod(minutes, 60)
|
|
days, hours = divmod(hours, 24)
|
|
return f"{days:02}d {hours:02}h {minutes:02}m"
|
|
|
|
|
|
class DurationMeter:
|
|
def __init__(self, name, device, fmt=":f"):
|
|
self.name = name
|
|
self.device = device
|
|
self.fmt = fmt
|
|
self.val = 0
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
|
|
def update(self, val):
|
|
self.val = val
|
|
|
|
def add(self, val):
|
|
self.val += val
|
|
|
|
def __str__(self):
|
|
return f"{self.name}: {human_readable_time(self.val)}"
|
|
|
|
|
|
class ProgressMeter:
|
|
def __init__(self, num_batches, meters, real_meters, prefix=""):
|
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
|
self.meters = meters
|
|
self.real_meters = real_meters
|
|
self.prefix = prefix
|
|
|
|
def display(self, batch, enable_print=False):
|
|
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
|
entries += [str(meter) for meter in self.meters]
|
|
entries += [
|
|
" | ".join(
|
|
[
|
|
f"{os.path.join(name, subname)}: {val:.4f}"
|
|
for subname, val in meter.compute().items()
|
|
]
|
|
)
|
|
for name, meter in self.real_meters.items()
|
|
]
|
|
logging.info(" | ".join(entries))
|
|
if enable_print:
|
|
print(" | ".join(entries))
|
|
|
|
def _get_batch_fmtstr(self, num_batches):
|
|
num_digits = len(str(num_batches // 1))
|
|
fmt = "{:" + str(num_digits) + "d}"
|
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
|
|
|
|
|
def get_resume_checkpoint(checkpoint_save_dir):
|
|
if not g_pathmgr.isdir(checkpoint_save_dir):
|
|
return None
|
|
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
|
|
if not g_pathmgr.isfile(ckpt_file):
|
|
return None
|
|
|
|
return ckpt_file
|