init commit of samurai

This commit is contained in:
Cheng-Yen Yang
2024-11-19 22:12:54 -08:00
parent f65f4ba181
commit c17e4cecc0
679 changed files with 123982 additions and 0 deletions

116
sam2/training/README.md Normal file
View File

@@ -0,0 +1,116 @@
# Training Code for SAM 2
This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos.
The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both).
## Structure
The training code is organized into the following subfolders:
* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms.
* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling).
* `utils`: This folder contains training utils such as loggers and distributed training utils.
* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training.
* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training.
* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers.
* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop.
* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h`
## Getting Started
To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets.
#### Requirements:
- We assume training on A100 GPUs with **80 GB** of memory.
- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download).
#### Steps to fine-tune on MOSE:
- Install the packages required for training by running `pip install -e ".[dev]"`.
- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`.
```yaml
dataset:
# PATHS to Dataset
img_folder: null # PATH to MOSE JPEGImages folder
gt_folder: null # PATH to MOSE Annotations folder
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
```
- To fine-tune the base model on MOSE using 8 GPUs, run
```python
python training/train.py \
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
--use-cluster 0 \
--num-gpus 8
```
We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running
```python
python training/train.py \
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
--use-cluster 1 \
--num-gpus 8 \
--num-nodes 2
--partition $PARTITION \
--qos $QOS \
--account $ACCOUNT
```
where partition, qos, and account are optional and depend on your SLURM configuration.
By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows:
```yaml
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
```
The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4.
After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)).
## Training on images and videos
The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets:
```yaml
data:
train:
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
batch_sizes: # List of batch sizes corresponding to each dataset
- ${bs1} # Batch size of dataset 1
- ${bs2} # Batch size of dataset 2
datasets:
# SA1B as an example of an image dataset
- _target_: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.SA1BRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 1
max_num_objects: ${max_num_objects_per_image}
transforms: ${image_transforms}
# SA-V as an example of a video dataset
- _target_: training.dataset.vos_dataset.VOSDataset
training: true
video_dataset:
_target_: training.dataset.vos_raw_dataset.JSONRawDataset
img_folder: ${path_to_img_folder}
gt_folder: ${path_to_gt_folder}
file_list_txt: ${path_to_train_filelist} # Optional
ann_every: 4
sampler:
_target_: training.dataset.vos_sampler.RandomUniformSampler
num_frames: 8 # Number of frames per video
max_num_objects: ${max_num_objects_per_video}
reverse_time_prob: ${reverse_time_prob} # probability to reverse video
transforms: ${video_transforms}
shuffle: True
num_workers: ${num_train_workers}
pin_memory: True
drop_last: True
collate_fn:
_target_: training.utils.data_utils.collate_fn
_partial_: true
dict_key: all
```

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,200 @@
32e5d721
5bad0bab
267bfd6c
0a43a414
56c56ca9
9a1146b3
c6ad7aaf
78a1f4b1
fc455e73
072e7b3f
77ccb57d
a76ee415
8cdcfc17
5d518b42
376dd830
0e843fc8
2af0e766
2bd4e845
de2f2a6a
ade9ee91
001ca3cb
fc4c1c67
8ef55579
b84ce852
4cc8528a
767ffaaa
112a2ef0
a338c8aa
cbd144f5
5ff72128
86a949e2
9f2323ac
1fab1d1c
75924351
ef55817b
02deca50
4d979d99
4d65f873
28470fa0
0d1575fe
06ea172e
29a6ddc2
797f1bec
780e7a99
b9ed5b44
02a236b4
607d8ff5
af5666b2
0558d0ed
a938c6b2
103df575
77110e80
739e5a07
6763a576
06ebc138
ba4b3b09
b35cc2f3
4e0597a0
5949ee84
5348d547
323c4236
b3b51117
55727ddd
ab2714f3
d2878895
c0734cb3
94f7c53e
2a2745e5
442ffb54
3592425a
50ae03b0
5f150435
3067f9fa
9ffb2818
adeaf5aa
31caacec
1cd99b86
aa22f9d0
8fa50320
e6348d2c
42ff84a5
8c8b7913
c96adcbc
495be321
db735509
ee113fc4
a678cdab
c409ca4d
68d2b259
592b4dee
4e2b4dc7
eb4d26e1
2009a00f
bec5c89d
67191f24
a3e85b4b
da7080cd
80d978e9
36dcb93f
a41e8c44
12fdc864
46d140ea
657c9dd9
a86f84ee
90c1c43d
33015509
afc7664d
23df06e1
291d4799
0ab75563
251bf059
bcefdcc4
ce9a2796
94d3403a
8f2e04bc
f9cda066
9dfa2cc5
66924c91
e765a09e
15654ee1
48e0bd39
ee095221
2463609b
544d0d1f
51b8c2e1
d321dde4
4cb11a5f
d7058a0d
37af282a
fabae187
7be91184
181ec185
2d16ceeb
b56be4b1
6699eff0
79acac96
d61c4665
0c13e1e7
100f6ecf
71217dfc
82df0888
4c42c747
c9fdf703
d2efeb4b
69ed9d14
64914fb6
255bedbc
4ea934d8
a034feb2
e4f4ddae
e36a3026
c1489591
111bb373
e1d9fb32
93e22d48
c1ec4b26
d9638e69
60ab04c5
cfe7773a
62132822
2f5fb2a3
7bdd197d
033333fd
130fcdbe
12e509c2
67138c33
6f90cc5f
4e3020fe
bbdd8bb7
b399ccdb
fecd10d2
2e0967f7
f509054f
792c6ff7
48e2afc5
d904c048
111e0a5c
b83024e2
e6a7b79c
bdc5ccf7
b8146d00
9d394f1a
645b84f9
95ab2d0f
e6f8a31d
b4f876fb
dc2c570d
3afd02d7
5c80c82c
b1b32ddd
9f25fc61
ba538072
f8916fef
43c04ad2
a658e949
2861dd53
f6e40aba
09d305d1
aac33bff
8d9d4c08

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from typing import Callable, Iterable, List, Optional, Sequence
import torch
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler
class MixedDataLoader:
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
"""
Args:
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
"""
assert len(dataloaders) == mixing_prob.shape[0]
self.dataloaders = dataloaders
self.mixing_prob = mixing_prob
# Iterator state
self._iter_dls = None
self._iter_mixing_prob = None
self.random_generator = torch.Generator()
def __len__(self):
return sum([len(d) for d in self.dataloaders])
def __iter__(self):
# Synchronize dataloader seeds
self.random_generator.manual_seed(42)
self._iter_dls = [iter(loader) for loader in self.dataloaders]
self._iter_mixing_prob = self.mixing_prob.clone()
return self
def __next__(self):
"""
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
"""
if self._iter_dls is None:
raise TypeError(f"{type(self).__name__} object is not an iterator")
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
dataset_idx = self._iter_mixing_prob.multinomial(
1, generator=self.random_generator
).item()
try:
item = next(self._iter_dls[dataset_idx])
return item
except StopIteration:
# No more iterations for this dataset, set it's mixing probability to zero and try again.
self._iter_mixing_prob[dataset_idx] = 0
except Exception as e:
# log and raise any other unexpected error.
logging.error(e)
raise e
# Exhausted all iterators
raise StopIteration
class TorchTrainMixedDataset:
def __init__(
self,
datasets: List[Dataset],
batch_sizes: List[int],
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
phases_per_epoch: int = 1,
dataset_prob: Optional[List[float]] = None,
) -> None:
"""
Args:
datasets (List[Dataset]): List of Datasets to be mixed.
batch_sizes (List[int]): Batch sizes for each dataset in the list.
num_workers (int): Number of workers per dataloader.
shuffle (bool): Whether or not to shuffle data.
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
drop_last (bool): Whether or not to drop the last batch of data.
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
worker_init_fn (Callable): Function to init each dataloader worker.
phases_per_epoch (int): Number of phases per epoch.
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
"""
self.datasets = datasets
self.batch_sizes = batch_sizes
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert len(self.datasets) > 0
for dataset in self.datasets:
assert not isinstance(dataset, IterableDataset), "Not supported"
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
self._set_dataset_epoch(dataset, 0)
self.phases_per_epoch = phases_per_epoch
self.chunks = [None] * len(datasets)
if dataset_prob is None:
# If not provided, assign each dataset a probability proportional to its length.
dataset_lens = [
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
for d, bs in zip(datasets, batch_sizes)
]
total_len = sum(dataset_lens)
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
else:
assert len(dataset_prob) == len(datasets)
dataset_prob = torch.tensor(dataset_prob)
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
self.dataset_prob = dataset_prob
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
def get_loader(self, epoch) -> Iterable:
dataloaders = []
for d_idx, (dataset, batch_size) in enumerate(
zip(self.datasets, self.batch_sizes)
):
if self.phases_per_epoch > 1:
# Major epoch that looops over entire dataset
# len(main_epoch) == phases_per_epoch * len(epoch)
main_epoch = epoch // self.phases_per_epoch
# Phase with in the main epoch
local_phase = epoch % self.phases_per_epoch
# Start of new data-epoch or job is resumed after preemtion.
if local_phase == 0 or self.chunks[d_idx] is None:
# set seed for dataset epoch
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
self._set_dataset_epoch(dataset, main_epoch)
# Separate random generator for subset sampling
g = torch.Generator()
g.manual_seed(main_epoch)
self.chunks[d_idx] = torch.chunk(
torch.randperm(len(dataset), generator=g),
self.phases_per_epoch,
)
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
else:
self._set_dataset_epoch(dataset, epoch)
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
sampler.set_epoch(epoch)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
dataloaders.append(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)
)
return MixedDataLoader(dataloaders, self.dataset_prob)

View File

