1114 lines
40 KiB
Python
1114 lines
40 KiB
Python
![]() |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import gc
|
||
|
import json
|
||
|
import logging
|
||
|
import math
|
||
|
import os
|
||
|
import time
|
||
|
from collections import OrderedDict
|
||
|
from dataclasses import dataclass, field
|
||
|
from typing import Any, Dict, List, Mapping, Optional
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn as nn
|
||
|
from hydra.utils import instantiate
|
||
|
from iopath.common.file_io import g_pathmgr
|
||
|
|
||
|
from training.optimizer import construct_optimizer
|
||
|
|
||
|
from training.utils.checkpoint_utils import (
|
||
|
assert_skipped_parameters_are_frozen,
|
||
|
exclude_params_matching_unix_pattern,
|
||
|
load_state_dict_into_model,
|
||
|
with_check_parameter_frozen,
|
||
|
)
|
||
|
from training.utils.data_utils import BatchedVideoDatapoint
|
||
|
from training.utils.distributed import all_reduce_max, barrier, get_rank
|
||
|
|
||
|
from training.utils.logger import Logger, setup_logging
|
||
|
|
||
|
from training.utils.train_utils import (
|
||
|
AverageMeter,
|
||
|
collect_dict_keys,
|
||
|
DurationMeter,
|
||
|
get_amp_type,
|
||
|
get_machine_local_and_dist_rank,
|
||
|
get_resume_checkpoint,
|
||
|
human_readable_time,
|
||
|
is_dist_avail_and_initialized,
|
||
|
log_env_variables,
|
||
|
makedir,
|
||
|
MemMeter,
|
||
|
Phase,
|
||
|
ProgressMeter,
|
||
|
set_seeds,
|
||
|
setup_distributed_backend,
|
||
|
)
|
||
|
|
||
|
|
||
|
CORE_LOSS_KEY = "core_loss"
|
||
|
|
||
|
|
||
|
def unwrap_ddp_if_wrapped(model):
|
||
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||
|
return model.module
|
||
|
return model
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OptimAMPConf:
|
||
|
enabled: bool = False
|
||
|
amp_dtype: str = "float16"
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OptimConf:
|
||
|
optimizer: torch.optim.Optimizer = None
|
||
|
options: Optional[Dict[str, Any]] = None
|
||
|
param_group_modifiers: Optional[List] = None
|
||
|
amp: Optional[Dict[str, Any]] = None
|
||
|
gradient_clip: Any = None
|
||
|
gradient_logger: Any = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
# amp
|
||
|
if not isinstance(self.amp, OptimAMPConf):
|
||
|
if self.amp is None:
|
||
|
self.amp = {}
|
||
|
assert isinstance(self.amp, Mapping)
|
||
|
self.amp = OptimAMPConf(**self.amp)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DistributedConf:
|
||
|
backend: Optional[str] = None # inferred from accelerator type
|
||
|
comms_dtype: Optional[str] = None
|
||
|
find_unused_parameters: bool = False
|
||
|
timeout_mins: int = 30
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class CudaConf:
|
||
|
cudnn_deterministic: bool = False
|
||
|
cudnn_benchmark: bool = True
|
||
|
allow_tf32: bool = False
|
||
|
# if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
|
||
|
matmul_allow_tf32: Optional[bool] = None
|
||
|
# if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
|
||
|
cudnn_allow_tf32: Optional[bool] = None
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class CheckpointConf:
|
||
|
save_dir: str
|
||
|
save_freq: int
|
||
|
save_list: List[int] = field(default_factory=list)
|
||
|
model_weight_initializer: Any = None
|
||
|
save_best_meters: List[str] = None
|
||
|
skip_saving_parameters: List[str] = field(default_factory=list)
|
||
|
initialize_after_preemption: Optional[bool] = None
|
||
|
# if not None, training will be resumed from this checkpoint
|
||
|
resume_from: Optional[str] = None
|
||
|
|
||
|
def infer_missing(self):
|
||
|
if self.initialize_after_preemption is None:
|
||
|
with_skip_saving = len(self.skip_saving_parameters) > 0
|
||
|
self.initialize_after_preemption = with_skip_saving
|
||
|
return self
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class LoggingConf:
|
||
|
log_dir: str
|
||
|
log_freq: int # In iterations
|
||
|
tensorboard_writer: Any
|
||
|
log_level_primary: str = "INFO"
|
||
|
log_level_secondary: str = "ERROR"
|
||
|
log_scalar_frequency: int = 100
|
||
|
log_visual_frequency: int = 100
|
||
|
scalar_keys_to_log: Optional[Dict[str, Any]] = None
|
||
|
log_batch_stats: bool = False
|
||
|
|
||
|
|
||
|
class Trainer:
|
||
|
"""
|
||
|
Trainer supporting the DDP training strategies.
|
||
|
"""
|
||
|
|
||
|
EPSILON = 1e-8
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*, # the order of these args can change at any time, so they are keyword-only
|
||
|
data: Dict[str, Any],
|
||
|
model: Dict[str, Any],
|
||
|
logging: Dict[str, Any],
|
||
|
checkpoint: Dict[str, Any],
|
||
|
max_epochs: int,
|
||
|
mode: str = "train",
|
||
|
accelerator: str = "cuda",
|
||
|
seed_value: int = 123,
|
||
|
val_epoch_freq: int = 1,
|
||
|
distributed: Dict[str, bool] = None,
|
||
|
cuda: Dict[str, bool] = None,
|
||
|
env_variables: Optional[Dict[str, Any]] = None,
|
||
|
optim: Optional[Dict[str, Any]] = None,
|
||
|
optim_overrides: Optional[List[Dict[str, Any]]] = None,
|
||
|
meters: Optional[Dict[str, Any]] = None,
|
||
|
loss: Optional[Dict[str, Any]] = None,
|
||
|
):
|
||
|
|
||
|
self._setup_env_variables(env_variables)
|
||
|
self._setup_timers()
|
||
|
|
||
|
self.data_conf = data
|
||
|
self.model_conf = model
|
||
|
self.logging_conf = LoggingConf(**logging)
|
||
|
self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
|
||
|
self.max_epochs = max_epochs
|
||
|
self.mode = mode
|
||
|
self.val_epoch_freq = val_epoch_freq
|
||
|
self.optim_conf = OptimConf(**optim) if optim is not None else None
|
||
|
self.meters_conf = meters
|
||
|
self.loss_conf = loss
|
||
|
distributed = DistributedConf(**distributed or {})
|
||
|
cuda = CudaConf(**cuda or {})
|
||
|
self.where = 0.0
|
||
|
|
||
|
self._infer_distributed_backend_if_none(distributed, accelerator)
|
||
|
|
||
|
self._setup_device(accelerator)
|
||
|
|
||
|
self._setup_torch_dist_and_backend(cuda, distributed)
|
||
|
|
||
|
makedir(self.logging_conf.log_dir)
|
||
|
setup_logging(
|
||
|
__name__,
|
||
|
output_dir=self.logging_conf.log_dir,
|
||
|
rank=self.rank,
|
||
|
log_level_primary=self.logging_conf.log_level_primary,
|
||
|
log_level_secondary=self.logging_conf.log_level_secondary,
|
||
|
)
|
||
|
|
||
|
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
||
|
log_env_variables()
|
||
|
|
||
|
assert (
|
||
|
is_dist_avail_and_initialized()
|
||
|
), "Torch distributed needs to be initialized before calling the trainer."
|
||
|
|
||
|
self._setup_components() # Except Optimizer everything is setup here.
|
||
|
self._move_to_device()
|
||
|
self._construct_optimizers()
|
||
|
self._setup_dataloaders()
|
||
|
|
||
|
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
||
|
|
||
|
if self.checkpoint_conf.resume_from is not None:
|
||
|
assert os.path.exists(
|
||
|
self.checkpoint_conf.resume_from
|
||
|
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||
|
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
||
|
if self.distributed_rank == 0 and not os.path.exists(dst):
|
||
|
# Copy the "resume_from" checkpoint to the checkpoint folder
|
||
|
# if there is not a checkpoint to resume from already there
|
||
|
makedir(self.checkpoint_conf.save_dir)
|
||
|
g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
|
||
|
barrier()
|
||
|
|
||
|
self.load_checkpoint()
|
||
|
self._setup_ddp_distributed_training(distributed, accelerator)
|
||
|
barrier()
|
||
|
|
||
|
def _setup_timers(self):
|
||
|
"""
|
||
|
Initializes counters for elapsed time and eta.
|
||
|
"""
|
||
|
self.start_time = time.time()
|
||
|
self.ckpt_time_elapsed = 0
|
||
|
self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
|
||
|
|
||
|
def _get_meters(self, phase_filters=None):
|
||
|
if self.meters is None:
|
||
|
return {}
|
||
|
meters = {}
|
||
|
for phase, phase_meters in self.meters.items():
|
||
|
if phase_filters is not None and phase not in phase_filters:
|
||
|
continue
|
||
|
for key, key_meters in phase_meters.items():
|
||
|
if key_meters is None:
|
||
|
continue
|
||
|
for name, meter in key_meters.items():
|
||
|
meters[f"{phase}_{key}/{name}"] = meter
|
||
|
return meters
|
||
|
|
||
|
def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
|
||
|
if distributed_conf.backend is None:
|
||
|
distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
|
||
|
|
||
|
def _setup_env_variables(self, env_variables_conf) -> None:
|
||
|
if env_variables_conf is not None:
|
||
|
for variable_name, value in env_variables_conf.items():
|
||
|
os.environ[variable_name] = value
|
||
|
|
||
|
def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
|
||
|
if torch.cuda.is_available():
|
||
|
torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
|
||
|
torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
|
||
|
torch.backends.cuda.matmul.allow_tf32 = (
|
||
|
cuda_conf.matmul_allow_tf32
|
||
|
if cuda_conf.matmul_allow_tf32 is not None
|
||
|
else cuda_conf.allow_tf32
|
||
|
)
|
||
|
torch.backends.cudnn.allow_tf32 = (
|
||
|
cuda_conf.cudnn_allow_tf32
|
||
|
if cuda_conf.cudnn_allow_tf32 is not None
|
||
|
else cuda_conf.allow_tf32
|
||
|
)
|
||
|
|
||
|
self.rank = setup_distributed_backend(
|
||
|
distributed_conf.backend, distributed_conf.timeout_mins
|
||
|
)
|
||
|
|
||
|
def _setup_device(self, accelerator):
|
||
|
self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
|
||
|
if accelerator == "cuda":
|
||
|
self.device = torch.device("cuda", self.local_rank)
|
||
|
torch.cuda.set_device(self.local_rank)
|
||
|
elif accelerator == "cpu":
|
||
|
self.device = torch.device("cpu")
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported accelerator: {accelerator}")
|
||
|
|
||
|
def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
|
||
|
|
||
|
assert isinstance(self.model, torch.nn.Module)
|
||
|
|
||
|
self.model = nn.parallel.DistributedDataParallel(
|
||
|
self.model,
|
||
|
device_ids=[self.local_rank] if accelerator == "cuda" else [],
|
||
|
find_unused_parameters=distributed_conf.find_unused_parameters,
|
||
|
)
|
||
|
if distributed_conf.comms_dtype is not None: # noqa
|
||
|
from torch.distributed.algorithms import ddp_comm_hooks
|
||
|
|
||
|
amp_type = get_amp_type(distributed_conf.comms_dtype)
|
||
|
if amp_type == torch.bfloat16:
|
||
|
hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
|
||
|
logging.info("Enabling bfloat16 grad communication")
|
||
|
else:
|
||
|
hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
|
||
|
logging.info("Enabling fp16 grad communication")
|
||
|
process_group = None
|
||
|
self.model.register_comm_hook(process_group, hook)
|
||
|
|
||
|
def _move_to_device(self):
|
||
|
logging.info(
|
||
|
f"Moving components to device {self.device} and local rank {self.local_rank}."
|
||
|
)
|
||
|
|
||
|
self.model.to(self.device)
|
||
|
|
||
|
logging.info(
|
||
|
f"Done moving components to device {self.device} and local rank {self.local_rank}."
|
||
|
)
|
||
|
|
||
|
def save_checkpoint(self, epoch, checkpoint_names=None):
|
||
|
checkpoint_folder = self.checkpoint_conf.save_dir
|
||
|
makedir(checkpoint_folder)
|
||
|
if checkpoint_names is None:
|
||
|
checkpoint_names = ["checkpoint"]
|
||
|
if (
|
||
|
self.checkpoint_conf.save_freq > 0
|
||
|
and (int(epoch) % self.checkpoint_conf.save_freq == 0)
|
||
|
) or int(epoch) in self.checkpoint_conf.save_list:
|
||
|
checkpoint_names.append(f"checkpoint_{int(epoch)}")
|
||
|
|
||
|
checkpoint_paths = []
|
||
|
for ckpt_name in checkpoint_names:
|
||
|
checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
|
||
|
|
||
|
state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
|
||
|
state_dict = exclude_params_matching_unix_pattern(
|
||
|
patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
|
||
|
)
|
||
|
|
||
|
checkpoint = {
|
||
|
"model": state_dict,
|
||
|
"optimizer": self.optim.optimizer.state_dict(),
|
||
|
"epoch": epoch,
|
||
|
"loss": self.loss.state_dict(),
|
||
|
"steps": self.steps,
|
||
|
"time_elapsed": self.time_elapsed_meter.val,
|
||
|
"best_meter_values": self.best_meter_values,
|
||
|
}
|
||
|
if self.optim_conf.amp.enabled:
|
||
|
checkpoint["scaler"] = self.scaler.state_dict()
|
||
|
|
||
|
# DDP checkpoints are only saved on rank 0 (all workers are identical)
|
||
|
if self.distributed_rank != 0:
|
||
|
return
|
||
|
|
||
|
for checkpoint_path in checkpoint_paths:
|
||
|
self._save_checkpoint(checkpoint, checkpoint_path)
|
||
|
|
||
|
def _save_checkpoint(self, checkpoint, checkpoint_path):
|
||
|
"""
|
||
|
Save a checkpoint while guarding against the job being killed in the middle
|
||
|
of checkpoint saving (which corrupts the checkpoint file and ruins the
|
||
|
entire training since usually only the last checkpoint is kept per run).
|
||
|
|
||
|
We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
|
||
|
and move it to overwrite the old checkpoint_path.
|
||
|
"""
|
||
|
checkpoint_path_tmp = f"{checkpoint_path}.tmp"
|
||
|
with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
|
||
|
torch.save(checkpoint, f)
|
||
|
# after torch.save is completed, replace the old checkpoint with the new one
|
||
|
if g_pathmgr.exists(checkpoint_path):
|
||
|
# remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
|
||
|
g_pathmgr.rm(checkpoint_path)
|
||
|
success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
|
||
|
assert success
|
||
|
|
||
|
def load_checkpoint(self):
|
||
|
ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
|
||
|
if ckpt_path is None:
|
||
|
self._init_model_state()
|
||
|
else:
|
||
|
if self.checkpoint_conf.initialize_after_preemption:
|
||
|
self._call_model_initializer()
|
||
|
self._load_resuming_checkpoint(ckpt_path)
|
||
|
|
||
|
def _init_model_state(self):
|
||
|
# Checking that parameters that won't be saved are indeed frozen
|
||
|
# We do this check here before even saving the model to catch errors
|
||
|
# are early as possible and not at the end of the first epoch
|
||
|
assert_skipped_parameters_are_frozen(
|
||
|
patterns=self.checkpoint_conf.skip_saving_parameters,
|
||
|
model=self.model,
|
||
|
)
|
||
|
|
||
|
# Checking that parameters that won't be saved are initialized from
|
||
|
# within the model definition, unless `initialize_after_preemption`
|
||
|
# is explicitly set to `True`. If not, this is a bug, and after
|
||
|
# preemption, the `skip_saving_parameters` will have random values
|
||
|
allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
|
||
|
with with_check_parameter_frozen(
|
||
|
patterns=self.checkpoint_conf.skip_saving_parameters,
|
||
|
model=self.model,
|
||
|
disabled=allow_init_skip_parameters,
|
||
|
):
|
||
|
self._call_model_initializer()
|
||
|
|
||
|
def _call_model_initializer(self):
|
||
|
model_weight_initializer = instantiate(
|
||
|
self.checkpoint_conf.model_weight_initializer
|
||
|
)
|
||
|
if model_weight_initializer is not None:
|
||
|
logging.info(
|
||
|
f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
|
||
|
)
|
||
|
self.model = model_weight_initializer(model=self.model)
|
||
|
|
||
|
def _load_resuming_checkpoint(self, ckpt_path: str):
|
||
|
logging.info(f"Resuming training from {ckpt_path}")
|
||
|
|
||
|
with g_pathmgr.open(ckpt_path, "rb") as f:
|
||
|
checkpoint = torch.load(f, map_location="cpu")
|
||
|
load_state_dict_into_model(
|
||
|
model=self.model,
|
||
|
state_dict=checkpoint["model"],
|
||
|
ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
|
||
|
)
|
||
|
|
||
|
self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
|
||
|
self.loss.load_state_dict(checkpoint["loss"], strict=True)
|
||
|
self.epoch = checkpoint["epoch"]
|
||
|
self.steps = checkpoint["steps"]
|
||
|
self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
|
||
|
|
||
|
if self.optim_conf.amp.enabled and "scaler" in checkpoint:
|
||
|
self.scaler.load_state_dict(checkpoint["scaler"])
|
||
|
|
||
|
self.best_meter_values = checkpoint.get("best_meter_values", {})
|
||
|
|
||
|
if "train_dataset" in checkpoint and self.train_dataset is not None:
|
||
|
self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
|
||
|
|
||
|
def is_intermediate_val_epoch(self, epoch):
|
||
|
return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1
|
||
|
|
||
|
def _step(
|
||
|
self,
|
||
|
batch: BatchedVideoDatapoint,
|
||
|
model: nn.Module,
|
||
|
phase: str,
|
||
|
):
|
||
|
|
||
|
outputs = model(batch)
|
||
|
targets = batch.masks
|
||
|
batch_size = len(batch.img_batch)
|
||
|
|
||
|
key = batch.dict_key # key for dataset
|
||
|
loss = self.loss[key](outputs, targets)
|
||
|
loss_str = f"Losses/{phase}_{key}_loss"
|
||
|
|
||
|
loss_log_str = os.path.join("Step_Losses", loss_str)
|
||
|
|
||
|
# loss contains multiple sub-components we wish to log
|
||
|
step_losses = {}
|
||
|
if isinstance(loss, dict):
|
||
|
step_losses.update(
|
||
|
{f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
|
||
|
)
|
||
|
loss = self._log_loss_detailed_and_return_core_loss(
|
||
|
loss, loss_log_str, self.steps[phase]
|
||
|
)
|
||
|
|
||
|
if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
|
||
|
self.logger.log(
|
||
|
loss_log_str,
|
||
|
loss,
|
||
|
self.steps[phase],
|
||
|
)
|
||
|
|
||
|
self.steps[phase] += 1
|
||
|
|
||
|
ret_tuple = {loss_str: loss}, batch_size, step_losses
|
||
|
|
||
|
if phase in self.meters and key in self.meters[phase]:
|
||
|
meters_dict = self.meters[phase][key]
|
||
|
if meters_dict is not None:
|
||
|
for _, meter in meters_dict.items():
|
||
|
meter.update(
|
||
|
find_stages=outputs,
|
||
|
find_metadatas=batch.metadata,
|
||
|
)
|
||
|
|
||
|
return ret_tuple
|
||
|
|
||
|
def run(self):
|
||
|
assert self.mode in ["train", "train_only", "val"]
|
||
|
if self.mode == "train":
|
||
|
if self.epoch > 0:
|
||
|
logging.info(f"Resuming training from epoch: {self.epoch}")
|
||
|
# resuming from a checkpoint
|
||
|
if self.is_intermediate_val_epoch(self.epoch - 1):
|
||
|
logging.info("Running previous val epoch")
|
||
|
self.epoch -= 1
|
||
|
self.run_val()
|
||
|
self.epoch += 1
|
||
|
self.run_train()
|
||
|
self.run_val()
|
||
|
elif self.mode == "val":
|
||
|
self.run_val()
|
||
|
elif self.mode == "train_only":
|
||
|
self.run_train()
|
||
|
|
||
|
def _setup_dataloaders(self):
|
||
|
self.train_dataset = None
|
||
|
self.val_dataset = None
|
||
|
|
||
|
if self.mode in ["train", "val"]:
|
||
|
self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
|
||
|
|
||
|
if self.mode in ["train", "train_only"]:
|
||
|
self.train_dataset = instantiate(self.data_conf.train)
|
||
|
|
||
|
def run_train(self):
|
||
|
|
||
|
while self.epoch < self.max_epochs:
|
||
|
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
|
||
|
barrier()
|
||
|
outs = self.train_epoch(dataloader)
|
||
|
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
|
||
|
|
||
|
# log train to text file.
|
||
|
if self.distributed_rank == 0:
|
||
|
with g_pathmgr.open(
|
||
|
os.path.join(self.logging_conf.log_dir, "train_stats.json"),
|
||
|
"a",
|
||
|
) as f:
|
||
|
f.write(json.dumps(outs) + "\n")
|
||
|
|
||
|
# Save checkpoint before validating
|
||
|
self.save_checkpoint(self.epoch + 1)
|
||
|
|
||
|
del dataloader
|
||
|
gc.collect()
|
||
|
|
||
|
# Run val, not running on last epoch since will run after the
|
||
|
# loop anyway
|
||
|
if self.is_intermediate_val_epoch(self.epoch):
|
||
|
self.run_val()
|
||
|
|
||
|
if self.distributed_rank == 0:
|
||
|
self.best_meter_values.update(self._get_trainer_state("train"))
|
||
|
with g_pathmgr.open(
|
||
|
os.path.join(self.logging_conf.log_dir, "best_stats.json"),
|
||
|
"a",
|
||
|
) as f:
|
||
|
f.write(json.dumps(self.best_meter_values) + "\n")
|
||
|
|
||
|
self.epoch += 1
|
||
|
# epoch was incremented in the loop but the val step runs out of the loop
|
||
|
self.epoch -= 1
|
||
|
|
||
|
def run_val(self):
|
||
|
if not self.val_dataset:
|
||
|
return
|
||
|
|
||
|
dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
|
||
|
outs = self.val_epoch(dataloader, phase=Phase.VAL)
|
||
|
del dataloader
|
||
|
gc.collect()
|
||
|
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
|
||
|
|
||
|
if self.distributed_rank == 0:
|
||
|
with g_pathmgr.open(
|
||
|
os.path.join(self.logging_conf.log_dir, "val_stats.json"),
|
||
|
"a",
|
||
|
) as f:
|
||
|
f.write(json.dumps(outs) + "\n")
|
||
|
|
||
|
def val_epoch(self, val_loader, phase):
|
||
|
batch_time = AverageMeter("Batch Time", self.device, ":.2f")
|
||
|
data_time = AverageMeter("Data Time", self.device, ":.2f")
|
||
|
mem = MemMeter("Mem (GB)", self.device, ":.2f")
|
||
|
|
||
|
iters_per_epoch = len(val_loader)
|
||
|
|
||
|
curr_phases = [phase]
|
||
|
curr_models = [self.model]
|
||
|
|
||
|
loss_names = []
|
||
|
for p in curr_phases:
|
||
|
for key in self.loss.keys():
|
||
|
loss_names.append(f"Losses/{p}_{key}_loss")
|
||
|
|
||
|
loss_mts = OrderedDict(
|
||
|
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
|
||
|
)
|
||
|
extra_loss_mts = {}
|
||
|
|
||
|
for model in curr_models:
|
||
|
model.eval()
|
||
|
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
|
||
|
unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
|
||
|
|
||
|
progress = ProgressMeter(
|
||
|
iters_per_epoch,
|
||
|
[batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
|
||
|
self._get_meters(curr_phases),
|
||
|
prefix="Val Epoch: [{}]".format(self.epoch),
|
||
|
)
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
for data_iter, batch in enumerate(val_loader):
|
||
|
|
||
|
# measure data loading time
|
||
|
data_time.update(time.time() - end)
|
||
|
|
||
|
batch = batch.to(self.device, non_blocking=True)
|
||
|
|
||
|
# compute output
|
||
|
with torch.no_grad():
|
||
|
with torch.cuda.amp.autocast(
|
||
|
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
|
||
|
dtype=(
|
||
|
get_amp_type(self.optim_conf.amp.amp_dtype)
|
||
|
if self.optim_conf
|
||
|
else None
|
||
|
),
|
||
|
):
|
||
|
for phase, model in zip(curr_phases, curr_models):
|
||
|
loss_dict, batch_size, extra_losses = self._step(
|
||
|
batch,
|
||
|
model,
|
||
|
phase,
|
||
|
)
|
||
|
|
||
|
assert len(loss_dict) == 1
|
||
|
loss_key, loss = loss_dict.popitem()
|
||
|
|
||
|
loss_mts[loss_key].update(loss.item(), batch_size)
|
||
|
|
||
|
for k, v in extra_losses.items():
|
||
|
if k not in extra_loss_mts:
|
||
|
extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
|
||
|
extra_loss_mts[k].update(v.item(), batch_size)
|
||
|
|
||
|
# measure elapsed time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
self.time_elapsed_meter.update(
|
||
|
time.time() - self.start_time + self.ckpt_time_elapsed
|
||
|
)
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
mem.update(reset_peak_usage=True)
|
||
|
|
||
|
if data_iter % self.logging_conf.log_freq == 0:
|
||
|
progress.display(data_iter)
|
||
|
|
||
|
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
||
|
# Log progress meters.
|
||
|
for progress_meter in progress.meters:
|
||
|
self.logger.log(
|
||
|
os.path.join("Step_Stats", phase, progress_meter.name),
|
||
|
progress_meter.val,
|
||
|
self.steps[Phase.VAL],
|
||
|
)
|
||
|
|
||
|
if data_iter % 10 == 0:
|
||
|
dist.barrier()
|
||
|
|
||
|
self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
|
||
|
self._log_timers(phase)
|
||
|
for model in curr_models:
|
||
|
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
|
||
|
unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
|
||
|
|
||
|
out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
|
||
|
|
||
|
for k, v in loss_mts.items():
|
||
|
out_dict[k] = v.avg
|
||
|
for k, v in extra_loss_mts.items():
|
||
|
out_dict[k] = v.avg
|
||
|
|
||
|
for phase in curr_phases:
|
||
|
out_dict.update(self._get_trainer_state(phase))
|
||
|
self._reset_meters(curr_phases)
|
||
|
logging.info(f"Meters: {out_dict}")
|
||
|
return out_dict
|
||
|
|
||
|
def _get_trainer_state(self, phase):
|
||
|
return {
|
||
|
"Trainer/where": self.where,
|
||
|
"Trainer/epoch": self.epoch,
|
||
|
f"Trainer/steps_{phase}": self.steps[phase],
|
||
|
}
|
||
|
|
||
|
def train_epoch(self, train_loader):
|
||
|
|
||
|
# Init stat meters
|
||
|
batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
|
||
|
data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
|
||
|
mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
|
||
|
data_times = []
|
||
|
phase = Phase.TRAIN
|
||
|
|
||
|
iters_per_epoch = len(train_loader)
|
||
|
|
||
|
loss_names = []
|
||
|
for batch_key in self.loss.keys():
|
||
|
loss_names.append(f"Losses/{phase}_{batch_key}_loss")
|
||
|
|
||
|
loss_mts = OrderedDict(
|
||
|
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
|
||
|
)
|
||
|
extra_loss_mts = {}
|
||
|
|
||
|
progress = ProgressMeter(
|
||
|
iters_per_epoch,
|
||
|
[
|
||
|
batch_time_meter,
|
||
|
data_time_meter,
|
||
|
mem_meter,
|
||
|
self.time_elapsed_meter,
|
||
|
*loss_mts.values(),
|
||
|
],
|
||
|
self._get_meters([phase]),
|
||
|
prefix="Train Epoch: [{}]".format(self.epoch),
|
||
|
)
|
||
|
|
||
|
# Model training loop
|
||
|
self.model.train()
|
||
|
end = time.time()
|
||
|
|
||
|
for data_iter, batch in enumerate(train_loader):
|
||
|
# measure data loading time
|
||
|
data_time_meter.update(time.time() - end)
|
||
|
data_times.append(data_time_meter.val)
|
||
|
batch = batch.to(
|
||
|
self.device, non_blocking=True
|
||
|
) # move tensors in a tensorclass
|
||
|
|
||
|
try:
|
||
|
self._run_step(batch, phase, loss_mts, extra_loss_mts)
|
||
|
|
||
|
# compute gradient and do optim step
|
||
|
exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
|
||
|
self.where = float(exact_epoch) / self.max_epochs
|
||
|
assert self.where <= 1 + self.EPSILON
|
||
|
if self.where < 1.0:
|
||
|
self.optim.step_schedulers(
|
||
|
self.where, step=int(exact_epoch * iters_per_epoch)
|
||
|
)
|
||
|
else:
|
||
|
logging.warning(
|
||
|
f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
|
||
|
)
|
||
|
|
||
|
# Log schedulers
|
||
|
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
||
|
for j, param_group in enumerate(self.optim.optimizer.param_groups):
|
||
|
for option in self.optim.schedulers[j]:
|
||
|
optim_prefix = (
|
||
|
"" + f"{j}_"
|
||
|
if len(self.optim.optimizer.param_groups) > 1
|
||
|
else ""
|
||
|
)
|
||
|
self.logger.log(
|
||
|
os.path.join("Optim", f"{optim_prefix}", option),
|
||
|
param_group[option],
|
||
|
self.steps[phase],
|
||
|
)
|
||
|
|
||
|
# Clipping gradients and detecting diverging gradients
|
||
|
if self.gradient_clipper is not None:
|
||
|
self.scaler.unscale_(self.optim.optimizer)
|
||
|
self.gradient_clipper(model=self.model)
|
||
|
|
||
|
if self.gradient_logger is not None:
|
||
|
self.gradient_logger(
|
||
|
self.model, rank=self.distributed_rank, where=self.where
|
||
|
)
|
||
|
|
||
|
# Optimizer step: the scaler will make sure gradients are not
|
||
|
# applied if the gradients are infinite
|
||
|
self.scaler.step(self.optim.optimizer)
|
||
|
self.scaler.update()
|
||
|
|
||
|
# measure elapsed time
|
||
|
batch_time_meter.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
|
||
|
self.time_elapsed_meter.update(
|
||
|
time.time() - self.start_time + self.ckpt_time_elapsed
|
||
|
)
|
||
|
|
||
|
mem_meter.update(reset_peak_usage=True)
|
||
|
if data_iter % self.logging_conf.log_freq == 0:
|
||
|
progress.display(data_iter)
|
||
|
|
||
|
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
||
|
# Log progress meters.
|
||
|
for progress_meter in progress.meters:
|
||
|
self.logger.log(
|
||
|
os.path.join("Step_Stats", phase, progress_meter.name),
|
||
|
progress_meter.val,
|
||
|
self.steps[phase],
|
||
|
)
|
||
|
|
||
|
# Catching NaN/Inf errors in the loss
|
||
|
except FloatingPointError as e:
|
||
|
raise e
|
||
|
|
||
|
self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
|
||
|
self._log_timers(Phase.TRAIN)
|
||
|
self._log_sync_data_times(Phase.TRAIN, data_times)
|
||
|
|
||
|
out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
|
||
|
|
||
|
for k, v in loss_mts.items():
|
||
|
out_dict[k] = v.avg
|
||
|
for k, v in extra_loss_mts.items():
|
||
|
out_dict[k] = v.avg
|
||
|
out_dict.update(self._get_trainer_state(phase))
|
||
|
logging.info(f"Losses and meters: {out_dict}")
|
||
|
self._reset_meters([phase])
|
||
|
return out_dict
|
||
|
|
||
|
def _log_sync_data_times(self, phase, data_times):
|
||
|
data_times = all_reduce_max(torch.tensor(data_times)).tolist()
|
||
|
steps = range(self.steps[phase] - len(data_times), self.steps[phase])
|
||
|
for step, data_time in zip(steps, data_times):
|
||
|
if step % self.logging_conf.log_scalar_frequency == 0:
|
||
|
self.logger.log(
|
||
|
os.path.join("Step_Stats", phase, "Data Time Synced"),
|
||
|
data_time,
|
||
|
step,
|
||
|
)
|
||
|
|
||
|
def _run_step(
|
||
|
self,
|
||
|
batch: BatchedVideoDatapoint,
|
||
|
phase: str,
|
||
|
loss_mts: Dict[str, AverageMeter],
|
||
|
extra_loss_mts: Dict[str, AverageMeter],
|
||
|
raise_on_error: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Run the forward / backward
|
||
|
"""
|
||
|
|
||
|
# it's important to set grads to None, especially with Adam since 0
|
||
|
# grads will also update a model even if the step doesn't produce
|
||
|
# gradients
|
||
|
self.optim.zero_grad(set_to_none=True)
|
||
|
with torch.cuda.amp.autocast(
|
||
|
enabled=self.optim_conf.amp.enabled,
|
||
|
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
|
||
|
):
|
||
|
loss_dict, batch_size, extra_losses = self._step(
|
||
|
batch,
|
||
|
self.model,
|
||
|
phase,
|
||
|
)
|
||
|
|
||
|
assert len(loss_dict) == 1
|
||
|
loss_key, loss = loss_dict.popitem()
|
||
|
|
||
|
if not math.isfinite(loss.item()):
|
||
|
error_msg = f"Loss is {loss.item()}, attempting to stop training"
|
||
|
logging.error(error_msg)
|
||
|
if raise_on_error:
|
||
|
raise FloatingPointError(error_msg)
|
||
|
else:
|
||
|
return
|
||
|
|
||
|
self.scaler.scale(loss).backward()
|
||
|
loss_mts[loss_key].update(loss.item(), batch_size)
|
||
|
for extra_loss_key, extra_loss in extra_losses.items():
|
||
|
if extra_loss_key not in extra_loss_mts:
|
||
|
extra_loss_mts[extra_loss_key] = AverageMeter(
|
||
|
extra_loss_key, self.device, ":.2e"
|
||
|
)
|
||
|
extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
|
||
|
|
||
|
def _log_meters_and_save_best_ckpts(self, phases: List[str]):
|
||
|
logging.info("Synchronizing meters")
|
||
|
out_dict = {}
|
||
|
checkpoint_save_keys = []
|
||
|
for key, meter in self._get_meters(phases).items():
|
||
|
meter_output = meter.compute_synced()
|
||
|
is_better_check = getattr(meter, "is_better", None)
|
||
|
|
||
|
for meter_subkey, meter_value in meter_output.items():
|
||
|
out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
|
||
|
|
||
|
if is_better_check is None:
|
||
|
continue
|
||
|
|
||
|
tracked_meter_key = os.path.join(key, meter_subkey)
|
||
|
if tracked_meter_key not in self.best_meter_values or is_better_check(
|
||
|
meter_value,
|
||
|
self.best_meter_values[tracked_meter_key],
|
||
|
):
|
||
|
self.best_meter_values[tracked_meter_key] = meter_value
|
||
|
|
||
|
if (
|
||
|
self.checkpoint_conf.save_best_meters is not None
|
||
|
and key in self.checkpoint_conf.save_best_meters
|
||
|
):
|
||
|
checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
|
||
|
|
||
|
if len(checkpoint_save_keys) > 0:
|
||
|
self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
|
||
|
|
||
|
return out_dict
|
||
|
|
||
|
def _log_timers(self, phase):
|
||
|
time_remaining = 0
|
||
|
epochs_remaining = self.max_epochs - self.epoch - 1
|
||
|
val_epochs_remaining = sum(
|
||
|
n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
|
||
|
)
|
||
|
|
||
|
# Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
|
||
|
# the end epoch.
|
||
|
if (self.max_epochs - 1) % self.val_epoch_freq != 0:
|
||
|
val_epochs_remaining += 1
|
||
|
|
||
|
# Remove the current val run from estimate
|
||
|
if phase == Phase.VAL:
|
||
|
val_epochs_remaining -= 1
|
||
|
|
||
|
time_remaining += (
|
||
|
epochs_remaining * self.est_epoch_time[Phase.TRAIN]
|
||
|
+ val_epochs_remaining * self.est_epoch_time[Phase.VAL]
|
||
|
)
|
||
|
|
||
|
self.logger.log(
|
||
|
os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
|
||
|
self.time_elapsed_meter.val,
|
||
|
self.steps[phase],
|
||
|
)
|
||
|
|
||
|
logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
|
||
|
|
||
|
def _reset_meters(self, phases: str) -> None:
|
||
|
for meter in self._get_meters(phases).values():
|
||
|
meter.reset()
|
||
|
|
||
|
def _check_val_key_match(self, val_keys, phase):
|
||
|
if val_keys is not None:
|
||
|
# Check if there are any duplicates
|
||
|
assert len(val_keys) == len(
|
||
|
set(val_keys)
|
||
|
), f"Duplicate keys in val datasets, keys: {val_keys}"
|
||
|
|
||
|
# Check that the keys match the meter keys
|
||
|
if self.meters_conf is not None and phase in self.meters_conf:
|
||
|
assert set(val_keys) == set(self.meters_conf[phase].keys()), (
|
||
|
f"Keys in val datasets do not match the keys in meters."
|
||
|
f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
|
||
|
f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
|
||
|
)
|
||
|
|
||
|
if self.loss_conf is not None:
|
||
|
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
||
|
assert all([k in loss_keys for k in val_keys]), (
|
||
|
f"Keys in val datasets do not match the keys in losses."
|
||
|
f"\nMissing in losses: {set(val_keys) - loss_keys}"
|
||
|
f"\nMissing in val datasets: {loss_keys - set(val_keys)}"
|
||
|
)
|
||
|
|
||
|
def _setup_components(self):
|
||
|
|
||
|
# Get the keys for all the val datasets, if any
|
||
|
val_phase = Phase.VAL
|
||
|
val_keys = None
|
||
|
if self.data_conf.get(val_phase, None) is not None:
|
||
|
val_keys = collect_dict_keys(self.data_conf[val_phase])
|
||
|
# Additional checks on the sanity of the config for val datasets
|
||
|
self._check_val_key_match(val_keys, phase=val_phase)
|
||
|
|
||
|
logging.info("Setting up components: Model, loss, optim, meters etc.")
|
||
|
self.epoch = 0
|
||
|
self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
|
||
|
|
||
|
self.logger = Logger(self.logging_conf)
|
||
|
|
||
|
self.model = instantiate(self.model_conf, _convert_="all")
|
||
|
print_model_summary(self.model)
|
||
|
|
||
|
self.loss = None
|
||
|
if self.loss_conf:
|
||
|
self.loss = {
|
||
|
key: el # wrap_base_loss(el)
|
||
|
for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
|
||
|
}
|
||
|
self.loss = nn.ModuleDict(self.loss)
|
||
|
|
||
|
self.meters = {}
|
||
|
self.best_meter_values = {}
|
||
|
if self.meters_conf:
|
||
|
self.meters = instantiate(self.meters_conf, _convert_="all")
|
||
|
|
||
|
self.scaler = torch.amp.GradScaler(
|
||
|
self.device,
|
||
|
enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
|
||
|
)
|
||
|
|
||
|
self.gradient_clipper = (
|
||
|
instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
|
||
|
)
|
||
|
self.gradient_logger = (
|
||
|
instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
|
||
|
)
|
||
|
|
||
|
logging.info("Finished setting up components: Model, loss, optim, meters etc.")
|
||
|
|
||
|
def _construct_optimizers(self):
|
||
|
self.optim = construct_optimizer(
|
||
|
self.model,
|
||
|
self.optim_conf.optimizer,
|
||
|
self.optim_conf.options,
|
||
|
self.optim_conf.param_group_modifiers,
|
||
|
)
|
||
|
|
||
|
def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
|
||
|
core_loss = loss.pop(CORE_LOSS_KEY)
|
||
|
if step % self.logging_conf.log_scalar_frequency == 0:
|
||
|
for k in loss:
|
||
|
log_str = os.path.join(loss_str, k)
|
||
|
self.logger.log(log_str, loss[k], step)
|
||
|
return core_loss
|
||
|
|
||
|
|
||
|
def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
|
||
|
"""
|
||
|
Prints the model and the number of parameters in the model.
|
||
|
# Multiple packages provide this info in a nice table format
|
||
|
# However, they need us to provide an `input` (as they also write down the output sizes)
|
||
|
# Our models are complex, and a single input is restrictive.
|
||
|
# https://github.com/sksq96/pytorch-summary
|
||
|
# https://github.com/nmhkahn/torchsummaryX
|
||
|
"""
|
||
|
if get_rank() != 0:
|
||
|
return
|
||
|
param_kwargs = {}
|
||
|
trainable_parameters = sum(
|
||
|
p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
|
||
|
)
|
||
|
total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
|
||
|
non_trainable_parameters = total_parameters - trainable_parameters
|
||
|
logging.info("==" * 10)
|
||
|
logging.info(f"Summary for model {type(model)}")
|
||
|
logging.info(f"Model is {model}")
|
||
|
logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
|
||
|
logging.info(
|
||
|
f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
|
||
|
)
|
||
|
logging.info(
|
||
|
f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
|
||
|
)
|
||
|
logging.info("==" * 10)
|
||
|
|
||
|
if log_dir:
|
||
|
output_fpath = os.path.join(log_dir, "model.txt")
|
||
|
with g_pathmgr.open(output_fpath, "w") as f:
|
||
|
print(model, file=f)
|
||
|
|
||
|
|
||
|
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
||
|
|
||
|
|
||
|
def get_human_readable_count(number: int) -> str:
|
||
|
"""
|
||
|
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
||
|
billions and trillions, respectively.
|
||
|
Examples:
|
||
|
>>> get_human_readable_count(123)
|
||
|
'123 '
|
||
|
>>> get_human_readable_count(1234) # (one thousand)
|
||
|
'1.2 K'
|
||
|
>>> get_human_readable_count(2e6) # (two million)
|
||
|
'2.0 M'
|
||
|
>>> get_human_readable_count(3e9) # (three billion)
|
||
|
'3.0 B'
|
||
|
>>> get_human_readable_count(4e14) # (four hundred trillion)
|
||
|
'400 T'
|
||
|
>>> get_human_readable_count(5e15) # (more than trillion)
|
||
|
'5,000 T'
|
||
|
Args:
|
||
|
number: a positive integer number
|
||
|
Return:
|
||
|
A string formatted according to the pattern described above.
|
||
|
"""
|
||
|
assert number >= 0
|
||
|
labels = PARAMETER_NUM_UNITS
|
||
|
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
||
|
num_groups = int(np.ceil(num_digits / 3))
|
||
|
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
||
|
shift = -3 * (num_groups - 1)
|
||
|
number = number * (10**shift)
|
||
|
index = num_groups - 1
|
||
|
if index < 1 or number >= 100:
|
||
|
return f"{int(number):,d} {labels[index]}"
|
||
|
else:
|
||
|
return f"{number:,.1f} {labels[index]}"
|