init commit of samurai
This commit is contained in:
5
sam2/training/utils/__init__.py
Normal file
5
sam2/training/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
361
sam2/training/utils/checkpoint_utils.py
Normal file
361
sam2/training/utils/checkpoint_utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import fnmatch
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from torch.jit._script import RecursiveScriptModule
|
||||
|
||||
|
||||
def unix_pattern_to_parameter_names(
|
||||
constraints: List[str], all_parameter_names: Sequence[str]
|
||||
) -> Union[None, Set[str]]:
|
||||
"""
|
||||
Go through the list of parameter names and select those that match
|
||||
any of the provided constraints
|
||||
"""
|
||||
parameter_names = []
|
||||
for param_name in constraints:
|
||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"param_names {param_name} don't match any param in the given names."
|
||||
parameter_names.append(matching_parameters)
|
||||
return set.union(*parameter_names)
|
||||
|
||||
|
||||
def filter_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: state_dict[k] for k in included_keys}
|
||||
|
||||
|
||||
def exclude_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return state_dict
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: v for k, v in state_dict.items() if k not in excluded_keys}
|
||||
|
||||
|
||||
def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
|
||||
keys = []
|
||||
trace = []
|
||||
for k, v in state_dict.items():
|
||||
keys.append(k)
|
||||
trace.append(v.sum().item())
|
||||
trace = np.array(trace)[np.argsort(keys)]
|
||||
return trace
|
||||
|
||||
|
||||
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
|
||||
"""
|
||||
Verifies that all the parameters matching the provided patterns
|
||||
are frozen - this acts as a safeguard when ignoring parameter
|
||||
when saving checkpoints - if the parameters are in fact trainable
|
||||
"""
|
||||
if not patterns:
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
non_frozen_keys = {
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if n in frozen_state_dict and p.requires_grad
|
||||
}
|
||||
if non_frozen_keys:
|
||||
raise ValueError(
|
||||
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_check_parameter_frozen(
|
||||
model: nn.Module, patterns: List[str], disabled: bool = True
|
||||
):
|
||||
"""
|
||||
Context manager that inspects a model surrounding a piece of code
|
||||
and verifies if the model has been updated by this piece of code
|
||||
|
||||
The function will raise an exception if the model has been updated
|
||||
on at least one of the parameter that matches one of the pattern
|
||||
|
||||
Args:
|
||||
model: the model that might have been updated
|
||||
patterns: for the parameters we want to observe
|
||||
allowed:
|
||||
"""
|
||||
if not patterns or disabled:
|
||||
yield
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_before = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
yield
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_after = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
if not np.allclose(summary_before, summary_after, atol=1e-6):
|
||||
raise ValueError(
|
||||
f"""
|
||||
The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
|
||||
You can resolve this error by either initializing those parameters from within the model definition
|
||||
or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class CkptExcludeKernel:
|
||||
"""
|
||||
Removes the keys from the given model state_dict that match the key_pattern.
|
||||
|
||||
Args:
|
||||
key_pattern: Patterns used to select the keys in the state_dict
|
||||
that are eligible for this kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, key_pattern: List[str]):
|
||||
self.key_pattern = key_pattern
|
||||
|
||||
def __call__(self, state_dict: Dict):
|
||||
"""
|
||||
Args:
|
||||
state_dict: A dictionary representing the given checkpoint's state dict.
|
||||
"""
|
||||
if len(self.key_pattern) == 0:
|
||||
return state_dict
|
||||
exclude_keys = unix_pattern_to_parameter_names(
|
||||
self.key_pattern, state_dict.keys()
|
||||
)
|
||||
return {k: v for k, v in state_dict.items() if k not in exclude_keys}
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
path_list: List[str],
|
||||
pick_recursive_keys: Optional[List[str]] = None,
|
||||
map_location: str = "cpu",
|
||||
) -> Any:
|
||||
"""
|
||||
Loads a checkpoint from the specified path.
|
||||
|
||||
Args:
|
||||
path_list: A list of paths which contain the checkpoint. Each element
|
||||
is tried (in order) until a file that exists is found. That file is then
|
||||
used to read the checkpoint.
|
||||
pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
|
||||
For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
path_exists = False
|
||||
for path in path_list:
|
||||
if g_pathmgr.exists(path):
|
||||
path_exists = True
|
||||
break
|
||||
|
||||
if not path_exists:
|
||||
raise ValueError(f"No path exists in {path_list}")
|
||||
|
||||
with g_pathmgr.open(path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
logging.info(f"Loaded checkpoint from {path}")
|
||||
if pick_recursive_keys is not None:
|
||||
for key in pick_recursive_keys:
|
||||
checkpoint = checkpoint[key]
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_state_dict(checkpoint, ckpt_state_dict_keys):
|
||||
if isinstance(checkpoint, RecursiveScriptModule):
|
||||
# This is a torchscript JIT model
|
||||
return checkpoint.state_dict()
|
||||
pre_train_dict = checkpoint
|
||||
for i, key in enumerate(ckpt_state_dict_keys):
|
||||
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
|
||||
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
|
||||
):
|
||||
key_str = (
|
||||
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
|
||||
)
|
||||
raise KeyError(
|
||||
f"'{key}' not found in checkpoint{key_str} "
|
||||
f"with keys: {pre_train_dict.keys()}"
|
||||
)
|
||||
pre_train_dict = pre_train_dict[key]
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def load_checkpoint_and_apply_kernels(
|
||||
checkpoint_path: str,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
|
||||
map_location: str = "cpu",
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Performs checkpoint loading with a variety of pre-processing kernel applied in
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint.
|
||||
checkpoint_kernels List(Callable): A list of checkpoint processing kernels
|
||||
to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
|
||||
`CkptExcludeKernel`, etc. These kernels are applied in the
|
||||
given order.
|
||||
ckpt_state_dict_keys (str): Keys containing the model state dict.
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
|
||||
checkpoint_path
|
||||
)
|
||||
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike.
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
|
||||
|
||||
# Not logging into info etc since it's a huge log
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict pre-kernel application: %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
pre_train_dict = f(state_dict=pre_train_dict)
|
||||
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict Post-kernel application %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict: bool,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
):
|
||||
if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
|
||||
ignored_keys = unix_pattern_to_parameter_names(
|
||||
ignore_missing_keys, missing_keys
|
||||
)
|
||||
missing_keys = [key for key in missing_keys if key not in ignored_keys]
|
||||
|
||||
if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
|
||||
ignored_unexpected_keys = unix_pattern_to_parameter_names(
|
||||
ignore_unexpected_keys, unexpected_keys
|
||||
)
|
||||
unexpected_keys = [
|
||||
key for key in unexpected_keys if key not in ignored_unexpected_keys
|
||||
]
|
||||
|
||||
err = "State key mismatch."
|
||||
if unexpected_keys:
|
||||
err += f" Unexpected keys: {unexpected_keys}."
|
||||
if missing_keys:
|
||||
err += f" Missing keys: {missing_keys}."
|
||||
|
||||
if unexpected_keys or missing_keys:
|
||||
logging.warning(err)
|
||||
if unexpected_keys or strict:
|
||||
raise KeyError(err)
|
||||
|
||||
|
||||
def load_state_dict_into_model(
|
||||
state_dict: Dict,
|
||||
model: nn.Module,
|
||||
strict: bool = True,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Loads a state dict into the given model.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary containing the model's
|
||||
state dict, or a subset if strict is False
|
||||
model: Model to load the checkpoint weights into
|
||||
strict: raise if the state_dict has missing state keys
|
||||
ignore_missing_keys: unix pattern of keys to ignore
|
||||
"""
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
state_dict = f(state_dict=state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict=strict,
|
||||
ignore_missing_keys=ignore_missing_keys,
|
||||
ignore_unexpected_keys=ignore_unexpected_keys,
|
||||
)
|
||||
return model
|
179
sam2/training/utils/data_utils.py
Normal file
179
sam2/training/utils/data_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from tensordict import tensorclass
|
||||
|
||||
|
||||
@tensorclass
|
||||
class BatchedVideoMetaData:
|
||||
"""
|
||||
This class represents metadata about a batch of videos.
|
||||
Attributes:
|
||||
unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
|
||||
frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
|
||||
"""
|
||||
|
||||
unique_objects_identifier: torch.LongTensor
|
||||
frame_orig_size: torch.LongTensor
|
||||
|
||||
|
||||
@tensorclass
|
||||
class BatchedVideoDatapoint:
|
||||
"""
|
||||
This class represents a batch of videos with associated annotations and metadata.
|
||||
Attributes:
|
||||
img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
|
||||
obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
|
||||
masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
|
||||
metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
|
||||
dict_key: A string key used to identify the batch.
|
||||
"""
|
||||
|
||||
img_batch: torch.FloatTensor
|
||||
obj_to_frame_idx: torch.IntTensor
|
||||
masks: torch.BoolTensor
|
||||
metadata: BatchedVideoMetaData
|
||||
|
||||
dict_key: str
|
||||
|
||||
def pin_memory(self, device=None):
|
||||
return self.apply(torch.Tensor.pin_memory, device=device)
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""
|
||||
Returns the number of frames per video.
|
||||
"""
|
||||
return self.batch_size[0]
|
||||
|
||||
@property
|
||||
def num_videos(self) -> int:
|
||||
"""
|
||||
Returns the number of videos in the batch.
|
||||
"""
|
||||
return self.img_batch.shape[1]
|
||||
|
||||
@property
|
||||
def flat_obj_to_img_idx(self) -> torch.IntTensor:
|
||||
"""
|
||||
Returns a flattened tensor containing the object to img index.
|
||||
The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
|
||||
"""
|
||||
frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
|
||||
flat_idx = video_idx * self.num_frames + frame_idx
|
||||
return flat_idx
|
||||
|
||||
@property
|
||||
def flat_img_batch(self) -> torch.FloatTensor:
|
||||
"""
|
||||
Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
|
||||
"""
|
||||
|
||||
return self.img_batch.transpose(0, 1).flatten(0, 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Object:
|
||||
# Id of the object in the media
|
||||
object_id: int
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: int
|
||||
segment: Union[torch.Tensor, dict] # RLE dict or binary mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
data: Union[torch.Tensor, PILImage.Image]
|
||||
objects: List[Object]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoDatapoint:
|
||||
"""Refers to an image/video and all its annotations"""
|
||||
|
||||
frames: List[Frame]
|
||||
video_id: int
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
def collate_fn(
|
||||
batch: List[VideoDatapoint],
|
||||
dict_key,
|
||||
) -> BatchedVideoDatapoint:
|
||||
"""
|
||||
Args:
|
||||
batch: A list of VideoDatapoint instances.
|
||||
dict_key (str): A string key used to identify the batch.
|
||||
"""
|
||||
img_batch = []
|
||||
for video in batch:
|
||||
img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
|
||||
|
||||
img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
|
||||
T = img_batch.shape[0]
|
||||
# Prepare data structures for sequential processing. Per-frame processing but batched across videos.
|
||||
step_t_objects_identifier = [[] for _ in range(T)]
|
||||
step_t_frame_orig_size = [[] for _ in range(T)]
|
||||
|
||||
step_t_masks = [[] for _ in range(T)]
|
||||
step_t_obj_to_frame_idx = [
|
||||
[] for _ in range(T)
|
||||
] # List to store frame indices for each time step
|
||||
|
||||
for video_idx, video in enumerate(batch):
|
||||
orig_video_id = video.video_id
|
||||
orig_frame_size = video.size
|
||||
for t, frame in enumerate(video.frames):
|
||||
objects = frame.objects
|
||||
for obj in objects:
|
||||
orig_obj_id = obj.object_id
|
||||
orig_frame_idx = obj.frame_index
|
||||
step_t_obj_to_frame_idx[t].append(
|
||||
torch.tensor([t, video_idx], dtype=torch.int)
|
||||
)
|
||||
step_t_masks[t].append(obj.segment.to(torch.bool))
|
||||
step_t_objects_identifier[t].append(
|
||||
torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
|
||||
)
|
||||
step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
|
||||
|
||||
obj_to_frame_idx = torch.stack(
|
||||
[
|
||||
torch.stack(obj_to_frame_idx, dim=0)
|
||||
for obj_to_frame_idx in step_t_obj_to_frame_idx
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
|
||||
objects_identifier = torch.stack(
|
||||
[torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
|
||||
)
|
||||
frame_orig_size = torch.stack(
|
||||
[torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
|
||||
)
|
||||
return BatchedVideoDatapoint(
|
||||
img_batch=img_batch,
|
||||
obj_to_frame_idx=obj_to_frame_idx,
|
||||
masks=masks,
|
||||
metadata=BatchedVideoMetaData(
|
||||
unique_objects_identifier=objects_identifier,
|
||||
frame_orig_size=frame_orig_size,
|
||||
),
|
||||
dict_key=dict_key,
|
||||
batch_size=[T],
|
||||
)
|
576
sam2/training/utils/distributed.py
Normal file
576
sam2/training/utils/distributed.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Default to GPU 0
|
||||
_cuda_device_index: int = 0
|
||||
|
||||
# Setting _cuda_device_index to -1 internally implies that we should use CPU
|
||||
_CPU_DEVICE_INDEX = -1
|
||||
_PRIMARY_RANK = 0
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
|
||||
if dist.get_backend() == "nccl":
|
||||
# Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
|
||||
# being much slower than others causing a timeout (which can happen in relation
|
||||
# or LVIS class mAP evaluation).
|
||||
timeout = 43200
|
||||
return dist.new_group(
|
||||
backend="gloo",
|
||||
timeout=datetime.timedelta(seconds=timeout),
|
||||
)
|
||||
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""Return true if the current process is the main one"""
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
|
||||
`all_gather` above, but using filesystem instead of collective ops.
|
||||
|
||||
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
|
||||
(and other ranks will have an empty list).
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
print("gathering via files")
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
# if unspecified, we will save to the current python file dir
|
||||
if filesys_save_dir is not None:
|
||||
save_dir = filesys_save_dir
|
||||
elif "EXP_DIR" in os.environ:
|
||||
save_dir = os.environ["EXP_DIR"]
|
||||
else:
|
||||
# try the same directory where the code is stored
|
||||
save_dir = filesys_save_dir or os.path.dirname(__file__)
|
||||
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
|
||||
if is_main_process():
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# use a timestamp and salt to distinguish different all_gather
|
||||
timestamp = int(time.time()) if is_main_process() else 0
|
||||
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
|
||||
# broadcast the timestamp and salt across ranks
|
||||
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
|
||||
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
|
||||
dist.all_reduce(timestamp_and_salt, group=cpu_group)
|
||||
timestamp, salt = timestamp_and_salt.tolist()
|
||||
|
||||
# save the data to a file on the disk
|
||||
rank_save = get_rank()
|
||||
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
|
||||
save_data_path = os.path.join(save_dir, save_data_filename)
|
||||
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
|
||||
torch.save(data, save_data_path)
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# read the data from the files
|
||||
data_list = []
|
||||
if rank_save == 0 or not gather_to_rank_0_only:
|
||||
for rank_load in range(world_size):
|
||||
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
|
||||
load_data_path = os.path.join(save_dir, load_data_filename)
|
||||
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
|
||||
data_list.append(torch.load(load_data_path))
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# delete the saved file
|
||||
os.remove(save_data_path)
|
||||
return data_list
|
||||
|
||||
|
||||
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
|
||||
return all_gather_via_filesys(
|
||||
data, filesys_save_dir, gather_to_rank_0_only=True
|
||||
)
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
|
||||
return all_gather_via_filesys(data, filesys_save_dir)
|
||||
|
||||
cpu_group = None
|
||||
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(data, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
device = "cuda" if cpu_group is None else "cpu"
|
||||
tensor = torch.ByteTensor(data_view).to(device)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
|
||||
size_list = [
|
||||
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
|
||||
]
|
||||
if cpu_group is None:
|
||||
dist.all_gather(size_list, local_size)
|
||||
else:
|
||||
print("gathering on cpu")
|
||||
dist.all_gather(size_list, local_size, group=cpu_group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
assert isinstance(local_size.item(), int)
|
||||
local_size = int(local_size.item())
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(
|
||||
size=(max_size - local_size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
if cpu_group is None:
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
else:
|
||||
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||
obj = torch.load(buffer)
|
||||
data_list.append(obj)
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This helper function converts to the correct
|
||||
device and returns the tensor + original device.
|
||||
"""
|
||||
orig_device = "cpu" if not tensor.is_cuda else "gpu"
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
|
||||
and not tensor.is_cuda
|
||||
):
|
||||
tensor = tensor.cuda()
|
||||
return (tensor, orig_device)
|
||||
|
||||
|
||||
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This converts the tensor back to original device.
|
||||
"""
|
||||
if tensor.is_cuda and orig_device == "cpu":
|
||||
tensor = tensor.cpu()
|
||||
return tensor
|
||||
|
||||
|
||||
def is_distributed_training_run() -> bool:
|
||||
return (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
and (torch.distributed.get_world_size() > 1)
|
||||
)
|
||||
|
||||
|
||||
def is_primary() -> bool:
|
||||
"""
|
||||
Returns True if this is rank 0 of a distributed training job OR if it is
|
||||
a single trainer job. Otherwise False.
|
||||
"""
|
||||
return get_rank() == _PRIMARY_RANK
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing mean reduction
|
||||
of tensor over all processes.
|
||||
"""
|
||||
return all_reduce_op(
|
||||
tensor,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
lambda t: t / torch.distributed.get_world_size(),
|
||||
)
|
||||
|
||||
|
||||
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing sum
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
|
||||
|
||||
|
||||
def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
|
||||
|
||||
|
||||
def all_reduce_op(
|
||||
tensor: torch.Tensor,
|
||||
op: torch.distributed.ReduceOp,
|
||||
after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.all_reduce(tensor, op)
|
||||
if after_op_func is not None:
|
||||
tensor = after_op_func(tensor)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_gather for performing
|
||||
'gather' of 'tensor' over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if tensor.ndim == 0:
|
||||
# 0 dim tensors cannot be gathered. so unsqueeze
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
gathered_tensors = [
|
||||
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, tensor)
|
||||
gathered_tensors = [
|
||||
convert_to_normal_tensor(_tensor, orig_device)
|
||||
for _tensor in gathered_tensors
|
||||
]
|
||||
else:
|
||||
gathered_tensors = [tensor]
|
||||
|
||||
return gathered_tensors
|
||||
|
||||
|
||||
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
|
||||
gathered_tensors = gather_tensors_from_all(tensor)
|
||||
gathered_tensor = torch.cat(gathered_tensors, 0)
|
||||
return gathered_tensor
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
|
||||
to all processes in both distributed / non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.broadcast(tensor, src)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def barrier() -> None:
|
||||
"""
|
||||
Wrapper over torch.distributed.barrier, returns without waiting
|
||||
if the distributed process group is not initialized instead of throwing error.
|
||||
"""
|
||||
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
|
||||
return
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting worldsize in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_world_size()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 1
|
||||
)
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting rank in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_rank()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
def get_primary_rank() -> int:
|
||||
return _PRIMARY_RANK
|
||||
|
||||
|
||||
def set_cuda_device_index(idx: int) -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = idx
|
||||
torch.cuda.set_device(_cuda_device_index)
|
||||
|
||||
|
||||
def set_cpu_device() -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = _CPU_DEVICE_INDEX
|
||||
|
||||
|
||||
def get_cuda_device_index() -> int:
|
||||
return _cuda_device_index
|
||||
|
||||
|
||||
def init_distributed_data_parallel_model(
|
||||
model: torch.nn.Module,
|
||||
broadcast_buffers: bool = False,
|
||||
find_unused_parameters: bool = True,
|
||||
bucket_cap_mb: int = 25,
|
||||
) -> torch.nn.parallel.DistributedDataParallel:
|
||||
global _cuda_device_index
|
||||
|
||||
if _cuda_device_index == _CPU_DEVICE_INDEX:
|
||||
# CPU-only model, don't specify device
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
else:
|
||||
# GPU model
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[_cuda_device_index],
|
||||
output_device=_cuda_device_index,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
|
||||
|
||||
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
|
||||
"""Broadcast an object from a source to all workers.
|
||||
|
||||
Args:
|
||||
obj: Object to broadcast, must be serializable
|
||||
src: Source rank for broadcast (default is primary)
|
||||
use_disk: If enabled, removes redundant CPU memory copies by writing to
|
||||
disk
|
||||
"""
|
||||
# Either broadcast from primary to the fleet (default),
|
||||
# or use the src setting as the original rank
|
||||
if get_rank() == src:
|
||||
# Emit data
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
length_tensor = torch.LongTensor([len(data_view)])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.ByteTensor(data_view)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
else:
|
||||
# Fetch from the source
|
||||
length_tensor = torch.LongTensor([0])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
if use_disk:
|
||||
with tempfile.TemporaryFile("r+b") as f:
|
||||
f.write(data_tensor.numpy())
|
||||
# remove reference to the data tensor and hope that Python garbage
|
||||
# collects it
|
||||
del data_tensor
|
||||
f.seek(0)
|
||||
obj = torch.load(f)
|
||||
else:
|
||||
buffer = io.BytesIO(data_tensor.numpy())
|
||||
obj = torch.load(buffer)
|
||||
return obj
|
||||
|
||||
|
||||
def all_gather_tensor(tensor: torch.Tensor, world_size=None):
|
||||
if world_size is None:
|
||||
world_size = get_world_size()
|
||||
# make contiguous because NCCL won't gather the tensor otherwise
|
||||
assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
|
||||
tensor_all = [
|
||||
convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
|
||||
]
|
||||
return tensor_all
|
||||
|
||||
|
||||
def all_gather_batch(tensors: List[torch.Tensor]):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
for tensor in tensors:
|
||||
tensor_all = all_gather_tensor(tensor, world_size)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
class GatherLayer(autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
dist.all_reduce(all_gradients)
|
||||
return all_gradients[dist.get_rank()]
|
||||
|
||||
|
||||
def all_gather_batch_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
|
||||
for tensor in tensors:
|
||||
tensor_all = GatherLayer.apply(tensor)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
def unwrap_ddp_if_wrapped(model):
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
def create_new_process_group(group_size):
|
||||
"""
|
||||
Creates process groups of a gives `group_size` and returns
|
||||
process group that current GPU participates in.
|
||||
|
||||
`group_size` must divide the total number of GPUs (world_size).
|
||||
|
||||
Modified from
|
||||
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
|
||||
|
||||
Args:
|
||||
group_size (int): number of GPU's to collaborate for sync bn
|
||||
"""
|
||||
|
||||
assert group_size > 0
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size <= 8:
|
||||
if group_size > world_size:
|
||||
logging.warning(
|
||||
f"Requested group size [{group_size}] > world size [{world_size}]. "
|
||||
"Assuming local debug run and capping it to world size."
|
||||
)
|
||||
group_size = world_size
|
||||
assert world_size >= group_size
|
||||
assert world_size % group_size == 0
|
||||
|
||||
group = None
|
||||
for group_num in range(world_size // group_size):
|
||||
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
|
||||
cur_group = torch.distributed.new_group(ranks=group_ids)
|
||||
if torch.distributed.get_rank() // group_size == group_num:
|
||||
group = cur_group
|
||||
# can not drop out and return here, every process must go through creation of all subgroups
|
||||
|
||||
assert group is not None
|
||||
return group
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
246
sam2/training/utils/logger.py
Normal file
246
sam2/training/utils/logger.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
|
||||
import atexit
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from numpy import ndarray
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||
|
||||
Scalar = Union[Tensor, ndarray, int, float]
|
||||
|
||||
|
||||
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
|
||||
makedir(log_dir)
|
||||
summary_writer_method = SummaryWriter
|
||||
return TensorBoardLogger(
|
||||
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
|
||||
)
|
||||
|
||||
|
||||
class TensorBoardWriterWrapper:
|
||||
"""
|
||||
A wrapper around a SummaryWriter object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
*args: Any,
|
||||
filename_suffix: str = None,
|
||||
summary_writer_method: Any = SummaryWriter,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TensorBoard logger.
|
||||
On construction, the logger creates a new events file that logs
|
||||
will be written to. If the environment variable `RANK` is defined,
|
||||
logger will only log if RANK = 0.
|
||||
|
||||
NOTE: If using the logger with distributed training:
|
||||
- This logger can call collective operations
|
||||
- Logs will be written on rank 0 only
|
||||
- Logger must be constructed synchronously *after* initializing distributed process group.
|
||||
|
||||
Args:
|
||||
path (str): path to write logs to
|
||||
*args, **kwargs: Extra arguments to pass to SummaryWriter
|
||||
"""
|
||||
self._writer: Optional[SummaryWriter] = None
|
||||
_, self._rank = get_machine_local_and_dist_rank()
|
||||
self._path: str = path
|
||||
if self._rank == 0:
|
||||
logging.info(
|
||||
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
|
||||
)
|
||||
self._writer = summary_writer_method(
|
||||
log_dir=path,
|
||||
*args,
|
||||
filename_suffix=filename_suffix or str(uuid.uuid4()),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Not logging meters on this host because env RANK: {self._rank} != 0"
|
||||
)
|
||||
atexit.register(self.close)
|
||||
|
||||
@property
|
||||
def writer(self) -> Optional[SummaryWriter]:
|
||||
return self._writer
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Writes pending logs to disk."""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close writer, flushing pending logs to disk.
|
||||
Logs cannot be written after `close` is called.
|
||||
"""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.close()
|
||||
self._writer = None
|
||||
|
||||
|
||||
class TensorBoardLogger(TensorBoardWriterWrapper):
|
||||
"""
|
||||
A simple logger for TensorBoard.
|
||||
"""
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
"""Add multiple scalar values to TensorBoard.
|
||||
|
||||
Args:
|
||||
payload (dict): dictionary of tag name and scalar value
|
||||
step (int, Optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
for k, v in payload.items():
|
||||
self.log(k, v, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
"""Add scalar data to TensorBoard.
|
||||
|
||||
Args:
|
||||
name (string): tag name used to group scalars
|
||||
data (float/int/Tensor): scalar data to log
|
||||
step (int, optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_scalar(name, data, global_step=step, new_style=True)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
"""Add hyperparameter data to TensorBoard.
|
||||
|
||||
Args:
|
||||
hparams (dict): dictionary of hyperparameter names and corresponding values
|
||||
meters (dict): dictionary of name of meter and corersponding values
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_hparams(hparams, meters)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logging_conf):
|
||||
# allow turning off TensorBoard with "should_log: false" in config
|
||||
tb_config = logging_conf.tensorboard_writer
|
||||
tb_should_log = tb_config and tb_config.pop("should_log", True)
|
||||
self.tb_logger = instantiate(tb_config) if tb_should_log else None
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_dict(payload, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log(name, data, step)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_hparams(hparams, meters)
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
# we tune the buffering value so that the logs are updated
|
||||
# frequently.
|
||||
log_buffer_kb = 10 * 1024 # 10KB
|
||||
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
|
||||
atexit.register(io.close)
|
||||
return io
|
||||
|
||||
|
||||
def setup_logging(
|
||||
name,
|
||||
output_dir=None,
|
||||
rank=0,
|
||||
log_level_primary="INFO",
|
||||
log_level_secondary="ERROR",
|
||||
):
|
||||
"""
|
||||
Setup various logging streams: stdout and file handlers.
|
||||
For file handlers, we only setup for the master gpu.
|
||||
"""
|
||||
# get the filename if we want to log to the file as well
|
||||
log_filename = None
|
||||
if output_dir:
|
||||
makedir(output_dir)
|
||||
if rank == 0:
|
||||
log_filename = f"{output_dir}/log.txt"
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(log_level_primary)
|
||||
|
||||
# create formatter
|
||||
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
|
||||
formatter = logging.Formatter(FORMAT)
|
||||
|
||||
# Cleanup any existing handlers
|
||||
for h in logger.handlers:
|
||||
logger.removeHandler(h)
|
||||
logger.root.handlers = []
|
||||
|
||||
# setup the console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
if rank == 0:
|
||||
console_handler.setLevel(log_level_primary)
|
||||
else:
|
||||
console_handler.setLevel(log_level_secondary)
|
||||
|
||||
# we log to file as well if user wants
|
||||
if log_filename and rank == 0:
|
||||
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
|
||||
file_handler.setLevel(log_level_primary)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
logging.root = logger
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""
|
||||
After training is done, we ensure to shut down all the logger streams.
|
||||
"""
|
||||
logging.info("Shutting down loggers...")
|
||||
handlers = logging.root.handlers
|
||||
for handler in handlers:
|
||||
handler.close()
|
288
sam2/training/utils/train_utils.py
Normal file
288
sam2/training/utils/train_utils.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def multiply_all(*args):
|
||||
return np.prod(np.array(args)).item()
|
||||
|
||||
|
||||
def collect_dict_keys(config):
|
||||
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
|
||||
val_keys = []
|
||||
# If the this config points to the collate function, then it has a key
|
||||
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
|
||||
val_keys.append(config["dict_key"])
|
||||
else:
|
||||
# Recursively proceed
|
||||
for v in config.values():
|
||||
if isinstance(v, type(config)):
|
||||
val_keys.extend(collect_dict_keys(v))
|
||||
elif isinstance(v, omegaconf.listconfig.ListConfig):
|
||||
for item in v:
|
||||
if isinstance(item, type(config)):
|
||||
val_keys.extend(collect_dict_keys(item))
|
||||
return val_keys
|
||||
|
||||
|
||||
class Phase:
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
|
||||
|
||||
def register_omegaconf_resolvers():
|
||||
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
|
||||
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
|
||||
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
|
||||
OmegaConf.register_new_resolver("times", multiply_all)
|
||||
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
|
||||
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
|
||||
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
|
||||
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
|
||||
OmegaConf.register_new_resolver("int", lambda x: int(x))
|
||||
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
|
||||
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
|
||||
|
||||
|
||||
def setup_distributed_backend(backend, timeout_mins):
|
||||
"""
|
||||
Initialize torch.distributed and set the CUDA device.
|
||||
Expects environment variables to be set as per
|
||||
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
|
||||
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
|
||||
"""
|
||||
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
|
||||
# of waiting
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
|
||||
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_machine_local_and_dist_rank():
|
||||
"""
|
||||
Get the distributed and local rank of the current gpu.
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||
distributed_rank = int(os.environ.get("RANK", None))
|
||||
assert (
|
||||
local_rank is not None and distributed_rank is not None
|
||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
||||
return local_rank, distributed_rank
|
||||
|
||||
|
||||
def print_cfg(cfg):
|
||||
"""
|
||||
Supports printing both Hydra DictConfig and also the AttrDict config
|
||||
"""
|
||||
logging.info("Training with config:")
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
|
||||
def set_seeds(seed_value, max_epochs, dist_rank):
|
||||
"""
|
||||
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
|
||||
seeds if the CUDA is available. This ensures deterministic nature of the training.
|
||||
"""
|
||||
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
|
||||
seed_value = (seed_value + dist_rank) * max_epochs
|
||||
logging.info(f"MACHINE SEED: {seed_value}")
|
||||
random.seed(seed_value)
|
||||
np.random.seed(seed_value)
|
||||
torch.manual_seed(seed_value)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed_value)
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
logging.info(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_amp_type(amp_type: Optional[str] = None):
|
||||
if amp_type is None:
|
||||
return None
|
||||
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
|
||||
if amp_type == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float16
|
||||
|
||||
|
||||
def log_env_variables():
|
||||
env_keys = sorted(list(os.environ.keys()))
|
||||
st = ""
|
||||
for k in env_keys:
|
||||
v = os.environ[k]
|
||||
st += f"{k}={v}\n"
|
||||
logging.info("Logging ENV_VARIABLES")
|
||||
logging.info(st)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class MemMeter:
|
||||
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0 # Per iteration max usage
|
||||
self.avg = 0 # Avg per iteration max usage
|
||||
self.peak = 0 # Peak usage for lifetime of program
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, n=1, reset_peak_usage=True):
|
||||
self.val = torch.cuda.max_memory_allocated() // 1e9
|
||||
self.sum += self.val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
self.peak = max(self.peak, self.val)
|
||||
if reset_peak_usage:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = (
|
||||
"{name}: {val"
|
||||
+ self.fmt
|
||||
+ "} ({avg"
|
||||
+ self.fmt
|
||||
+ "}/{peak"
|
||||
+ self.fmt
|
||||
+ "})"
|
||||
)
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def human_readable_time(time_seconds):
|
||||
time = int(time_seconds)
|
||||
minutes, seconds = divmod(time, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
days, hours = divmod(hours, 24)
|
||||
return f"{days:02}d {hours:02}h {minutes:02}m"
|
||||
|
||||
|
||||
class DurationMeter:
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.fmt = fmt
|
||||
self.val = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
|
||||
def add(self, val):
|
||||
self.val += val
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}: {human_readable_time(self.val)}"
|
||||
|
||||
|
||||
class ProgressMeter:
|
||||
def __init__(self, num_batches, meters, real_meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.real_meters = real_meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch, enable_print=False):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
entries += [
|
||||
" | ".join(
|
||||
[
|
||||
f"{os.path.join(name, subname)}: {val:.4f}"
|
||||
for subname, val in meter.compute().items()
|
||||
]
|
||||
)
|
||||
for name, meter in self.real_meters.items()
|
||||
]
|
||||
logging.info(" | ".join(entries))
|
||||
if enable_print:
|
||||
print(" | ".join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = "{:" + str(num_digits) + "d}"
|
||||
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
||||
|
||||
|
||||
def get_resume_checkpoint(checkpoint_save_dir):
|
||||
if not g_pathmgr.isdir(checkpoint_save_dir):
|
||||
return None
|
||||
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
|
||||
if not g_pathmgr.isfile(ckpt_file):
|
||||
return None
|
||||
|
||||
return ckpt_file
|
Reference in New Issue
Block a user