@@ -0,0 +1,528 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Transforms and data augmentation for both image + bbox.
"""
import logging
import random
from typing import Iterable
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from torchvision.transforms import InterpolationMode
from training.utils.data_utils import VideoDatapoint
def hflip(datapoint, index):
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.hflip(obj.segment)
return datapoint
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = max_size * min_original_size / max_original_size
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = int(round(size))
oh = int(round(size * h / w))
else:
oh = int(round(size))
ow = int(round(size * w / h))
return (oh, ow)
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
cur_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
size = get_size(cur_size, size, max_size)
old_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
if v2:
datapoint.frames[index].data = Fv2.resize(
datapoint.frames[index].data, size, antialias=True
)
else:
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
new_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
h, w = size
datapoint.frames[index].size = (h, w)
return datapoint
def pad(datapoint, index, padding, v2=False):
old_h, old_w = datapoint.frames[index].size
h, w = old_h, old_w
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data, (0, 0, padding[0], padding[1])
)
h += padding[1]
w += padding[0]
else:
# left, top, right, bottom
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data,
(padding[0], padding[1], padding[2], padding[3]),
)
h += padding[1] + padding[3]
w += padding[0] + padding[2]
datapoint.frames[index].size = (h, w)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
if v2:
if len(padding) == 2:
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = Fv2.pad(obj.segment, tuple(padding))
else:
if len(padding) == 2:
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = F.pad(obj.segment, tuple(padding))
return datapoint
class RandomHorizontalFlip:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for i in range(len(datapoint.frames)):
datapoint = hflip(datapoint, i)
return datapoint
for i in range(len(datapoint.frames)):
if random.random() < self.p:
datapoint = hflip(datapoint, i)
return datapoint
class RandomResizeAPI:
def __init__(
self, sizes, consistent_transform, max_size=None, square=False, v2=False
):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
self.consistent_transform = consistent_transform
self.v2 = v2
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
size = random.choice(self.sizes)
for i in range(len(datapoint.frames)):
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
for i in range(len(datapoint.frames)):
size = random.choice(self.sizes)
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
class ToTensorAPI:
def __init__(self, v2=False):
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.to_image_tensor(img.data)
else:
img.data = F.to_tensor(img.data)
return datapoint
class NormalizeAPI:
def __init__(self, mean, std, v2=False):
self.mean = mean
self.std = std
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
else:
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
return datapoint
class ComposeAPI:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, datapoint, **kwargs):
for t in self.transforms:
datapoint = t(datapoint, **kwargs)
return datapoint
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomGrayscale:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
self.Grayscale = T.Grayscale(num_output_channels=3)
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for img in datapoint.frames:
img.data = self.Grayscale(img.data)
return datapoint
for img in datapoint.frames:
if random.random() < self.p:
img.data = self.Grayscale(img.data)
return datapoint
class ColorJitter:
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
self.consistent_transform = consistent_transform
self.brightness = (
brightness
if isinstance(brightness, list)
else [max(0, 1 - brightness), 1 + brightness]
)
self.contrast = (
contrast
if isinstance(contrast, list)
else [max(0, 1 - contrast), 1 + contrast]
)
self.saturation = (
saturation
if isinstance(saturation, list)
else [max(0, 1 - saturation), 1 + saturation]
)
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
# Create a color jitter transformation params
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for img in datapoint.frames:
if not self.consistent_transform:
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img.data = F.adjust_brightness(img.data, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img.data = F.adjust_contrast(img.data, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img.data = F.adjust_saturation(img.data, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img.data = F.adjust_hue(img.data, hue_factor)
return datapoint
class RandomAffine:
def __init__(
self,
degrees,
consistent_transform,
scale=None,
translate=None,
shear=None,
image_mean=(123, 116, 103),
log_warning=True,
num_tentatives=1,
image_interpolation="bicubic",
):
"""
The mask is required for this transform.
if consistent_transform if True, then the same random affine is applied to all frames and masks.
"""
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
self.scale = scale
self.shear = (
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
)
self.translate = translate
self.fill_img = image_mean
self.consistent_transform = consistent_transform
self.log_warning = log_warning
self.num_tentatives = num_tentatives
if image_interpolation == "bicubic":
self.image_interpolation = InterpolationMode.BICUBIC
elif image_interpolation == "bilinear":
self.image_interpolation = InterpolationMode.BILINEAR
else:
raise NotImplementedError
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for _tentative in range(self.num_tentatives):
res = self.transform_datapoint(datapoint)
if res is not None:
return res
if self.log_warning:
logging.warning(
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
)
return datapoint
def transform_datapoint(self, datapoint: VideoDatapoint):
_, height, width = F.get_dimensions(datapoint.frames[0].data)
img_size = [width, height]
if self.consistent_transform:
# Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
for img_idx, img in enumerate(datapoint.frames):
this_masks = [
obj.segment.unsqueeze(0) if obj.segment is not None else None
for obj in img.objects
]
if not self.consistent_transform:
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
transformed_bboxes, transformed_masks = [], []
for i in range(len(img.objects)):
if this_masks[i] is None:
transformed_masks.append(None)
# Dummy bbox for a dummy target
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
else:
transformed_mask = F.affine(
this_masks[i],
*affine_params,
interpolation=InterpolationMode.NEAREST,
fill=0.0,
)
if img_idx == 0 and transformed_mask.max() == 0:
# We are dealing with a video and the object is not visible in the first frame
# Return the datapoint without transformation
return None
transformed_masks.append(transformed_mask.squeeze())
for i in range(len(img.objects)):
img.objects[i].segment = transformed_masks[i]
img.data = F.affine(
img.data,
*affine_params,
interpolation=self.image_interpolation,
fill=self.fill_img,
)
return datapoint
def random_mosaic_frame(
datapoint,
index,
grid_h,
grid_w,
target_grid_y,
target_grid_x,
should_hflip,
):
# Step 1: downsize the images and paste them into a mosaic
image_data = datapoint.frames[index].data
is_pil = isinstance(image_data, PILImage.Image)
if is_pil:
H_im = image_data.height
W_im = image_data.width
image_data_output = PILImage.new("RGB", (W_im, H_im))
else:
H_im = image_data.size(-2)
W_im = image_data.size(-1)
image_data_output = torch.zeros_like(image_data)
downsize_cache = {}
for grid_y in range(grid_h):
for grid_x in range(grid_w):
y_offset_b = grid_y * H_im // grid_h
x_offset_b = grid_x * W_im // grid_w
y_offset_e = (grid_y + 1) * H_im // grid_h
x_offset_e = (grid_x + 1) * W_im // grid_w
H_im_downsize = y_offset_e - y_offset_b
W_im_downsize = x_offset_e - x_offset_b
if (H_im_downsize, W_im_downsize) in downsize_cache:
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
else:
image_data_downsize = F.resize(
image_data,
size=(H_im_downsize, W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
if should_hflip[grid_y, grid_x].item():
image_data_downsize = F.hflip(image_data_downsize)
if is_pil:
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
else:
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
image_data_downsize
)
datapoint.frames[index].data = image_data_output
# Step 2: downsize the masks and paste them into the target grid of the mosaic
for obj in datapoint.frames[index].objects:
if obj.segment is None:
continue
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
segment_output = torch.zeros_like(obj.segment)
target_y_offset_b = target_grid_y * H_im // grid_h
target_x_offset_b = target_grid_x * W_im // grid_w
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
target_H_im_downsize = target_y_offset_e - target_y_offset_b
target_W_im_downsize = target_x_offset_e - target_x_offset_b
segment_downsize = F.resize(
obj.segment[None, None],
size=(target_H_im_downsize, target_W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)[0, 0]
if should_hflip[target_grid_y, target_grid_x].item():
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
segment_output[
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
] = segment_downsize
obj.segment = segment_output
return datapoint
class RandomMosaicVideoAPI:
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
self.prob = prob
self.grid_h = grid_h
self.grid_w = grid_w
self.use_random_hflip = use_random_hflip
def __call__(self, datapoint, **kwargs):
if random.random() > self.prob:
return datapoint
# select a random location to place the target mask in the mosaic
target_grid_y = random.randint(0, self.grid_h - 1)
target_grid_x = random.randint(0, self.grid_w - 1)
# whether to flip each grid in the mosaic horizontally
if self.use_random_hflip:
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
else:
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
for i in range(len(datapoint.frames)):
datapoint = random_mosaic_frame(
datapoint,
i,
grid_h=self.grid_h,
grid_w=self.grid_w,
target_grid_y=target_grid_y,
target_grid_x=target_grid_x,
should_hflip=should_hflip,
)
return datapoint

View File

@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
from typing import Iterable
import torch
from torch.utils.data import (
ConcatDataset as TorchConcatDataset,
Dataset,
Subset as TorchSubset,
)
class ConcatDataset(TorchConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__(datasets)
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
def set_epoch(self, epoch: int):
for dataset in self.datasets:
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
class Subset(TorchSubset):
def __init__(self, dataset, indices) -> None:
super(Subset, self).__init__(dataset, indices)
self.repeat_factors = dataset.repeat_factors[indices]
assert len(indices) == len(self.repeat_factors)
# Adapted from Detectron2
class RepeatFactorWrapper(Dataset):
"""
Thin wrapper around a dataset to implement repeat factor sampling.
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
Set it to uniformly ones to disable repeat factor sampling
"""
def __init__(self, dataset, seed: int = 0):
self.dataset = dataset
self.epoch_ids = None
self._seed = seed
# Split into whole number (_int_part) and fractional (_frac_part) parts.
self._int_part = torch.trunc(dataset.repeat_factors)
self._frac_part = dataset.repeat_factors - self._int_part
def _get_epoch_indices(self, generator):
"""
Create a list of dataset indices (with repeats) to use for one epoch.
Args:
generator (torch.Generator): pseudo random number generator used for
stochastic rounding.
Returns:
torch.Tensor: list of dataset indices to use in one epoch. Each index
is repeated based on its calculated repeat factor.
"""
# Since repeat factors are fractional, we use stochastic rounding so
# that the target repeat factor is achieved in expectation over the
# course of training
rands = torch.rand(len(self._frac_part), generator=generator)
rep_factors = self._int_part + (rands < self._frac_part).float()
# Construct a list of indices in which we repeat images as specified
indices = []
for dataset_index, rep_factor in enumerate(rep_factors):
indices.extend([dataset_index] * int(rep_factor.item()))
return torch.tensor(indices, dtype=torch.int64)
def __len__(self):
if self.epoch_ids is None:
# Here we raise an error instead of returning original len(self.dataset) avoid
# accidentally using unwrapped length. Otherwise it's error-prone since the
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
raise RuntimeError("please call set_epoch first to get wrapped length")
# return len(self.dataset)
return len(self.epoch_ids)
def set_epoch(self, epoch: int):
g = torch.Generator()
g.manual_seed(self._seed + epoch)
self.epoch_ids = self._get_epoch_indices(g)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def __getitem__(self, idx):
if self.epoch_ids is None:
raise RuntimeError(
"Repeat ids haven't been computed. Did you forget to call set_epoch?"
)
return self.dataset[self.epoch_ids[idx]]

View File

@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from copy import deepcopy
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from torchvision.datasets.vision import VisionDataset
from training.dataset.vos_raw_dataset import VOSRawDataset
from training.dataset.vos_sampler import VOSSampler
from training.dataset.vos_segment_loader import JSONSegmentLoader
from training.utils.data_utils import Frame, Object, VideoDatapoint
MAX_RETRIES = 100
class VOSDataset(VisionDataset):
def __init__(
self,
transforms,
training: bool,
video_dataset: VOSRawDataset,
sampler: VOSSampler,
multiplier: int,
always_target=True,
target_segments_available=True,
):
self._transforms = transforms
self.training = training
self.video_dataset = video_dataset
self.sampler = sampler
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.video_dataset)}")
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.always_target = always_target
self.target_segments_available = target_segments_available
def _get_datapoint(self, idx):
for retry in range(MAX_RETRIES):
try:
if isinstance(idx, torch.Tensor):
idx = idx.item()
# sample a video
video, segment_loader = self.video_dataset.get_video(idx)
# sample frames and object indices to be used in a datapoint
sampled_frms_and_objs = self.sampler.sample(
video, segment_loader, epoch=self.curr_epoch
)
break # Succesfully loaded video
except Exception as e:
if self.training:
logging.warning(
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
)
idx = random.randrange(0, len(self.video_dataset))
else:
# Shouldn't fail to load a val video
raise e
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
return datapoint
def construct(self, video, sampled_frms_and_objs, segment_loader):
"""
Constructs a VideoDatapoint sample to pass to transforms
"""
sampled_frames = sampled_frms_and_objs.frames
sampled_object_ids = sampled_frms_and_objs.object_ids
images = []
rgb_images = load_images(sampled_frames)
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
for frame_idx, frame in enumerate(sampled_frames):
w, h = rgb_images[frame_idx].size
images.append(
Frame(
data=rgb_images[frame_idx],
objects=[],
)
)
# We load the gt segments associated with the current frame
if isinstance(segment_loader, JSONSegmentLoader):
segments = segment_loader.load(
frame.frame_idx, obj_ids=sampled_object_ids
)
else:
segments = segment_loader.load(frame.frame_idx)
for obj_id in sampled_object_ids:
# Extract the segment
if obj_id in segments:
assert (
segments[obj_id] is not None
), "None targets are not supported"
# segment is uint8 and remains uint8 throughout the transforms
segment = segments[obj_id].to(torch.uint8)
else:
# There is no target, we either use a zero mask target or drop this object
if not self.always_target:
continue
segment = torch.zeros(h, w, dtype=torch.uint8)
images[frame_idx].objects.append(
Object(
object_id=obj_id,
frame_index=frame.frame_idx,
segment=segment,
)
)
return VideoDatapoint(
frames=images,
video_id=video.video_id,
size=(h, w),
)
def __getitem__(self, idx):
return self._get_datapoint(idx)
def __len__(self):
return len(self.video_dataset)
def load_images(frames):
all_images = []
cache = {}
for frame in frames:
if frame.data is None:
# Load the frame rgb data from file
path = frame.image_path
if path in cache:
all_images.append(deepcopy(all_images[cache[path]]))
continue
with g_pathmgr.open(path, "rb") as fopen:
all_images.append(PILImage.open(fopen).convert("RGB"))
cache[path] = len(all_images) - 1
else:
# The frame rgb data has already been loaded
# Convert it to a PILImage
all_images.append(tensor_2_PIL(frame.data))
return all_images
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
data = data.astype(np.uint8)
return PILImage.fromarray(data)

View File

@@ -0,0 +1,308 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import logging
import os
from dataclasses import dataclass
from typing import List, Optional
import pandas as pd
import torch
from iopath.common.file_io import g_pathmgr
from omegaconf.listconfig import ListConfig
from training.dataset.vos_segment_loader import (
JSONSegmentLoader,
MultiplePNGSegmentLoader,
PalettisedPNGSegmentLoader,
SA1BSegmentLoader,
)
@dataclass
class VOSFrame:
frame_idx: int
image_path: str
data: Optional[torch.Tensor] = None
is_conditioning_only: Optional[bool] = False
@dataclass
class VOSVideo:
video_name: str
video_id: int
frames: List[VOSFrame]
def __len__(self):
return len(self.frames)
class VOSRawDataset:
def __init__(self):
pass
def get_video(self, idx):
raise NotImplementedError()
class PNGRawDataset(VOSRawDataset):
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
sample_rate=1,
is_palette=True,
single_object_mode=False,
truncate_video=-1,
frames_sampling_mult=False,
):
self.img_folder = img_folder
self.gt_folder = gt_folder
self.sample_rate = sample_rate
self.is_palette = is_palette
self.single_object_mode = single_object_mode
self.truncate_video = truncate_video
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
# Read and process excluded files if provided
if excluded_videos_list_txt is not None:
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
else:
excluded_files = []
# Check if it's not in excluded_files
self.video_names = sorted(
[video_name for video_name in subset if video_name not in excluded_files]
)
if self.single_object_mode:
# single object mode
self.video_names = sorted(
[
os.path.join(video_name, obj)
for video_name in self.video_names
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
]
)
if frames_sampling_mult:
video_names_mult = []
for video_name in self.video_names:
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
video_names_mult.extend([video_name] * num_frames)
self.video_names = video_names_mult
def get_video(self, idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[idx]
if self.single_object_mode:
video_frame_root = os.path.join(
self.img_folder, os.path.dirname(video_name)
)
else:
video_frame_root = os.path.join(self.img_folder, video_name)
video_mask_root = os.path.join(self.gt_folder, video_name)
if self.is_palette:
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
else:
segment_loader = MultiplePNGSegmentLoader(
video_mask_root, self.single_object_mode
)
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
if self.truncate_video > 0:
all_frames = all_frames[: self.truncate_video]
frames = []
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
fid = int(os.path.basename(fpath).split(".")[0])
frames.append(VOSFrame(fid, image_path=fpath))
video = VOSVideo(video_name, idx, frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)
class SA1BRawDataset(VOSRawDataset):
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
num_frames=1,
mask_area_frac_thresh=1.1, # no filtering by default
uncertain_iou=-1, # no filtering by default
):
self.img_folder = img_folder
self.gt_folder = gt_folder
self.num_frames = num_frames
self.mask_area_frac_thresh = mask_area_frac_thresh
self.uncertain_iou = uncertain_iou # stability score
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
subset = [
path.split(".")[0] for path in subset if path.endswith(".jpg")
] # remove extension
# Read and process excluded files if provided
if excluded_videos_list_txt is not None:
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
else:
excluded_files = []
# Check if it's not in excluded_files and it exists
self.video_names = [
video_name for video_name in subset if video_name not in excluded_files
]
def get_video(self, idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[idx]
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
segment_loader = SA1BSegmentLoader(
video_mask_path,
mask_area_frac_thresh=self.mask_area_frac_thresh,
video_frame_path=video_frame_path,
uncertain_iou=self.uncertain_iou,
)
frames = []
for frame_idx in range(self.num_frames):
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
video_name = video_name.split("_")[-1] # filename is sa_{int}
# video id needs to be image_id to be able to load correct annotation file during eval
video = VOSVideo(video_name, int(video_name), frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)
class JSONRawDataset(VOSRawDataset):
"""
Dataset where the annotation in the format of SA-V json files
"""
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
sample_rate=1,
rm_unannotated=True,
ann_every=1,
frames_fps=24,
):
self.gt_folder = gt_folder
self.img_folder = img_folder
self.sample_rate = sample_rate
self.rm_unannotated = rm_unannotated
self.ann_every = ann_every
self.frames_fps = frames_fps
# Read and process excluded files if provided
excluded_files = []
if excluded_videos_list_txt is not None:
if isinstance(excluded_videos_list_txt, str):
excluded_videos_lists = [excluded_videos_list_txt]
elif isinstance(excluded_videos_list_txt, ListConfig):
excluded_videos_lists = list(excluded_videos_list_txt)
else:
raise NotImplementedError
for excluded_videos_list_txt in excluded_videos_lists:
with open(excluded_videos_list_txt, "r") as f:
excluded_files.extend(
[os.path.splitext(line.strip())[0] for line in f]
)
excluded_files = set(excluded_files)
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
self.video_names = sorted(
[video_name for video_name in subset if video_name not in excluded_files]
)
def get_video(self, video_idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[video_idx]
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
segment_loader = JSONSegmentLoader(
video_json_path=video_json_path,
ann_every=self.ann_every,
frames_fps=self.frames_fps,
)
frame_ids = [
int(os.path.splitext(frame_name)[0])
for frame_name in sorted(
os.listdir(os.path.join(self.img_folder, video_name))
)
]
frames = [
VOSFrame(
frame_id,
image_path=os.path.join(
self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
),
)
for frame_id in frame_ids[:: self.sample_rate]
]
if self.rm_unannotated:
# Eliminate the frames that have not been annotated
valid_frame_ids = [
i * segment_loader.ann_every
for i, annot in enumerate(segment_loader.frame_annots)
if annot is not None and None not in annot
]
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
video = VOSVideo(video_name, video_idx, frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)

View File

@@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from dataclasses import dataclass
from typing import List
from training.dataset.vos_segment_loader import LazySegments
MAX_RETRIES = 1000
@dataclass
class SampledFramesAndObjects:
frames: List[int]
object_ids: List[int]
class VOSSampler:
def __init__(self, sort_frames=True):
# frames are ordered by frame id when sort_frames is True
self.sort_frames = sort_frames
def sample(self, video):
raise NotImplementedError()
class RandomUniformSampler(VOSSampler):
def __init__(
self,
num_frames,
max_num_objects,
reverse_time_prob=0.0,
):
self.num_frames = num_frames
self.max_num_objects = max_num_objects
self.reverse_time_prob = reverse_time_prob
def sample(self, video, segment_loader, epoch=None):
for retry in range(MAX_RETRIES):
if len(video.frames) < self.num_frames:
raise Exception(
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
)
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
frames = [video.frames[start + step] for step in range(self.num_frames)]
if random.uniform(0, 1) < self.reverse_time_prob:
# Reverse time
frames = frames[::-1]
# Get first frame object ids
visible_object_ids = []
loaded_segms = segment_loader.load(frames[0].frame_idx)
if isinstance(loaded_segms, LazySegments):
# LazySegments for SA1BRawDataset
visible_object_ids = list(loaded_segms.keys())
else:
for object_id, segment in segment_loader.load(
frames[0].frame_idx
).items():
if segment.sum():
visible_object_ids.append(object_id)
# First frame needs to have at least a target to track
if len(visible_object_ids) > 0:
break
if retry >= MAX_RETRIES - 1:
raise Exception("No visible objects")
object_ids = random.sample(
visible_object_ids,
min(len(visible_object_ids), self.max_num_objects),
)
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
class EvalSampler(VOSSampler):
"""
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
"""
def __init__(
self,
):
super().__init__()
def sample(self, video, segment_loader, epoch=None):
"""
Sampling all the frames and all the objects
"""
if self.sort_frames:
# ordered by frame id
frames = sorted(video.frames, key=lambda x: x.frame_idx)
else:
# use the original order
frames = video.frames
object_ids = segment_loader.load(frames[0].frame_idx).keys()
if len(object_ids) == 0:
raise Exception("First frame of the video has no objects")
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)

View File

@@ -0,0 +1,300 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import json
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage
try:
from pycocotools import mask as mask_utils
except:
pass
class JSONSegmentLoader:
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
# Annotations in the json are provided every ann_every th frame
self.ann_every = ann_every
# Ids of the objects to consider when sampling this video
self.valid_obj_ids = valid_obj_ids
with open(video_json_path, "r") as f:
data = json.load(f)
if isinstance(data, list):
self.frame_annots = data
elif isinstance(data, dict):
masklet_field_name = "masklet" if "masklet" in data else "masks"
self.frame_annots = data[masklet_field_name]
if "fps" in data:
if isinstance(data["fps"], list):
annotations_fps = int(data["fps"][0])
else:
annotations_fps = int(data["fps"])
assert frames_fps % annotations_fps == 0
self.ann_every = frames_fps // annotations_fps
else:
raise NotImplementedError
def load(self, frame_id, obj_ids=None):
assert frame_id % self.ann_every == 0
rle_mask = self.frame_annots[frame_id // self.ann_every]
valid_objs_ids = set(range(len(rle_mask)))
if self.valid_obj_ids is not None:
# Remove the masklets that have been filtered out for this video
valid_objs_ids &= set(self.valid_obj_ids)
if obj_ids is not None:
# Only keep the objects that have been sampled
valid_objs_ids &= set(obj_ids)
valid_objs_ids = sorted(list(valid_objs_ids))
# Construct rle_masks_filtered that only contains the rle masks we are interested in
id_2_idx = {}
rle_mask_filtered = []
for obj_id in valid_objs_ids:
if rle_mask[obj_id] is not None:
id_2_idx[obj_id] = len(rle_mask_filtered)
rle_mask_filtered.append(rle_mask[obj_id])
else:
id_2_idx[obj_id] = None
# Decode the masks
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
2, 0, 1
) # num_obj, h, w
segments = {}
for obj_id in valid_objs_ids:
if id_2_idx[obj_id] is None:
segments[obj_id] = None
else:
idx = id_2_idx[obj_id]
segments[obj_id] = raw_segments[idx]
return segments
def get_valid_obj_frames_ids(self, num_frames_min=None):
# For each object, find all the frames with a valid (not None) mask
num_objects = len(self.frame_annots[0])
# The result dict associates each obj_id with the id of its valid frames
res = {obj_id: [] for obj_id in range(num_objects)}
for annot_idx, annot in enumerate(self.frame_annots):
for obj_id in range(num_objects):
if annot[obj_id] is not None:
res[obj_id].append(int(annot_idx * self.ann_every))
if num_frames_min is not None:
# Remove masklets that have less than num_frames_min valid masks
for obj_id, valid_frames in list(res.items()):
if len(valid_frames) < num_frames_min:
res.pop(obj_id)
return res
class PalettisedPNGSegmentLoader:
def __init__(self, video_png_root):
"""
SegmentLoader for datasets with masks stored as palettised PNGs.
video_png_root: the folder contains all the masks stored in png
"""
self.video_png_root = video_png_root
# build a mapping from frame id to their PNG mask path
# note that in some datasets, the PNG paths could have more
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
png_filenames = os.listdir(self.video_png_root)
self.frame_id_to_png_filename = {}
for filename in png_filenames:
frame_id, _ = os.path.splitext(filename)
self.frame_id_to_png_filename[int(frame_id)] = filename
def load(self, frame_id):
"""
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
# check the path
mask_path = os.path.join(
self.video_png_root, self.frame_id_to_png_filename[frame_id]
)
# load the mask
masks = PILImage.open(mask_path).convert("P")
masks = np.array(masks)
object_id = pd.unique(masks.flatten())
object_id = object_id[object_id != 0] # remove background (0)
# convert into N binary segmentation masks
binary_segments = {}
for i in object_id:
bs = masks == i
binary_segments[i] = torch.from_numpy(bs)
return binary_segments
def __len__(self):
return
class MultiplePNGSegmentLoader:
def __init__(self, video_png_root, single_object_mode=False):
"""
video_png_root: the folder contains all the masks stored in png
single_object_mode: whether to load only a single object at a time
"""
self.video_png_root = video_png_root
self.single_object_mode = single_object_mode
# read a mask to know the resolution of the video
if self.single_object_mode:
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
else:
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
tmp_mask = np.array(PILImage.open(tmp_mask_path))
self.H = tmp_mask.shape[0]
self.W = tmp_mask.shape[1]
if self.single_object_mode:
self.obj_id = (
int(video_png_root.split("/")[-1]) + 1
) # offset by 1 as bg is 0
else:
self.obj_id = None
def load(self, frame_id):
if self.single_object_mode:
return self._load_single_png(frame_id)
else:
return self._load_multiple_pngs(frame_id)
def _load_single_png(self, frame_id):
"""
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
binary_segments = {}
if os.path.exists(mask_path):
mask = np.array(PILImage.open(mask_path))
else:
# if png doesn't exist, empty mask
mask = np.zeros((self.H, self.W), dtype=bool)
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
return binary_segments
def _load_multiple_pngs(self, frame_id):
"""
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
# get the path
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
num_objects = len(all_objects)
assert num_objects > 0
# load the masks
binary_segments = {}
for obj_folder in all_objects:
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
obj_id = int(obj_folder.split("/")[-1])
obj_id = obj_id + 1 # offset 1 as bg is 0
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
if os.path.exists(mask_path):
mask = np.array(PILImage.open(mask_path))
else:
mask = np.zeros((self.H, self.W), dtype=bool)
binary_segments[obj_id] = torch.from_numpy(mask > 0)
return binary_segments
def __len__(self):
return
class LazySegments:
"""
Only decodes segments that are actually used.
"""
def __init__(self):
self.segments = {}
self.cache = {}
def __setitem__(self, key, item):
self.segments[key] = item
def __getitem__(self, key):
if key in self.cache:
return self.cache[key]
rle = self.segments[key]
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
self.cache[key] = mask
return mask
def __contains__(self, key):
return key in self.segments
def __len__(self):
return len(self.segments)
def keys(self):
return self.segments.keys()
class SA1BSegmentLoader:
def __init__(
self,
video_mask_path,
mask_area_frac_thresh=1.1,
video_frame_path=None,
uncertain_iou=-1,
):
with open(video_mask_path, "r") as f:
self.frame_annots = json.load(f)
if mask_area_frac_thresh <= 1.0:
# Lazily read frame
orig_w, orig_h = PILImage.open(video_frame_path).size
area = orig_w * orig_h
self.frame_annots = self.frame_annots["annotations"]
rle_masks = []
for frame_annot in self.frame_annots:
if not frame_annot["area"] > 0:
continue
if ("uncertain_iou" in frame_annot) and (
frame_annot["uncertain_iou"] < uncertain_iou
):
# uncertain_iou is stability score
continue
if (
mask_area_frac_thresh <= 1.0
and (frame_annot["area"] / area) >= mask_area_frac_thresh
):
continue
rle_masks.append(frame_annot["segmentation"])
self.segments = LazySegments()
for i, rle in enumerate(rle_masks):
self.segments[i] = rle
def load(self, frame_idx):
return self.segments

307
sam2/training/loss_fns.py Normal file
View File

@@ -0,0 +1,307 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from typing import Dict, List
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from training.trainer import CORE_LOSS_KEY
from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
num_objects: Number of objects in the batch
loss_on_multimask: True if multimask prediction is enabled
Returns:
Dice loss tensor
"""
inputs = inputs.sigmoid()
if loss_on_multimask:
# inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
assert inputs.dim() == 4 and targets.dim() == 4
# flatten spatial dimension while keeping multimask channel dimension
inputs = inputs.flatten(2)
targets = targets.flatten(2)
numerator = 2 * (inputs * targets).sum(-1)
else:
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
if loss_on_multimask:
return loss / num_objects
return loss.sum() / num_objects
def sigmoid_focal_loss(
inputs,
targets,
num_objects,
alpha: float = 0.25,
gamma: float = 2,
loss_on_multimask=False,
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
num_objects: Number of objects in the batch
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
loss_on_multimask: True if multimask prediction is enabled
Returns:
focal loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if loss_on_multimask:
# loss is [N, M, H, W] where M corresponds to multiple predicted masks
assert loss.dim() == 4
return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
return loss.mean(1).sum() / num_objects
def iou_loss(
inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
pred_ious: A float tensor containing the predicted IoUs scores per mask
num_objects: Number of objects in the batch
loss_on_multimask: True if multimask prediction is enabled
use_l1_loss: Whether to use L1 loss is used instead of MSE loss
Returns:
IoU loss tensor
"""
assert inputs.dim() == 4 and targets.dim() == 4
pred_mask = inputs.flatten(2) > 0
gt_mask = targets.flatten(2) > 0
area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
actual_ious = area_i / torch.clamp(area_u, min=1.0)
if use_l1_loss:
loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
else:
loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
if loss_on_multimask:
return loss / num_objects
return loss.sum() / num_objects
class MultiStepMultiMasksAndIous(nn.Module):
def __init__(
self,
weight_dict,
focal_alpha=0.25,
focal_gamma=2,
supervise_all_iou=False,
iou_use_l1_loss=False,
pred_obj_scores=False,
focal_gamma_obj_score=0.0,
focal_alpha_obj_score=-1,
):
"""
This class computes the multi-step multi-mask and IoU losses.
Args:
weight_dict: dict containing weights for focal, dice, iou losses
focal_alpha: alpha for sigmoid focal loss
focal_gamma: gamma for sigmoid focal loss
supervise_all_iou: if True, back-prop iou losses for all predicted masks
iou_use_l1_loss: use L1 loss instead of MSE loss for iou
pred_obj_scores: if True, compute loss for object scores
focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
"""
super().__init__()
self.weight_dict = weight_dict
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
assert "loss_mask" in self.weight_dict
assert "loss_dice" in self.weight_dict
assert "loss_iou" in self.weight_dict
if "loss_class" not in self.weight_dict:
self.weight_dict["loss_class"] = 0.0
self.focal_alpha_obj_score = focal_alpha_obj_score
self.focal_gamma_obj_score = focal_gamma_obj_score
self.supervise_all_iou = supervise_all_iou
self.iou_use_l1_loss = iou_use_l1_loss
self.pred_obj_scores = pred_obj_scores
def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
assert len(outs_batch) == len(targets_batch)
num_objects = torch.tensor(
(targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
) # Number of objects is fixed within a batch
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_objects)
num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
losses = defaultdict(int)
for outs, targets in zip(outs_batch, targets_batch):
cur_losses = self._forward(outs, targets, num_objects)
for k, v in cur_losses.items():
losses[k] += v
return losses
def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
"""
Compute the losses related to the masks: the focal loss and the dice loss.
and also the MAE or MSE loss between predicted IoUs and actual IoUs.
Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
of shape [N, M, H, W], where M could be 1 or larger, corresponding to
one or multiple predicted masks from a click.
We back-propagate focal, dice losses only on the prediction channel
with the lowest focal+dice loss between predicted mask and ground-truth.
If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
"""
target_masks = targets.unsqueeze(1).float()
assert target_masks.dim() == 4 # [N, 1, H, W]
src_masks_list = outputs["multistep_pred_multimasks_high_res"]
ious_list = outputs["multistep_pred_ious"]
object_score_logits_list = outputs["multistep_object_score_logits"]
assert len(src_masks_list) == len(ious_list)
assert len(object_score_logits_list) == len(ious_list)
# accumulate the loss over prediction steps
losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
for src_masks, ious, object_score_logits in zip(
src_masks_list, ious_list, object_score_logits_list
):
self._update_losses(
losses, src_masks, target_masks, ious, num_objects, object_score_logits
)
losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
return losses
def _update_losses(
self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
):
target_masks = target_masks.expand_as(src_masks)
# get focal, dice and iou loss on all output masks in a prediction step
loss_multimask = sigmoid_focal_loss(
src_masks,
target_masks,
num_objects,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
loss_on_multimask=True,
)
loss_multidice = dice_loss(
src_masks, target_masks, num_objects, loss_on_multimask=True
)
if not self.pred_obj_scores:
loss_class = torch.tensor(
0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
)
target_obj = torch.ones(
loss_multimask.shape[0],
1,
dtype=loss_multimask.dtype,
device=loss_multimask.device,
)
else:
target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
..., None
].float()
loss_class = sigmoid_focal_loss(
object_score_logits,
target_obj,
num_objects,
alpha=self.focal_alpha_obj_score,
gamma=self.focal_gamma_obj_score,
)
loss_multiiou = iou_loss(
src_masks,
target_masks,
ious,
num_objects,
loss_on_multimask=True,
use_l1_loss=self.iou_use_l1_loss,
)
assert loss_multimask.dim() == 2
assert loss_multidice.dim() == 2
assert loss_multiiou.dim() == 2
if loss_multimask.size(1) > 1:
# take the mask indices with the smallest focal + dice loss for back propagation
loss_combo = (
loss_multimask * self.weight_dict["loss_mask"]
+ loss_multidice * self.weight_dict["loss_dice"]
)
best_loss_inds = torch.argmin(loss_combo, dim=-1)
batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
# calculate the iou prediction and slot losses only in the index
# with the minimum loss for each mask (to be consistent w/ SAM)
if self.supervise_all_iou:
loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
else:
loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
else:
loss_mask = loss_multimask
loss_dice = loss_multidice
loss_iou = loss_multiiou
# backprop focal, dice and iou loss only if obj present
loss_mask = loss_mask * target_obj
loss_dice = loss_dice * target_obj
loss_iou = loss_iou * target_obj
# sum over batch dimension (note that the losses are already divided by num_objects)
losses["loss_mask"] += loss_mask.sum()
losses["loss_dice"] += loss_dice.sum()
losses["loss_iou"] += loss_iou.sum()
losses["loss_class"] += loss_class
def reduce_loss(self, losses):
reduced_loss = 0.0
for loss_key, weight in self.weight_dict.items():
if loss_key not in losses:
raise ValueError(f"{type(self)} doesn't compute {loss_key}")
if weight != 0:
reduced_loss += losses[loss_key] * weight
return reduced_loss

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

541
sam2/training/model/sam2.py Normal file
View File

@@ -0,0 +1,541 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
import torch.distributed
from sam2.modeling.sam2_base import SAM2Base
from sam2.modeling.sam2_utils import (
get_1d_sine_pe,
get_next_point,
sample_box_points,
select_closest_cond_frames,
)
from sam2.utils.misc import concat_points
from training.utils.data_utils import BatchedVideoDatapoint
class SAM2Train(SAM2Base):
def __init__(
self,
image_encoder,
memory_attention=None,
memory_encoder=None,
prob_to_use_pt_input_for_train=0.0,
prob_to_use_pt_input_for_eval=0.0,
prob_to_use_box_input_for_train=0.0,
prob_to_use_box_input_for_eval=0.0,
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
rand_frames_to_correct_for_train=False,
rand_frames_to_correct_for_eval=False,
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
rand_init_cond_frames_for_eval=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# how many additional correction points to sample (on each frame selected to be corrected)
# note that the first frame receives an initial input click (in addition to any correction clicks)
num_correction_pt_per_frame=7,
# method for point sampling during evaluation
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
# default to "center" to be consistent with evaluation in the SAM paper
pt_sampling_for_eval="center",
# During training, we optionally allow sampling the correction points from GT regions
# instead of the prediction error regions with a small probability. This might allow the
# model to overfit less to the error regions in training datasets
prob_to_sample_from_gt_for_train=0.0,
use_act_ckpt_iterative_pt_sampling=False,
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
forward_backbone_per_frame_for_eval=False,
freeze_image_encoder=False,
**kwargs,
):
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
# Point sampler and conditioning frames
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
logging.info(
f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
)
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
# Initial multi-conditioning frames
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.num_correction_pt_per_frame = num_correction_pt_per_frame
self.pt_sampling_for_eval = pt_sampling_for_eval
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
# A random number generator with a fixed initial seed across GPUs
self.rng = np.random.default_rng(seed=42)
if freeze_image_encoder:
for p in self.image_encoder.parameters():
p.requires_grad = False
def forward(self, input: BatchedVideoDatapoint):
if self.training or not self.forward_backbone_per_frame_for_eval:
# precompute image features on all frames before tracking
backbone_out = self.forward_image(input.flat_img_batch)
else:
# defer image feature computation on a frame until it's being tracked
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
backbone_out = self.prepare_prompt_inputs(backbone_out, input)
previous_stages_out = self.forward_tracking(backbone_out, input)
return previous_stages_out
def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
"""Compute the image backbone features on the fly for the given img_ids."""
# Only forward backbone on unique image ids to avoid repetitive computation
# (if `img_ids` has only one element, it's already unique so we skip this step).
if img_ids.numel() > 1:
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
else:
unique_img_ids, inv_ids = img_ids, None
# Compute the image features on those unique image ids
image = img_batch[unique_img_ids]
backbone_out = self.forward_image(image)
(
_,
vision_feats,
vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features(backbone_out)
# Inverse-map image features for `unique_img_ids` to the final image features
# for the original input `img_ids`.
if inv_ids is not None:
image = image[inv_ids]
vision_feats = [x[:, inv_ids] for x in vision_feats]
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
return image, vision_feats, vision_pos_embeds, feat_sizes
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
"""
Prepare input mask, point or box prompts. Optionally, we allow tracking from
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
"""
# Load the ground-truth masks on all frames (so that we can later
# sample correction points from them)
# gt_masks_per_frame = {
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
# for stage_id, targets in enumerate(input.find_targets)
# }
gt_masks_per_frame = {
stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
for stage_id, masks in enumerate(input.masks)
}
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
num_frames = input.num_frames
backbone_out["num_frames"] = num_frames
# Randomly decide whether to use point inputs or mask inputs
if self.training:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
prob_to_use_box_input = self.prob_to_use_box_input_for_train
num_frames_to_correct = self.num_frames_to_correct_for_train
rand_frames_to_correct = self.rand_frames_to_correct_for_train
num_init_cond_frames = self.num_init_cond_frames_for_train
rand_init_cond_frames = self.rand_init_cond_frames_for_train
else:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
num_frames_to_correct = self.num_frames_to_correct_for_eval
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
num_init_cond_frames = self.num_init_cond_frames_for_eval
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
if num_frames == 1:
# here we handle a special case for mixing video + SAM on image training,
# where we force using point input for the SAM task on static images
prob_to_use_pt_input = 1.0
num_frames_to_correct = 1
num_init_cond_frames = 1
assert num_init_cond_frames >= 1
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
use_pt_input = self.rng.random() < prob_to_use_pt_input
if rand_init_cond_frames and num_init_cond_frames > 1:
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
num_init_cond_frames = self.rng.integers(
1, num_init_cond_frames, endpoint=True
)
if (
use_pt_input
and rand_frames_to_correct
and num_frames_to_correct > num_init_cond_frames
):
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
# correction clicks (only for the case of point input)
num_frames_to_correct = self.rng.integers(
num_init_cond_frames, num_frames_to_correct, endpoint=True
)
backbone_out["use_pt_input"] = use_pt_input
# Sample initial conditioning frames
if num_init_cond_frames == 1:
init_cond_frames = [start_frame_idx] # starting frame
else:
# starting frame + randomly selected remaining frames (without replacement)
init_cond_frames = [start_frame_idx] + self.rng.choice(
range(start_frame_idx + 1, num_frames),
num_init_cond_frames - 1,
replace=False,
).tolist()
backbone_out["init_cond_frames"] = init_cond_frames
backbone_out["frames_not_in_init_cond"] = [
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
]
# Prepare mask or point inputs on initial conditioning frames
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
for t in init_cond_frames:
if not use_pt_input:
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
else:
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
use_box_input = self.rng.random() < prob_to_use_box_input
if use_box_input:
points, labels = sample_box_points(
gt_masks_per_frame[t],
)
else:
# (here we only sample **one initial point** on initial conditioning frames from the
# ground-truth mask; we may sample more correction points on the fly)
points, labels = get_next_point(
gt_masks=gt_masks_per_frame[t],
pred_masks=None,
method=(
"uniform" if self.training else self.pt_sampling_for_eval
),
)
point_inputs = {"point_coords": points, "point_labels": labels}
backbone_out["point_inputs_per_frame"][t] = point_inputs
# Sample frames where we will add correction clicks on the fly
# based on the error between prediction and ground-truth masks
if not use_pt_input:
# no correction points will be sampled when using mask inputs
frames_to_add_correction_pt = []
elif num_frames_to_correct == num_init_cond_frames:
frames_to_add_correction_pt = init_cond_frames
else:
assert num_frames_to_correct > num_init_cond_frames
# initial cond frame + randomly selected remaining frames (without replacement)
extra_num = num_frames_to_correct - num_init_cond_frames
frames_to_add_correction_pt = (
init_cond_frames
+ self.rng.choice(
backbone_out["frames_not_in_init_cond"], extra_num, replace=False
).tolist()
)
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
return backbone_out
def forward_tracking(
self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
):
"""Forward video tracking on each frame (and sample correction clicks)."""
img_feats_already_computed = backbone_out["backbone_fpn"] is not None
if img_feats_already_computed:
# Prepare the backbone features
# - vision_feats and vision_pos_embeds are in (HW)BC format
(
_,
vision_feats,
vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features(backbone_out)
# Starting the stage loop
num_frames = backbone_out["num_frames"]
init_cond_frames = backbone_out["init_cond_frames"]
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
# first process all the initial conditioning frames to encode them as memory,
# and then conditioning on them to track the remaining frames
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
output_dict = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
for stage_id in processing_order:
# Get the image features for the current frames
# img_ids = input.find_inputs[stage_id].img_ids
img_ids = input.flat_obj_to_img_idx[stage_id]
if img_feats_already_computed:
# Retrieve image features according to img_ids (if they are already computed).
current_vision_feats = [x[:, img_ids] for x in vision_feats]
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
else:
# Otherwise, compute the image features on the fly for the given img_ids
# (this might be used for evaluation on long videos to avoid backbone OOM).
(
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features_per_frame(
input.flat_img_batch, img_ids
)
# Get output masks based on this frame's prompts and previous memory
current_out = self.track_step(
frame_idx=stage_id,
is_init_cond_frame=stage_id in init_cond_frames,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
frames_to_add_correction_pt=frames_to_add_correction_pt,
output_dict=output_dict,
num_frames=num_frames,
)
# Append the output, depending on whether it's a conditioning frame
add_output_as_cond_frame = stage_id in init_cond_frames or (
self.add_all_frames_to_correct_as_cond
and stage_id in frames_to_add_correction_pt
)
if add_output_as_cond_frame:
output_dict["cond_frame_outputs"][stage_id] = current_out
else:
output_dict["non_cond_frame_outputs"][stage_id] = current_out
if return_dict:
return output_dict
# turn `output_dict` into a list for loss function
all_frame_outputs = {}
all_frame_outputs.update(output_dict["cond_frame_outputs"])
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
# Make DDP happy with activation checkpointing by removing unused keys
all_frame_outputs = [
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
]
return all_frame_outputs
def track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
frames_to_add_correction_pt=None,
gt_masks=None,
):
if frames_to_add_correction_pt is None:
frames_to_add_correction_pt = []
current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
)
(
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = sam_outputs
current_out["multistep_pred_masks"] = low_res_masks
current_out["multistep_pred_masks_high_res"] = high_res_masks
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
current_out["multistep_pred_ious"] = [ious]
current_out["multistep_point_inputs"] = [point_inputs]
current_out["multistep_object_score_logits"] = [object_score_logits]
# Optionally, sample correction points iteratively to correct the mask
if frame_idx in frames_to_add_correction_pt:
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
is_init_cond_frame,
point_inputs,
gt_masks,
high_res_features,
pix_feat,
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
object_score_logits,
current_out,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = final_sam_outputs
# Use the final prediction (after all correction steps for output and eval)
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
)
return current_out
def _iter_correct_pt_sampling(
self,
is_init_cond_frame,
point_inputs,
gt_masks,
high_res_features,
pix_feat_with_mem,
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
object_score_logits,
current_out,
):
assert gt_masks is not None
all_pred_masks = [low_res_masks]
all_pred_high_res_masks = [high_res_masks]
all_pred_multimasks = [low_res_multimasks]
all_pred_high_res_multimasks = [high_res_multimasks]
all_pred_ious = [ious]
all_point_inputs = [point_inputs]
all_object_score_logits = [object_score_logits]
for _ in range(self.num_correction_pt_per_frame):
# sample a new point from the error between prediction and ground-truth
# (with a small probability, directly sample from GT masks instead of errors)
if self.training and self.prob_to_sample_from_gt_for_train > 0:
sample_from_gt = (
self.rng.random() < self.prob_to_sample_from_gt_for_train
)
else:
sample_from_gt = False
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
new_points, new_labels = get_next_point(
gt_masks=gt_masks,
pred_masks=pred_for_new_pt,
method="uniform" if self.training else self.pt_sampling_for_eval,
)
point_inputs = concat_points(point_inputs, new_points, new_labels)
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
# For tracking, this means that when the user adds a correction click, we also feed
# the tracking output mask logits along with the click as input to the SAM decoder.
mask_inputs = low_res_masks
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
sam_outputs = torch.utils.checkpoint.checkpoint(
self._forward_sam_heads,
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
use_reentrant=False,
)
else:
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
(
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
_,
object_score_logits,
) = sam_outputs
all_pred_masks.append(low_res_masks)
all_pred_high_res_masks.append(high_res_masks)
all_pred_multimasks.append(low_res_multimasks)
all_pred_high_res_multimasks.append(high_res_multimasks)
all_pred_ious.append(ious)
all_point_inputs.append(point_inputs)
all_object_score_logits.append(object_score_logits)
# Concatenate the masks along channel (to compute losses on all of them,
# using `MultiStepIteractiveMasks`)
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
current_out["multistep_pred_masks_high_res"] = torch.cat(
all_pred_high_res_masks, dim=1
)
current_out["multistep_pred_multimasks"] = all_pred_multimasks
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
current_out["multistep_pred_ious"] = all_pred_ious
current_out["multistep_point_inputs"] = all_point_inputs
current_out["multistep_object_score_logits"] = all_object_score_logits
return point_inputs, sam_outputs

502
sam2/training/optimizer.py Normal file
View File

@@ -0,0 +1,502 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import fnmatch
import inspect
import itertools
import logging
import types
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)
import hydra
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor
class Optimizer:
def __init__(self, optimizer, schedulers=None) -> None:
self.optimizer = optimizer
self.schedulers = schedulers
self._validate_optimizer_schedulers()
self.step_schedulers(0.0, 0)
def _validate_optimizer_schedulers(self):
if self.schedulers is None:
return
for _, set_of_schedulers in enumerate(self.schedulers):
for option, _ in set_of_schedulers.items():
assert option in self.optimizer.defaults, (
"Optimizer option "
f"{option} not found in {self.optimizer}. Valid options are "
f"{self.optimizer.defaults.keys()}"
)
def step_schedulers(self, where: float, step: int) -> None:
if self.schedulers is None:
return
for i, param_group in enumerate(self.optimizer.param_groups):
for option, scheduler in self.schedulers[i].items():
if "step" in inspect.signature(scheduler.__call__).parameters:
new_value = scheduler(step=step, where=where)
elif (
hasattr(scheduler, "scheduler")
and "step"
in inspect.signature(scheduler.scheduler.__call__).parameters
):
# To handle ValueScaler wrappers
new_value = scheduler(step=step, where=where)
else:
new_value = scheduler(where)
param_group[option] = new_value
def step(self, where, step, closure=None):
self.step_schedulers(where, step)
return self.optimizer.step(closure)
def zero_grad(self, *args, **kwargs):
return self.optimizer.zero_grad(*args, **kwargs)
def set_default_parameters(
scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
) -> None:
"""Set up the "default" scheduler with the right parameters.
Args:
scheduler_cgfs: A list of scheduler configs, where each scheduler also
specifies which parameters it applies to, based on the names of parameters
or the class of the modules. At most one scheduler is allowed to skip this
specification, which is used as a "default" specification for any remaining
parameters.
all_parameter_names: Names of all the parameters to consider.
"""
constraints = [
scheduler_cfg.parameter_names
for scheduler_cfg in scheduler_cfgs
if scheduler_cfg.parameter_names is not None
]
if len(constraints) == 0:
default_params = set(all_parameter_names)
else:
default_params = all_parameter_names - set.union(*constraints)
default_count = 0
for scheduler_cfg in scheduler_cfgs:
if scheduler_cfg.parameter_names is None:
scheduler_cfg.parameter_names = default_params
default_count += 1
assert default_count <= 1, "Only one scheduler per option can be default"
if default_count == 0:
# No default scheduler specified, add a default, but without any scheduler
# for that option
scheduler_cfgs.append({"parameter_names": default_params})
def name_constraints_to_parameters(
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
) -> List[torch.nn.Parameter]:
"""Return parameters which match the intersection of parameter constraints.
Note that this returns the parameters themselves, not their names.
Args:
param_constraints: A list, with each element being a set of allowed parameters.
named_parameters: Mapping from a parameter name to the parameter itself.
Returns:
A list containing the parameters which overlap with _each_ constraint set from
param_constraints.
"""
matching_names = set.intersection(*param_constraints)
return [value for name, value in named_parameters.items() if name in matching_names]
def map_scheduler_cfgs_to_param_groups(
all_scheduler_cfgs: Iterable[List[Dict]],
named_parameters: Dict[str, Tensor],
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
"""Produce parameter groups corresponding to all the scheduler configs.
Takes all the scheduler configs, each of which applies to a specific optimizer
option (like "lr" or "weight_decay") and has a set of parameter names which it
applies to, and produces a final set of param groups where each param group
covers all the options which apply to a particular set of parameters.
Args:
all_scheduler_cfgs: All the scheduler configs covering every option.
named_parameters: Mapping from a parameter name to the parameter itself.
Returns:
Tuple of lists of schedulers and param_groups, where schedulers[i]
applies to param_groups[i].
"""
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
schedulers = []
param_groups = []
for scheduler_cfgs in scheduler_cfgs_per_param_group:
param_constraints = [
scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
]
matching_parameters = name_constraints_to_parameters(
param_constraints, named_parameters
)
if len(matching_parameters) == 0: # If no overlap of parameters, skip
continue
schedulers_for_group = {
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
for scheduler_cfg in scheduler_cfgs
if "option" in scheduler_cfg
}
schedulers.append(schedulers_for_group)
param_groups.append({"params": matching_parameters})
return schedulers, param_groups
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
"""Check that the param groups are non-overlapping and cover all the parameters.
Args:
param_groups: List of all param groups
model: Model to validate against. The check ensures that all the model
parameters are part of param_groups
"""
for pg in param_groups:
# no param should be repeated within a group
assert len(pg["params"]) == len(set(pg["params"]))
parameters = [set(param_group["params"]) for param_group in param_groups]
model_parameters = {parameter for _, parameter in model.named_parameters()}
for p1, p2 in itertools.permutations(parameters, 2):
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
assert set.union(*parameters) == model_parameters, (
"Scheduler generated param_groups must include all parameters of the model."
f" Found {len(set.union(*parameters))} params whereas model has"
f" {len(model_parameters)} params"
)
def unix_module_cls_pattern_to_parameter_names(
filter_module_cls_names: List[str],
module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in filter_module_cls_names.
Args:
filter_module_cls_names: A list of filter strings containing class names, like
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
module_cls_to_param_names: Mapping from module classes to the parameter names
they contain. See `get_module_cls_to_param_names`.
"""
if filter_module_cls_names is None:
return set()
allowed_parameter_names = []
for module_cls_name in filter_module_cls_names:
module_cls = hydra.utils.get_class(module_cls_name)
if module_cls not in module_cls_to_param_names:
raise AssertionError(
f"module_cls_name {module_cls_name} does not "
"match any classes in the model"
)
matching_parameters = module_cls_to_param_names[module_cls]
assert (
len(matching_parameters) > 0
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
logging.info(
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
)
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)
def unix_param_pattern_to_parameter_names(
filter_param_names: Optional[List[str]],
parameter_names: Dict[str, torch.Tensor],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in filter_param_names.
Args:
filter_param_names: A list of unix-style filter strings with optional
wildcards, like ["block.2.*", "block.2.linear.weight"]
module_cls_to_param_names: Mapping from module classes to the parameter names
they contain. See `get_module_cls_to_param_names`.
"""
if filter_param_names is None:
return set()
allowed_parameter_names = []
for param_name in filter_param_names:
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
assert (
len(matching_parameters) >= 1
), f"param_name {param_name} does not match any parameters in the model"
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
allowed_parameter_names.append(matching_parameters)
return set.union(*allowed_parameter_names)
def _unix_pattern_to_parameter_names(
scheduler_cfg: DictConfig,
parameter_names: Set[str],
module_cls_to_param_names: Dict[Type, str],
) -> Union[None, Set[str]]:
"""Returns param names which pass the filters specified in scheduler_cfg.
Args:
scheduler_cfg: The config for the scheduler
parameter_names: The set of all parameter names which will be filtered
"""
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
return None
return unix_param_pattern_to_parameter_names(
scheduler_cfg.get("param_names"), parameter_names
).union(
unix_module_cls_pattern_to_parameter_names(
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
)
)
def get_module_cls_to_param_names(
model: nn.Module, param_allowlist: Set[str] = None
) -> Dict[Type, str]:
"""Produce a mapping from all the modules classes to the names of parames they own.
Only counts a parameter as part of the immediate parent module, i.e. recursive
parents do not count.
Args:
model: Model to iterate over
param_allowlist: If specified, only these param names will be processed
"""
module_cls_to_params = {}
for module_name, module in model.named_modules():
module_cls = type(module)
module_cls_to_params.setdefault(module_cls, set())
for param_name, _ in module.named_parameters(recurse=False):
full_param_name = get_full_parameter_name(module_name, param_name)
if param_allowlist is None or full_param_name in param_allowlist:
module_cls_to_params[module_cls].add(full_param_name)
return module_cls_to_params
def construct_optimizer(
model: torch.nn.Module,
optimizer_conf: Any,
options_conf: Mapping[str, List] = None,
param_group_modifiers_conf: List[Callable] = None,
param_allowlist: Optional[Set[str]] = None,
validate_param_groups=True,
) -> Optimizer:
"""
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
Batchnorm and/or no-update 1-D parameters support, based on the config.
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
(LARS): https://arxiv.org/abs/1708.03888
Args:
model: model to perform stochastic gradient descent
optimization or ADAM optimization.
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
ADAM, still missing the params argument which this function provides to
produce the final optimizer
param_group_modifiers_conf: Optional user specified functions which can modify
the final scheduler configs before the optimizer's param groups are built
param_allowlist: The parameters to optimize. Parameters which are not part of
this allowlist will be skipped.
validate_param_groups: If enabled, valides that the produced param_groups don't
overlap and cover all the model parameters.
"""
if param_allowlist is None:
param_allowlist = {name for name, _ in model.named_parameters()}
named_parameters = {
name: param
for name, param in model.named_parameters()
if name in param_allowlist
}
if not options_conf:
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
return Optimizer(optimizer)
all_parameter_names = {
name for name, _ in model.named_parameters() if name in param_allowlist
}
module_cls_to_all_param_names = get_module_cls_to_param_names(
model, param_allowlist
)
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
all_scheduler_cfgs = []
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
for config in scheduler_cfgs:
config.option = option
config.parameter_names = _unix_pattern_to_parameter_names(
config, all_parameter_names, module_cls_to_all_param_names
)
set_default_parameters(scheduler_cfgs, all_parameter_names)
all_scheduler_cfgs.append(scheduler_cfgs)
if param_group_modifiers_conf:
for custom_param_modifier in param_group_modifiers_conf:
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
all_scheduler_cfgs = custom_param_modifier(
scheduler_cfgs=all_scheduler_cfgs, model=model
)
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
all_scheduler_cfgs, named_parameters
)
if validate_param_groups:
validate_param_group_params(param_groups, model)
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
return Optimizer(optimizer, schedulers)
def get_full_parameter_name(module_name, param_name):
if module_name == "":
return param_name
return f"{module_name}.{param_name}"
class GradientClipper:
"""
Gradient clipping utils that works for DDP
"""
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
assert isinstance(max_norm, (int, float)) or max_norm is None
self.max_norm = max_norm if max_norm is None else float(max_norm)
self.norm_type = norm_type
def __call__(self, model: nn.Module):
if self.max_norm is None:
return # no-op
nn.utils.clip_grad_norm_(
model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
)
class ValueScaler:
def __init__(self, scheduler, mult_val: float):
self.scheduler = scheduler
self.mult_val = mult_val
def __call__(self, *args, **kwargs):
val = self.scheduler(*args, **kwargs)
return val * self.mult_val
def rgetattr(obj, rattrs: str = None):
"""
Like getattr(), but supports dotted notation for nested objects.
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
"""
if rattrs is None:
return obj
attrs = rattrs.split(".")
for attr in attrs:
obj = getattr(obj, attr)
return obj
def layer_decay_param_modifier(
scheduler_cfgs: List[List[Dict]],
model,
layer_decay_value: float,
layer_decay_min: Optional[float] = None,
apply_to: Optional[str] = None,
overrides: List[Dict] = (),
) -> List[List[Dict]]:
"""
Args
- scheduler_cfgs: a list of omegaconf.ListConfigs.
Each element in the list is a omegaconfg.DictConfig with the following structure
{
"scheduler": <some fvcore scheduler>
"option": <value> possible options are "lr", "weight_decay" etc.
"parameter_names": Set of str indicating param names that this scheduler applies to
}
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
and a method get_num_layers.
Alternatively, use apply_to argument to select a specific component of the model.
- layer_decay_value: float
- layer_decay_min: min val for layer decay
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
Returns
- scheduler_configs: same structure as the input, elements can be modified
"""
model = rgetattr(model, apply_to)
num_layers = model.get_num_layers() + 1
layer_decays = [
layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
]
if layer_decay_min is not None:
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
final_scheduler_cfgs = []
# scheduler_cfgs is a list of lists
for scheduler_cfg_group in scheduler_cfgs:
curr_cfg_group = []
# scheduler_cfg_group is a list of dictionaries
for scheduler_cfg in scheduler_cfg_group:
if scheduler_cfg["option"] != "lr":
curr_cfg_group.append(scheduler_cfg)
continue
# Need sorted so that the list of parameter names is deterministic and consistent
# across re-runs of this job. Else it was causing issues with loading the optimizer
# state during a job restart (D38591759)
parameter_names = sorted(scheduler_cfg["parameter_names"])
# Only want one cfg group per layer
layer_cfg_groups = {}
for param_name in parameter_names:
layer_id = num_layers
this_scale = layer_decays[layer_id]
if param_name.startswith(apply_to):
layer_id = model.get_layer_id(param_name)
this_scale = layer_decays[layer_id]
# Overrides
for override in overrides:
if fnmatch.fnmatchcase(param_name, override["pattern"]):
this_scale = float(override["value"])
layer_id = override["pattern"]
break
if layer_id not in layer_cfg_groups:
curr_param = {
"option": scheduler_cfg["option"],
"scheduler": ValueScaler(
scheduler_cfg["scheduler"], this_scale
),
"parameter_names": {param_name},
}
else:
curr_param = layer_cfg_groups[layer_id]
curr_param["parameter_names"].add(param_name)
layer_cfg_groups[layer_id] = curr_param
for layer_cfg in layer_cfg_groups.values():
curr_cfg_group.append(layer_cfg)
final_scheduler_cfgs.append(curr_cfg_group)
return final_scheduler_cfgs

View File

@@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import submitit
import tqdm
def get_args_parser():
parser = argparse.ArgumentParser(
description="[SA-V Preprocessing] Extracting JPEG frames",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# ------------
# DATA
# ------------
data_parser = parser.add_argument_group(
title="SA-V dataset data root",
description="What data to load and how to process it.",
)
data_parser.add_argument(
"--sav-vid-dir",
type=str,
required=True,
help=("Where to find the SAV videos"),
)
data_parser.add_argument(
"--sav-frame-sample-rate",
type=int,
default=4,
help="Rate at which to sub-sample frames",
)
# ------------
# LAUNCH
# ------------
launch_parser = parser.add_argument_group(
title="Cluster launch settings",
description="Number of jobs and retry settings.",
)
launch_parser.add_argument(
"--n-jobs",
type=int,
required=True,
help="Shard the run over this many jobs.",
)
launch_parser.add_argument(
"--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
)
launch_parser.add_argument(
"--partition", type=str, required=True, help="Partition to launch on."
)
launch_parser.add_argument(
"--account", type=str, required=True, help="Partition to launch on."
)
launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
# ------------
# OUTPUT
# ------------
output_parser = parser.add_argument_group(
title="Setting for results output", description="Where and how to save results."
)
output_parser.add_argument(
"--output-dir",
type=str,
required=True,
help=("Where to dump the extracted jpeg frames"),
)
output_parser.add_argument(
"--slurm-output-root-dir",
type=str,
required=True,
help=("Where to save slurm outputs"),
)
return parser
def decode_video(video_path: str):
assert os.path.exists(video_path)
video = cv2.VideoCapture(video_path)
video_frames = []
while video.isOpened():
ret, frame = video.read()
if ret:
video_frames.append(frame)
else:
break
return video_frames
def extract_frames(video_path, sample_rate):
frames = decode_video(video_path)
return frames[::sample_rate]
def submitit_launch(video_paths, sample_rate, save_root):
for path in tqdm.tqdm(video_paths):
frames = extract_frames(path, sample_rate)
output_folder = os.path.join(save_root, Path(path).stem)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for fid, frame in enumerate(frames):
frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
cv2.imwrite(frame_path, frame)
print(f"Saved output to {save_root}")
if __name__ == "__main__":
parser = get_args_parser()
args = parser.parse_args()
sav_vid_dir = args.sav_vid_dir
save_root = args.output_dir
sample_rate = args.sav_frame_sample_rate
# List all SA-V videos
mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
mp4_files = np.array(mp4_files)
chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
print(f"Processing videos in: {sav_vid_dir}")
print(f"Processing {len(mp4_files)} files")
print(f"Beginning processing in {args.n_jobs} processes")
# Submitit params
jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
cpus_per_task = 4
executor = submitit.AutoExecutor(folder=jobs_dir)
executor.update_parameters(
timeout_min=args.timeout,
gpus_per_node=0,
tasks_per_node=1,
slurm_array_parallelism=args.n_jobs,
cpus_per_task=cpus_per_task,
slurm_partition=args.partition,
slurm_account=args.account,
slurm_qos=args.qos,
)
executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
# Launch
jobs = []
with executor.batch():
for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
job = executor.submit(
submitit_launch,
video_paths=mp4_chunk,
sample_rate=sample_rate,
save_root=save_root,
)
jobs.append(job)
for j in jobs:
print(f"Slurm JobID: {j.job_id}")
print(f"Saving outputs to {save_root}")
print(f"Slurm outputs at {args.slurm_output_root_dir}")

270
sam2/training/train.py Normal file
View File

@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import random
import sys
import traceback
from argparse import ArgumentParser
import submitit
import torch
from hydra import compose, initialize_config_module
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
from training.utils.train_utils import makedir, register_omegaconf_resolvers
os.environ["HYDRA_FULL_ERROR"] = "1"
def single_proc_run(local_rank, main_port, cfg, world_size):
"""Single GPU process"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(main_port)
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
try:
register_omegaconf_resolvers()
except Exception as e:
logging.info(e)
trainer = instantiate(cfg.trainer, _recursive_=False)
trainer.run()
def single_node_runner(cfg, main_port: int):
assert cfg.launcher.num_nodes == 1
num_proc = cfg.launcher.gpus_per_node
torch.multiprocessing.set_start_method(
"spawn"
) # CUDA runtime does not support `fork`
if num_proc == 1:
# directly call single_proc so we can easily set breakpoints
# mp.spawn does not let us set breakpoints
single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
else:
mp_runner = torch.multiprocessing.start_processes
args = (main_port, cfg, num_proc)
# Note: using "fork" below, "spawn" causes time and error regressions. Using
# spawn changes the default multiprocessing context to spawn, which doesn't
# interact well with the dataloaders (likely due to the use of OpenCV).
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
def format_exception(e: Exception, limit=20):
traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
class SubmititRunner(submitit.helpers.Checkpointable):
"""A callable which is passed to submitit to launch the jobs."""
def __init__(self, port, cfg):
self.cfg = cfg
self.port = port
self.has_setup = False
def run_trainer(self):
job_env = submitit.JobEnvironment()
# Need to add this again so the hydra.job.set_env PYTHONPATH
# is also set when launching jobs.
add_pythonpath_to_sys_path()
os.environ["MASTER_ADDR"] = job_env.hostnames[0]
os.environ["MASTER_PORT"] = str(self.port)
os.environ["RANK"] = str(job_env.global_rank)
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
register_omegaconf_resolvers()
cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
cfg_resolved = OmegaConf.create(cfg_resolved)
trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
trainer.run()
def __call__(self):
job_env = submitit.JobEnvironment()
self.setup_job_info(job_env.job_id, job_env.global_rank)
try:
self.run_trainer()
except Exception as e:
# Log the exception. Then raise it again (as what SubmititRunner currently does).
message = format_exception(e)
logging.error(message)
raise e
def setup_job_info(self, job_id, rank):
"""Set up slurm job info"""
self.job_info = {
"job_id": job_id,
"rank": rank,
"cluster": self.cfg.get("cluster", None),
"experiment_log_dir": self.cfg.launcher.experiment_log_dir,
}
self.has_setup = True
def add_pythonpath_to_sys_path():
if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
return
sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
def main(args) -> None:
cfg = compose(config_name=args.config)
if cfg.launcher.experiment_log_dir is None:
cfg.launcher.experiment_log_dir = os.path.join(
os.getcwd(), "sam2_logs", args.config
)
print("###################### Train App Config ####################")
print(OmegaConf.to_yaml(cfg))
print("############################################################")
add_pythonpath_to_sys_path()
makedir(cfg.launcher.experiment_log_dir)
with g_pathmgr.open(
os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
) as f:
f.write(OmegaConf.to_yaml(cfg))
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
cfg_resolved = OmegaConf.create(cfg_resolved)
with g_pathmgr.open(
os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
) as f:
f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
submitit_conf = cfg.get("submitit", None)
assert submitit_conf is not None, "Missing submitit config"
submitit_dir = cfg.launcher.experiment_log_dir
submitit_dir = os.path.join(submitit_dir, "submitit_logs")
# Priotrize cmd line args
cfg.launcher.gpus_per_node = (
args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
)
cfg.launcher.num_nodes = (
args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
)
submitit_conf.use_cluster = (
args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
)
if submitit_conf.use_cluster:
executor = submitit.AutoExecutor(folder=submitit_dir)
submitit_conf.partition = (
args.partition
if args.partition is not None
else submitit_conf.get("partition", None)
)
submitit_conf.account = (
args.account
if args.account is not None
else submitit_conf.get("account", None)
)
submitit_conf.qos = (
args.qos if args.qos is not None else submitit_conf.get("qos", None)
)
job_kwargs = {
"timeout_min": 60 * submitit_conf.timeout_hour,
"name": (
submitit_conf.name if hasattr(submitit_conf, "name") else args.config
),
"slurm_partition": submitit_conf.partition,
"gpus_per_node": cfg.launcher.gpus_per_node,
"tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
"cpus_per_task": submitit_conf.cpus_per_task,
"nodes": cfg.launcher.num_nodes,
"slurm_additional_parameters": {
"exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
},
}
if "include_nodes" in submitit_conf:
assert (
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
), "Not enough nodes"
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
submitit_conf["include_nodes"]
)
if submitit_conf.account is not None:
job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
if submitit_conf.qos is not None:
job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
if submitit_conf.get("mem_gb", None) is not None:
job_kwargs["mem_gb"] = submitit_conf.mem_gb
elif submitit_conf.get("mem", None) is not None:
job_kwargs["slurm_mem"] = submitit_conf.mem
if submitit_conf.get("constraints", None) is not None:
job_kwargs["slurm_constraint"] = submitit_conf.constraints
if submitit_conf.get("comment", None) is not None:
job_kwargs["slurm_comment"] = submitit_conf.comment
# Supports only cpu-bind option within srun_args. New options can be added here
if submitit_conf.get("srun_args", None) is not None:
job_kwargs["slurm_srun_args"] = []
if submitit_conf.srun_args.get("cpu_bind", None) is not None:
job_kwargs["slurm_srun_args"].extend(
["--cpu-bind", submitit_conf.srun_args.cpu_bind]
)
print("###################### SLURM Config ####################")
print(job_kwargs)
print("##########################################")
executor.update_parameters(**job_kwargs)
main_port = random.randint(
submitit_conf.port_range[0], submitit_conf.port_range[1]
)
runner = SubmititRunner(main_port, cfg)
job = executor.submit(runner)
print(f"Submitit Job ID: {job.job_id}")
runner.setup_job_info(job.job_id, rank=0)
else:
cfg.launcher.num_nodes = 1
main_port = random.randint(
submitit_conf.port_range[0], submitit_conf.port_range[1]
)
single_node_runner(cfg, main_port)
if __name__ == "__main__":
initialize_config_module("sam2", version_base="1.2")
parser = ArgumentParser()
parser.add_argument(
"-c",
"--config",
required=True,
type=str,
help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
)
parser.add_argument(
"--use-cluster",
type=int,
default=None,
help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
)
parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
parser.add_argument("--account", type=str, default=None, help="SLURM account")
parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
parser.add_argument(
"--num-gpus", type=int, default=None, help="number of GPUS per node"
)
parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
args = parser.parse_args()
args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
register_omegaconf_resolvers()
main(args)

1113
sam2/training/trainer.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,361 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import fnmatch
import logging
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import numpy as np
import torch
import torch.nn as nn
from iopath.common.file_io import g_pathmgr
from torch.jit._script import RecursiveScriptModule
def unix_pattern_to_parameter_names(
constraints: List[str], all_parameter_names: Sequence[str]
) -> Union[None, Set[str]]:
"""
Go through the list of parameter names and select those that match
any of the provided constraints
"""
parameter_names = []
for param_name in constraints:
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
assert (
len(matching_parameters) > 0
), f"param_names {param_name} don't match any param in the given names."
parameter_names.append(matching_parameters)
return set.union(*parameter_names)
def filter_params_matching_unix_pattern(
patterns: List[str], state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Remove from the state dictionary the parameters matching the provided unix patterns
Args:
patterns: the list of unix patterns to exclude
state_dict: the dictionary to filter
Returns:
A new state dictionary
"""
if len(patterns) == 0:
return {}
all_keys = list(state_dict.keys())
included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
return {k: state_dict[k] for k in included_keys}
def exclude_params_matching_unix_pattern(
patterns: List[str], state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Remove from the state dictionary the parameters matching the provided unix patterns
Args:
patterns: the list of unix patterns to exclude
state_dict: the dictionary to filter
Returns:
A new state dictionary
"""
if len(patterns) == 0:
return state_dict
all_keys = list(state_dict.keys())
excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
return {k: v for k, v in state_dict.items() if k not in excluded_keys}
def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
keys = []
trace = []
for k, v in state_dict.items():
keys.append(k)
trace.append(v.sum().item())
trace = np.array(trace)[np.argsort(keys)]
return trace
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
"""
Verifies that all the parameters matching the provided patterns
are frozen - this acts as a safeguard when ignoring parameter
when saving checkpoints - if the parameters are in fact trainable
"""
if not patterns:
return
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
non_frozen_keys = {
n
for n, p in model.named_parameters()
if n in frozen_state_dict and p.requires_grad
}
if non_frozen_keys:
raise ValueError(
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
)
@contextlib.contextmanager
def with_check_parameter_frozen(
model: nn.Module, patterns: List[str], disabled: bool = True
):
"""
Context manager that inspects a model surrounding a piece of code
and verifies if the model has been updated by this piece of code
The function will raise an exception if the model has been updated
on at least one of the parameter that matches one of the pattern
Args:
model: the model that might have been updated
patterns: for the parameters we want to observe
allowed:
"""
if not patterns or disabled:
yield
return
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
summary_before = _get_state_dict_summary(frozen_state_dict)
yield
frozen_state_dict = filter_params_matching_unix_pattern(
patterns=patterns, state_dict=model.state_dict()
)
summary_after = _get_state_dict_summary(frozen_state_dict)
if not np.allclose(summary_before, summary_after, atol=1e-6):
raise ValueError(
f"""
The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
You can resolve this error by either initializing those parameters from within the model definition
or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
"""
)
class CkptExcludeKernel:
"""
Removes the keys from the given model state_dict that match the key_pattern.
Args:
key_pattern: Patterns used to select the keys in the state_dict
that are eligible for this kernel.
"""
def __init__(self, key_pattern: List[str]):
self.key_pattern = key_pattern
def __call__(self, state_dict: Dict):
"""
Args:
state_dict: A dictionary representing the given checkpoint's state dict.
"""
if len(self.key_pattern) == 0:
return state_dict
exclude_keys = unix_pattern_to_parameter_names(
self.key_pattern, state_dict.keys()
)
return {k: v for k, v in state_dict.items() if k not in exclude_keys}
def load_checkpoint(
path_list: List[str],
pick_recursive_keys: Optional[List[str]] = None,
map_location: str = "cpu",
) -> Any:
"""
Loads a checkpoint from the specified path.
Args:
path_list: A list of paths which contain the checkpoint. Each element
is tried (in order) until a file that exists is found. That file is then
used to read the checkpoint.
pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
map_location (str): a function, torch.device, string or a dict specifying how to
remap storage locations
Returns: Model with the matchin pre-trained weights loaded.
"""
path_exists = False
for path in path_list:
if g_pathmgr.exists(path):
path_exists = True
break
if not path_exists:
raise ValueError(f"No path exists in {path_list}")
with g_pathmgr.open(path, "rb") as f:
checkpoint = torch.load(f, map_location=map_location)
logging.info(f"Loaded checkpoint from {path}")
if pick_recursive_keys is not None:
for key in pick_recursive_keys:
checkpoint = checkpoint[key]
return checkpoint
def get_state_dict(checkpoint, ckpt_state_dict_keys):
if isinstance(checkpoint, RecursiveScriptModule):
# This is a torchscript JIT model
return checkpoint.state_dict()
pre_train_dict = checkpoint
for i, key in enumerate(ckpt_state_dict_keys):
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
):
key_str = (
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
)
raise KeyError(
f"'{key}' not found in checkpoint{key_str} "
f"with keys: {pre_train_dict.keys()}"
)
pre_train_dict = pre_train_dict[key]
return pre_train_dict
def load_checkpoint_and_apply_kernels(
checkpoint_path: str,
checkpoint_kernels: List[Callable] = None,
ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
map_location: str = "cpu",
) -> nn.Module:
"""
Performs checkpoint loading with a variety of pre-processing kernel applied in
sequence.
Args:
checkpoint_path (str): Path to the checkpoint.
checkpoint_kernels List(Callable): A list of checkpoint processing kernels
to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
`CkptExcludeKernel`, etc. These kernels are applied in the
given order.
ckpt_state_dict_keys (str): Keys containing the model state dict.
map_location (str): a function, torch.device, string or a dict specifying how to
remap storage locations
Returns: Model with the matchin pre-trained weights loaded.
"""
assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
checkpoint_path
)
# Load the checkpoint on CPU to avoid GPU mem spike.
with g_pathmgr.open(checkpoint_path, "rb") as f:
checkpoint = torch.load(f, map_location=map_location)
pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
# Not logging into info etc since it's a huge log
logging.debug(
"Loaded Checkpoint State Dict pre-kernel application: %s"
% str(", ".join(list(pre_train_dict.keys())))
)
# Apply kernels
if checkpoint_kernels is not None:
for f in checkpoint_kernels:
pre_train_dict = f(state_dict=pre_train_dict)
logging.debug(
"Loaded Checkpoint State Dict Post-kernel application %s"
% str(", ".join(list(pre_train_dict.keys())))
)
return pre_train_dict
def check_load_state_dict_errors(
missing_keys,
unexpected_keys,
strict: bool,
ignore_missing_keys: List[str] = None,
ignore_unexpected_keys: List[str] = None,
):
if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
ignored_keys = unix_pattern_to_parameter_names(
ignore_missing_keys, missing_keys
)
missing_keys = [key for key in missing_keys if key not in ignored_keys]
if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
ignored_unexpected_keys = unix_pattern_to_parameter_names(
ignore_unexpected_keys, unexpected_keys
)
unexpected_keys = [
key for key in unexpected_keys if key not in ignored_unexpected_keys
]
err = "State key mismatch."
if unexpected_keys:
err += f" Unexpected keys: {unexpected_keys}."
if missing_keys:
err += f" Missing keys: {missing_keys}."
if unexpected_keys or missing_keys:
logging.warning(err)
if unexpected_keys or strict:
raise KeyError(err)
def load_state_dict_into_model(
state_dict: Dict,
model: nn.Module,
strict: bool = True,
ignore_missing_keys: List[str] = None,
ignore_unexpected_keys: List[str] = None,
checkpoint_kernels: List[Callable] = None,
):
"""
Loads a state dict into the given model.
Args:
state_dict: A dictionary containing the model's
state dict, or a subset if strict is False
model: Model to load the checkpoint weights into
strict: raise if the state_dict has missing state keys
ignore_missing_keys: unix pattern of keys to ignore
"""
# Apply kernels
if checkpoint_kernels is not None:
for f in checkpoint_kernels:
state_dict = f(state_dict=state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
check_load_state_dict_errors(
missing_keys,
unexpected_keys,
strict=strict,
ignore_missing_keys=ignore_missing_keys,
ignore_unexpected_keys=ignore_unexpected_keys,
)
return model

View File

@@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from PIL import Image as PILImage
from tensordict import tensorclass
@tensorclass
class BatchedVideoMetaData:
"""
This class represents metadata about a batch of videos.
Attributes:
unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
"""
unique_objects_identifier: torch.LongTensor
frame_orig_size: torch.LongTensor
@tensorclass
class BatchedVideoDatapoint:
"""
This class represents a batch of videos with associated annotations and metadata.
Attributes:
img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
dict_key: A string key used to identify the batch.
"""
img_batch: torch.FloatTensor
obj_to_frame_idx: torch.IntTensor
masks: torch.BoolTensor
metadata: BatchedVideoMetaData
dict_key: str
def pin_memory(self, device=None):
return self.apply(torch.Tensor.pin_memory, device=device)
@property
def num_frames(self) -> int:
"""
Returns the number of frames per video.
"""
return self.batch_size[0]
@property
def num_videos(self) -> int:
"""
Returns the number of videos in the batch.
"""
return self.img_batch.shape[1]
@property
def flat_obj_to_img_idx(self) -> torch.IntTensor:
"""
Returns a flattened tensor containing the object to img index.
The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
"""
frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
flat_idx = video_idx * self.num_frames + frame_idx
return flat_idx
@property
def flat_img_batch(self) -> torch.FloatTensor:
"""
Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
"""
return self.img_batch.transpose(0, 1).flatten(0, 1)
@dataclass
class Object:
# Id of the object in the media
object_id: int
# Index of the frame in the media (0 if single image)
frame_index: int
segment: Union[torch.Tensor, dict] # RLE dict or binary mask
@dataclass
class Frame:
data: Union[torch.Tensor, PILImage.Image]
objects: List[Object]
@dataclass
class VideoDatapoint:
"""Refers to an image/video and all its annotations"""
frames: List[Frame]
video_id: int
size: Tuple[int, int]
def collate_fn(
batch: List[VideoDatapoint],
dict_key,
) -> BatchedVideoDatapoint:
"""
Args:
batch: A list of VideoDatapoint instances.
dict_key (str): A string key used to identify the batch.
"""
img_batch = []
for video in batch:
img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
T = img_batch.shape[0]
# Prepare data structures for sequential processing. Per-frame processing but batched across videos.
step_t_objects_identifier = [[] for _ in range(T)]
step_t_frame_orig_size = [[] for _ in range(T)]
step_t_masks = [[] for _ in range(T)]
step_t_obj_to_frame_idx = [
[] for _ in range(T)
] # List to store frame indices for each time step
for video_idx, video in enumerate(batch):
orig_video_id = video.video_id
orig_frame_size = video.size
for t, frame in enumerate(video.frames):
objects = frame.objects
for obj in objects:
orig_obj_id = obj.object_id
orig_frame_idx = obj.frame_index
step_t_obj_to_frame_idx[t].append(
torch.tensor([t, video_idx], dtype=torch.int)
)
step_t_masks[t].append(obj.segment.to(torch.bool))
step_t_objects_identifier[t].append(
torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
)
step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
obj_to_frame_idx = torch.stack(
[
torch.stack(obj_to_frame_idx, dim=0)
for obj_to_frame_idx in step_t_obj_to_frame_idx
],
dim=0,
)
masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
objects_identifier = torch.stack(
[torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
)
frame_orig_size = torch.stack(
[torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
)
return BatchedVideoDatapoint(
img_batch=img_batch,
obj_to_frame_idx=obj_to_frame_idx,
masks=masks,
metadata=BatchedVideoMetaData(
unique_objects_identifier=objects_identifier,
frame_orig_size=frame_orig_size,
),
dict_key=dict_key,
batch_size=[T],
)

View File

@@ -0,0 +1,576 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import functools
import io
import logging
import os
import random
import tempfile
import time
from typing import Any, Callable, List, Tuple
import torch
import torch.autograd as autograd
import torch.distributed as dist
# Default to GPU 0
_cuda_device_index: int = 0
# Setting _cuda_device_index to -1 internally implies that we should use CPU
_CPU_DEVICE_INDEX = -1
_PRIMARY_RANK = 0
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
# Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
# being much slower than others causing a timeout (which can happen in relation
# or LVIS class mAP evaluation).
timeout = 43200
return dist.new_group(
backend="gloo",
timeout=datetime.timedelta(seconds=timeout),
)
return dist.group.WORLD
def is_main_process():
"""Return true if the current process is the main one"""
return get_rank() == 0
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
`all_gather` above, but using filesystem instead of collective ops.
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
(and other ranks will have an empty list).
"""
world_size = get_world_size()
if world_size == 1:
return [data]
print("gathering via files")
cpu_group = _get_global_gloo_group()
# if unspecified, we will save to the current python file dir
if filesys_save_dir is not None:
save_dir = filesys_save_dir
elif "EXP_DIR" in os.environ:
save_dir = os.environ["EXP_DIR"]
else:
# try the same directory where the code is stored
save_dir = filesys_save_dir or os.path.dirname(__file__)
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
if is_main_process():
os.makedirs(save_dir, exist_ok=True)
# use a timestamp and salt to distinguish different all_gather
timestamp = int(time.time()) if is_main_process() else 0
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
# broadcast the timestamp and salt across ranks
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
dist.all_reduce(timestamp_and_salt, group=cpu_group)
timestamp, salt = timestamp_and_salt.tolist()
# save the data to a file on the disk
rank_save = get_rank()
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
save_data_path = os.path.join(save_dir, save_data_filename)
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
torch.save(data, save_data_path)
dist.barrier(group=cpu_group)
# read the data from the files
data_list = []
if rank_save == 0 or not gather_to_rank_0_only:
for rank_load in range(world_size):
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
load_data_path = os.path.join(save_dir, load_data_filename)
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
data_list.append(torch.load(load_data_path))
dist.barrier(group=cpu_group)
# delete the saved file
os.remove(save_data_path)
return data_list
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
return all_gather_via_filesys(
data, filesys_save_dir, gather_to_rank_0_only=True
)
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
return all_gather_via_filesys(data, filesys_save_dir)
cpu_group = None
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(
size=(max_size - local_size,), dtype=torch.uint8, device=device
)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
data_list.append(obj)
return data_list
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This helper function converts to the correct
device and returns the tensor + original device.
"""
orig_device = "cpu" if not tensor.is_cuda else "gpu"
if (
torch.distributed.is_available()
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
and not tensor.is_cuda
):
tensor = tensor.cuda()
return (tensor, orig_device)
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
"""
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This converts the tensor back to original device.
"""
if tensor.is_cuda and orig_device == "cpu":
tensor = tensor.cpu()
return tensor
def is_distributed_training_run() -> bool:
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and (torch.distributed.get_world_size() > 1)
)
def is_primary() -> bool:
"""
Returns True if this is rank 0 of a distributed training job OR if it is
a single trainer job. Otherwise False.
"""
return get_rank() == _PRIMARY_RANK
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing mean reduction
of tensor over all processes.
"""
return all_reduce_op(
tensor,
torch.distributed.ReduceOp.SUM,
lambda t: t / torch.distributed.get_world_size(),
)
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing sum
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing min
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing min
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
def all_reduce_op(
tensor: torch.Tensor,
op: torch.distributed.ReduceOp,
after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
"""
Wrapper over torch.distributed.all_reduce for performing
reduction of tensor over all processes in both distributed /
non-distributed scenarios.
"""
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
torch.distributed.all_reduce(tensor, op)
if after_op_func is not None:
tensor = after_op_func(tensor)
tensor = convert_to_normal_tensor(tensor, orig_device)
return tensor
def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Wrapper over torch.distributed.all_gather for performing
'gather' of 'tensor' over all processes in both distributed /
non-distributed scenarios.
"""
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensors, tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]
return gathered_tensors
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
gathered_tensors = gather_tensors_from_all(tensor)
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""
Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
to all processes in both distributed / non-distributed scenarios.
"""
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
torch.distributed.broadcast(tensor, src)
tensor = convert_to_normal_tensor(tensor, orig_device)
return tensor
def barrier() -> None:
"""
Wrapper over torch.distributed.barrier, returns without waiting
if the distributed process group is not initialized instead of throwing error.
"""
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return
torch.distributed.barrier()
def get_world_size() -> int:
"""
Simple wrapper for correctly getting worldsize in both distributed
/ non-distributed settings
"""
return (
torch.distributed.get_world_size()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 1
)
def get_rank() -> int:
"""
Simple wrapper for correctly getting rank in both distributed
/ non-distributed settings
"""
return (
torch.distributed.get_rank()
if torch.distributed.is_available() and torch.distributed.is_initialized()
else 0
)
def get_primary_rank() -> int:
return _PRIMARY_RANK
def set_cuda_device_index(idx: int) -> None:
global _cuda_device_index
_cuda_device_index = idx
torch.cuda.set_device(_cuda_device_index)
def set_cpu_device() -> None:
global _cuda_device_index
_cuda_device_index = _CPU_DEVICE_INDEX
def get_cuda_device_index() -> int:
return _cuda_device_index
def init_distributed_data_parallel_model(
model: torch.nn.Module,
broadcast_buffers: bool = False,
find_unused_parameters: bool = True,
bucket_cap_mb: int = 25,
) -> torch.nn.parallel.DistributedDataParallel:
global _cuda_device_index
if _cuda_device_index == _CPU_DEVICE_INDEX:
# CPU-only model, don't specify device
return torch.nn.parallel.DistributedDataParallel(
model,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
bucket_cap_mb=bucket_cap_mb,
)
else:
# GPU model
return torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[_cuda_device_index],
output_device=_cuda_device_index,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
bucket_cap_mb=bucket_cap_mb,
)
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
"""Broadcast an object from a source to all workers.
Args:
obj: Object to broadcast, must be serializable
src: Source rank for broadcast (default is primary)
use_disk: If enabled, removes redundant CPU memory copies by writing to
disk
"""
# Either broadcast from primary to the fleet (default),
# or use the src setting as the original rank
if get_rank() == src:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data_view = buffer.getbuffer()
length_tensor = torch.LongTensor([len(data_view)])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.ByteTensor(data_view)
data_tensor = broadcast(data_tensor, src=src)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
data_tensor = broadcast(data_tensor, src=src)
if use_disk:
with tempfile.TemporaryFile("r+b") as f:
f.write(data_tensor.numpy())
# remove reference to the data tensor and hope that Python garbage
# collects it
del data_tensor
f.seek(0)
obj = torch.load(f)
else:
buffer = io.BytesIO(data_tensor.numpy())
obj = torch.load(buffer)
return obj
def all_gather_tensor(tensor: torch.Tensor, world_size=None):
if world_size is None:
world_size = get_world_size()
# make contiguous because NCCL won't gather the tensor otherwise
assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
tensor, orig_device = convert_to_distributed_tensor(tensor)
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
tensor_all = [
convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
]
return tensor_all
def all_gather_batch(tensors: List[torch.Tensor]):
"""
Performs all_gather operation on the provided tensors.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_list = []
output_tensor = []
for tensor in tensors:
tensor_all = all_gather_tensor(tensor, world_size)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
return output_tensor
class GatherLayer(autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
def all_gather_batch_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_list = []
output_tensor = []
for tensor in tensors:
tensor_all = GatherLayer.apply(tensor)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
return output_tensor
def unwrap_ddp_if_wrapped(model):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
return model.module
return model
def create_new_process_group(group_size):
"""
Creates process groups of a gives `group_size` and returns
process group that current GPU participates in.
`group_size` must divide the total number of GPUs (world_size).
Modified from
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
Args:
group_size (int): number of GPU's to collaborate for sync bn
"""
assert group_size > 0
world_size = torch.distributed.get_world_size()
if world_size <= 8:
if group_size > world_size:
logging.warning(
f"Requested group size [{group_size}] > world size [{world_size}]. "
"Assuming local debug run and capping it to world size."
)
group_size = world_size
assert world_size >= group_size
assert world_size % group_size == 0
group = None
for group_num in range(world_size // group_size):
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if torch.distributed.get_rank() // group_size == group_num:
group = cur_group
# can not drop out and return here, every process must go through creation of all subgroups
assert group is not None
return group
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True

View File

@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
import atexit
import functools
import logging
import sys
import uuid
from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from numpy import ndarray
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
Scalar = Union[Tensor, ndarray, int, float]
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
makedir(log_dir)
summary_writer_method = SummaryWriter
return TensorBoardLogger(
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
)
class TensorBoardWriterWrapper:
"""
A wrapper around a SummaryWriter object.
"""
def __init__(
self,
path: str,
*args: Any,
filename_suffix: str = None,
summary_writer_method: Any = SummaryWriter,
**kwargs: Any,
) -> None:
"""Create a new TensorBoard logger.
On construction, the logger creates a new events file that logs
will be written to. If the environment variable `RANK` is defined,
logger will only log if RANK = 0.
NOTE: If using the logger with distributed training:
- This logger can call collective operations
- Logs will be written on rank 0 only
- Logger must be constructed synchronously *after* initializing distributed process group.
Args:
path (str): path to write logs to
*args, **kwargs: Extra arguments to pass to SummaryWriter
"""
self._writer: Optional[SummaryWriter] = None
_, self._rank = get_machine_local_and_dist_rank()
self._path: str = path
if self._rank == 0:
logging.info(
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
)
self._writer = summary_writer_method(
log_dir=path,
*args,
filename_suffix=filename_suffix or str(uuid.uuid4()),
**kwargs,
)
else:
logging.debug(
f"Not logging meters on this host because env RANK: {self._rank} != 0"
)
atexit.register(self.close)
@property
def writer(self) -> Optional[SummaryWriter]:
return self._writer
@property
def path(self) -> str:
return self._path
def flush(self) -> None:
"""Writes pending logs to disk."""
if not self._writer:
return
self._writer.flush()
def close(self) -> None:
"""Close writer, flushing pending logs to disk.
Logs cannot be written after `close` is called.
"""
if not self._writer:
return
self._writer.close()
self._writer = None
class TensorBoardLogger(TensorBoardWriterWrapper):
"""
A simple logger for TensorBoard.
"""
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
"""Add multiple scalar values to TensorBoard.
Args:
payload (dict): dictionary of tag name and scalar value
step (int, Optional): step value to record
"""
if not self._writer:
return
for k, v in payload.items():
self.log(k, v, step)
def log(self, name: str, data: Scalar, step: int) -> None:
"""Add scalar data to TensorBoard.
Args:
name (string): tag name used to group scalars
data (float/int/Tensor): scalar data to log
step (int, optional): step value to record
"""
if not self._writer:
return
self._writer.add_scalar(name, data, global_step=step, new_style=True)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
"""Add hyperparameter data to TensorBoard.
Args:
hparams (dict): dictionary of hyperparameter names and corresponding values
meters (dict): dictionary of name of meter and corersponding values
"""
if not self._writer:
return
self._writer.add_hparams(hparams, meters)
class Logger:
"""
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
"""
def __init__(self, logging_conf):
# allow turning off TensorBoard with "should_log: false" in config
tb_config = logging_conf.tensorboard_writer
tb_should_log = tb_config and tb_config.pop("should_log", True)
self.tb_logger = instantiate(tb_config) if tb_should_log else None
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
if self.tb_logger:
self.tb_logger.log_dict(payload, step)
def log(self, name: str, data: Scalar, step: int) -> None:
if self.tb_logger:
self.tb_logger.log(name, data, step)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
if self.tb_logger:
self.tb_logger.log_hparams(hparams, meters)
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# we tune the buffering value so that the logs are updated
# frequently.
log_buffer_kb = 10 * 1024 # 10KB
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
atexit.register(io.close)
return io
def setup_logging(
name,
output_dir=None,
rank=0,
log_level_primary="INFO",
log_level_secondary="ERROR",
):
"""
Setup various logging streams: stdout and file handlers.
For file handlers, we only setup for the master gpu.
"""
# get the filename if we want to log to the file as well
log_filename = None
if output_dir:
makedir(output_dir)
if rank == 0:
log_filename = f"{output_dir}/log.txt"
logger = logging.getLogger(name)
logger.setLevel(log_level_primary)
# create formatter
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
formatter = logging.Formatter(FORMAT)
# Cleanup any existing handlers
for h in logger.handlers:
logger.removeHandler(h)
logger.root.handlers = []
# setup the console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if rank == 0:
console_handler.setLevel(log_level_primary)
else:
console_handler.setLevel(log_level_secondary)
# we log to file as well if user wants
if log_filename and rank == 0:
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
file_handler.setLevel(log_level_primary)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logging.root = logger
def shutdown_logging():
"""
After training is done, we ensure to shut down all the logger streams.
"""
logging.info("Shutting down loggers...")
handlers = logging.root.handlers
for handler in handlers:
handler.close()

View File

@@ -0,0 +1,288 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import os
import random
import re
from datetime import timedelta
from typing import Optional
import hydra
import numpy as np
import omegaconf
import torch
import torch.distributed as dist
from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf
def multiply_all(*args):
return np.prod(np.array(args)).item()
def collect_dict_keys(config):
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
val_keys = []
# If the this config points to the collate function, then it has a key
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
val_keys.append(config["dict_key"])
else:
# Recursively proceed
for v in config.values():
if isinstance(v, type(config)):
val_keys.extend(collect_dict_keys(v))
elif isinstance(v, omegaconf.listconfig.ListConfig):
for item in v:
if isinstance(item, type(config)):
val_keys.extend(collect_dict_keys(item))
return val_keys
class Phase:
TRAIN = "train"
VAL = "val"
def register_omegaconf_resolvers():
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
OmegaConf.register_new_resolver("times", multiply_all)
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
OmegaConf.register_new_resolver("int", lambda x: int(x))
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
def setup_distributed_backend(backend, timeout_mins):
"""
Initialize torch.distributed and set the CUDA device.
Expects environment variables to be set as per
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
"""
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
# of waiting
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
return dist.get_rank()
def get_machine_local_and_dist_rank():
"""
Get the distributed and local rank of the current gpu.
"""
local_rank = int(os.environ.get("LOCAL_RANK", None))
distributed_rank = int(os.environ.get("RANK", None))
assert (
local_rank is not None and distributed_rank is not None
), "Please the set the RANK and LOCAL_RANK environment variables."
return local_rank, distributed_rank
def print_cfg(cfg):
"""
Supports printing both Hydra DictConfig and also the AttrDict config
"""
logging.info("Training with config:")
logging.info(OmegaConf.to_yaml(cfg))
def set_seeds(seed_value, max_epochs, dist_rank):
"""
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
seeds if the CUDA is available. This ensures deterministic nature of the training.
"""
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
seed_value = (seed_value + dist_rank) * max_epochs
logging.info(f"MACHINE SEED: {seed_value}")
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed_value)
def makedir(dir_path):
"""
Create the directory if it does not exist.
"""
is_success = False
try:
if not g_pathmgr.exists(dir_path):
g_pathmgr.mkdirs(dir_path)
is_success = True
except BaseException:
logging.info(f"Error creating directory: {dir_path}")
return is_success
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_amp_type(amp_type: Optional[str] = None):
if amp_type is None:
return None
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
if amp_type == "bfloat16":
return torch.bfloat16
else:
return torch.float16
def log_env_variables():
env_keys = sorted(list(os.environ.keys()))
st = ""
for k in env_keys:
v = os.environ[k]
st += f"{k}={v}\n"
logging.info("Logging ENV_VARIABLES")
logging.info(st)
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, device, fmt=":f"):
self.name = name
self.fmt = fmt
self.device = device
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self._allow_updates = True
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class MemMeter:
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
def __init__(self, name, device, fmt=":f"):
self.name = name
self.fmt = fmt
self.device = device
self.reset()
def reset(self):
self.val = 0 # Per iteration max usage
self.avg = 0 # Avg per iteration max usage
self.peak = 0 # Peak usage for lifetime of program
self.sum = 0
self.count = 0
self._allow_updates = True
def update(self, n=1, reset_peak_usage=True):
self.val = torch.cuda.max_memory_allocated() // 1e9
self.sum += self.val * n
self.count += n
self.avg = self.sum / self.count
self.peak = max(self.peak, self.val)
if reset_peak_usage:
torch.cuda.reset_peak_memory_stats()
def __str__(self):
fmtstr = (
"{name}: {val"
+ self.fmt
+ "} ({avg"
+ self.fmt
+ "}/{peak"
+ self.fmt
+ "})"
)
return fmtstr.format(**self.__dict__)
def human_readable_time(time_seconds):
time = int(time_seconds)
minutes, seconds = divmod(time, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
return f"{days:02}d {hours:02}h {minutes:02}m"
class DurationMeter:
def __init__(self, name, device, fmt=":f"):
self.name = name
self.device = device
self.fmt = fmt
self.val = 0
def reset(self):
self.val = 0
def update(self, val):
self.val = val
def add(self, val):
self.val += val
def __str__(self):
return f"{self.name}: {human_readable_time(self.val)}"
class ProgressMeter:
def __init__(self, num_batches, meters, real_meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.real_meters = real_meters
self.prefix = prefix
def display(self, batch, enable_print=False):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
entries += [
" | ".join(
[
f"{os.path.join(name, subname)}: {val:.4f}"
for subname, val in meter.compute().items()
]
)
for name, meter in self.real_meters.items()
]
logging.info(" | ".join(entries))
if enable_print:
print(" | ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def get_resume_checkpoint(checkpoint_save_dir):
if not g_pathmgr.isdir(checkpoint_save_dir):
return None
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
if not g_pathmgr.isfile(ckpt_file):
return None
return ckpt_file