[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
116
training/README.md
Normal file
116
training/README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# Training Code for SAM 2
|
||||
|
||||
This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos.
|
||||
The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both).
|
||||
|
||||
## Structure
|
||||
|
||||
The training code is organized into the following subfolders:
|
||||
|
||||
* `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms.
|
||||
* `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling).
|
||||
* `utils`: This folder contains training utils such as loggers and distributed training utils.
|
||||
* `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training.
|
||||
* `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training.
|
||||
* `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers.
|
||||
* `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop.
|
||||
* `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h`
|
||||
|
||||
## Getting Started
|
||||
|
||||
To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets.
|
||||
|
||||
#### Requirements:
|
||||
- We assume training on A100 GPUs with **80 GB** of memory.
|
||||
- Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download).
|
||||
|
||||
#### Steps to fine-tune on MOSE:
|
||||
- Install the packages required for training by running `pip install -e ".[dev]"`.
|
||||
- Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`.
|
||||
```yaml
|
||||
dataset:
|
||||
# PATHS to Dataset
|
||||
img_folder: null # PATH to MOSE JPEGImages folder
|
||||
gt_folder: null # PATH to MOSE Annotations folder
|
||||
file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training
|
||||
```
|
||||
- To fine-tune the base model on MOSE using 8 GPUs, run
|
||||
|
||||
```python
|
||||
python training/train.py \
|
||||
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
|
||||
--use-cluster 0 \
|
||||
--num-gpus 8
|
||||
```
|
||||
|
||||
We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running
|
||||
|
||||
```python
|
||||
python training/train.py \
|
||||
-c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \
|
||||
--use-cluster 1 \
|
||||
--num-gpus 8 \
|
||||
--num-nodes 2
|
||||
--partition $PARTITION \
|
||||
--qos $QOS \
|
||||
--account $ACCOUNT
|
||||
```
|
||||
where partition, qos, and account are optional and depend on your SLURM configuration.
|
||||
By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows:
|
||||
|
||||
```yaml
|
||||
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
||||
```
|
||||
The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4.
|
||||
|
||||
|
||||
After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)).
|
||||
## Training on images and videos
|
||||
The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets:
|
||||
|
||||
```yaml
|
||||
data:
|
||||
train:
|
||||
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
||||
phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
|
||||
batch_sizes: # List of batch sizes corresponding to each dataset
|
||||
- ${bs1} # Batch size of dataset 1
|
||||
- ${bs2} # Batch size of dataset 2
|
||||
datasets:
|
||||
# SA1B as an example of an image dataset
|
||||
- _target_: training.dataset.vos_dataset.VOSDataset
|
||||
training: true
|
||||
video_dataset:
|
||||
_target_: training.dataset.vos_raw_dataset.SA1BRawDataset
|
||||
img_folder: ${path_to_img_folder}
|
||||
gt_folder: ${path_to_gt_folder}
|
||||
file_list_txt: ${path_to_train_filelist} # Optional
|
||||
sampler:
|
||||
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
||||
num_frames: 1
|
||||
max_num_objects: ${max_num_objects_per_image}
|
||||
transforms: ${image_transforms}
|
||||
# SA-V as an example of a video dataset
|
||||
- _target_: training.dataset.vos_dataset.VOSDataset
|
||||
training: true
|
||||
video_dataset:
|
||||
_target_: training.dataset.vos_raw_dataset.JSONRawDataset
|
||||
img_folder: ${path_to_img_folder}
|
||||
gt_folder: ${path_to_gt_folder}
|
||||
file_list_txt: ${path_to_train_filelist} # Optional
|
||||
ann_every: 4
|
||||
sampler:
|
||||
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
||||
num_frames: 8 # Number of frames per video
|
||||
max_num_objects: ${max_num_objects_per_video}
|
||||
reverse_time_prob: ${reverse_time_prob} # probability to reverse video
|
||||
transforms: ${video_transforms}
|
||||
shuffle: True
|
||||
num_workers: ${num_train_workers}
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
collate_fn:
|
||||
_target_: training.utils.data_utils.collate_fn
|
||||
_partial_: true
|
||||
dict_key: all
|
||||
```
|
5
training/__init__.py
Normal file
5
training/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
1246
training/assets/MOSE_sample_train_list.txt
Normal file
1246
training/assets/MOSE_sample_train_list.txt
Normal file
File diff suppressed because it is too large
Load Diff
200
training/assets/MOSE_sample_val_list.txt
Normal file
200
training/assets/MOSE_sample_val_list.txt
Normal file
@@ -0,0 +1,200 @@
|
||||
32e5d721
|
||||
5bad0bab
|
||||
267bfd6c
|
||||
0a43a414
|
||||
56c56ca9
|
||||
9a1146b3
|
||||
c6ad7aaf
|
||||
78a1f4b1
|
||||
fc455e73
|
||||
072e7b3f
|
||||
77ccb57d
|
||||
a76ee415
|
||||
8cdcfc17
|
||||
5d518b42
|
||||
376dd830
|
||||
0e843fc8
|
||||
2af0e766
|
||||
2bd4e845
|
||||
de2f2a6a
|
||||
ade9ee91
|
||||
001ca3cb
|
||||
fc4c1c67
|
||||
8ef55579
|
||||
b84ce852
|
||||
4cc8528a
|
||||
767ffaaa
|
||||
112a2ef0
|
||||
a338c8aa
|
||||
cbd144f5
|
||||
5ff72128
|
||||
86a949e2
|
||||
9f2323ac
|
||||
1fab1d1c
|
||||
75924351
|
||||
ef55817b
|
||||
02deca50
|
||||
4d979d99
|
||||
4d65f873
|
||||
28470fa0
|
||||
0d1575fe
|
||||
06ea172e
|
||||
29a6ddc2
|
||||
797f1bec
|
||||
780e7a99
|
||||
b9ed5b44
|
||||
02a236b4
|
||||
607d8ff5
|
||||
af5666b2
|
||||
0558d0ed
|
||||
a938c6b2
|
||||
103df575
|
||||
77110e80
|
||||
739e5a07
|
||||
6763a576
|
||||
06ebc138
|
||||
ba4b3b09
|
||||
b35cc2f3
|
||||
4e0597a0
|
||||
5949ee84
|
||||
5348d547
|
||||
323c4236
|
||||
b3b51117
|
||||
55727ddd
|
||||
ab2714f3
|
||||
d2878895
|
||||
c0734cb3
|
||||
94f7c53e
|
||||
2a2745e5
|
||||
442ffb54
|
||||
3592425a
|
||||
50ae03b0
|
||||
5f150435
|
||||
3067f9fa
|
||||
9ffb2818
|
||||
adeaf5aa
|
||||
31caacec
|
||||
1cd99b86
|
||||
aa22f9d0
|
||||
8fa50320
|
||||
e6348d2c
|
||||
42ff84a5
|
||||
8c8b7913
|
||||
c96adcbc
|
||||
495be321
|
||||
db735509
|
||||
ee113fc4
|
||||
a678cdab
|
||||
c409ca4d
|
||||
68d2b259
|
||||
592b4dee
|
||||
4e2b4dc7
|
||||
eb4d26e1
|
||||
2009a00f
|
||||
bec5c89d
|
||||
67191f24
|
||||
a3e85b4b
|
||||
da7080cd
|
||||
80d978e9
|
||||
36dcb93f
|
||||
a41e8c44
|
||||
12fdc864
|
||||
46d140ea
|
||||
657c9dd9
|
||||
a86f84ee
|
||||
90c1c43d
|
||||
33015509
|
||||
afc7664d
|
||||
23df06e1
|
||||
291d4799
|
||||
0ab75563
|
||||
251bf059
|
||||
bcefdcc4
|
||||
ce9a2796
|
||||
94d3403a
|
||||
8f2e04bc
|
||||
f9cda066
|
||||
9dfa2cc5
|
||||
66924c91
|
||||
e765a09e
|
||||
15654ee1
|
||||
48e0bd39
|
||||
ee095221
|
||||
2463609b
|
||||
544d0d1f
|
||||
51b8c2e1
|
||||
d321dde4
|
||||
4cb11a5f
|
||||
d7058a0d
|
||||
37af282a
|
||||
fabae187
|
||||
7be91184
|
||||
181ec185
|
||||
2d16ceeb
|
||||
b56be4b1
|
||||
6699eff0
|
||||
79acac96
|
||||
d61c4665
|
||||
0c13e1e7
|
||||
100f6ecf
|
||||
71217dfc
|
||||
82df0888
|
||||
4c42c747
|
||||
c9fdf703
|
||||
d2efeb4b
|
||||
69ed9d14
|
||||
64914fb6
|
||||
255bedbc
|
||||
4ea934d8
|
||||
a034feb2
|
||||
e4f4ddae
|
||||
e36a3026
|
||||
c1489591
|
||||
111bb373
|
||||
e1d9fb32
|
||||
93e22d48
|
||||
c1ec4b26
|
||||
d9638e69
|
||||
60ab04c5
|
||||
cfe7773a
|
||||
62132822
|
||||
2f5fb2a3
|
||||
7bdd197d
|
||||
033333fd
|
||||
130fcdbe
|
||||
12e509c2
|
||||
67138c33
|
||||
6f90cc5f
|
||||
4e3020fe
|
||||
bbdd8bb7
|
||||
b399ccdb
|
||||
fecd10d2
|
||||
2e0967f7
|
||||
f509054f
|
||||
792c6ff7
|
||||
48e2afc5
|
||||
d904c048
|
||||
111e0a5c
|
||||
b83024e2
|
||||
e6a7b79c
|
||||
bdc5ccf7
|
||||
b8146d00
|
||||
9d394f1a
|
||||
645b84f9
|
||||
95ab2d0f
|
||||
e6f8a31d
|
||||
b4f876fb
|
||||
dc2c570d
|
||||
3afd02d7
|
||||
5c80c82c
|
||||
b1b32ddd
|
||||
9f25fc61
|
||||
ba538072
|
||||
f8916fef
|
||||
43c04ad2
|
||||
a658e949
|
||||
2861dd53
|
||||
f6e40aba
|
||||
09d305d1
|
||||
aac33bff
|
||||
8d9d4c08
|
5
training/dataset/__init__.py
Normal file
5
training/dataset/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
180
training/dataset/sam2_datasets.py
Normal file
180
training/dataset/sam2_datasets.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
|
||||
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class MixedDataLoader:
|
||||
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
|
||||
"""
|
||||
Args:
|
||||
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
|
||||
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
|
||||
|
||||
"""
|
||||
assert len(dataloaders) == mixing_prob.shape[0]
|
||||
self.dataloaders = dataloaders
|
||||
self.mixing_prob = mixing_prob
|
||||
# Iterator state
|
||||
self._iter_dls = None
|
||||
self._iter_mixing_prob = None
|
||||
self.random_generator = torch.Generator()
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(d) for d in self.dataloaders])
|
||||
|
||||
def __iter__(self):
|
||||
# Synchronize dataloader seeds
|
||||
self.random_generator.manual_seed(42)
|
||||
self._iter_dls = [iter(loader) for loader in self.dataloaders]
|
||||
self._iter_mixing_prob = self.mixing_prob.clone()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
|
||||
"""
|
||||
if self._iter_dls is None:
|
||||
raise TypeError(f"{type(self).__name__} object is not an iterator")
|
||||
|
||||
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
|
||||
dataset_idx = self._iter_mixing_prob.multinomial(
|
||||
1, generator=self.random_generator
|
||||
).item()
|
||||
try:
|
||||
item = next(self._iter_dls[dataset_idx])
|
||||
return item
|
||||
except StopIteration:
|
||||
# No more iterations for this dataset, set it's mixing probability to zero and try again.
|
||||
self._iter_mixing_prob[dataset_idx] = 0
|
||||
except Exception as e:
|
||||
# log and raise any other unexpected error.
|
||||
logging.error(e)
|
||||
raise e
|
||||
|
||||
# Exhausted all iterators
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class TorchTrainMixedDataset:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: List[Dataset],
|
||||
batch_sizes: List[int],
|
||||
num_workers: int,
|
||||
shuffle: bool,
|
||||
pin_memory: bool,
|
||||
drop_last: bool,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
worker_init_fn: Optional[Callable] = None,
|
||||
phases_per_epoch: int = 1,
|
||||
dataset_prob: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
datasets (List[Dataset]): List of Datasets to be mixed.
|
||||
batch_sizes (List[int]): Batch sizes for each dataset in the list.
|
||||
num_workers (int): Number of workers per dataloader.
|
||||
shuffle (bool): Whether or not to shuffle data.
|
||||
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
|
||||
drop_last (bool): Whether or not to drop the last batch of data.
|
||||
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
|
||||
worker_init_fn (Callable): Function to init each dataloader worker.
|
||||
phases_per_epoch (int): Number of phases per epoch.
|
||||
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
|
||||
"""
|
||||
|
||||
self.datasets = datasets
|
||||
self.batch_sizes = batch_sizes
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.collate_fn = collate_fn
|
||||
self.worker_init_fn = worker_init_fn
|
||||
assert len(self.datasets) > 0
|
||||
for dataset in self.datasets:
|
||||
assert not isinstance(dataset, IterableDataset), "Not supported"
|
||||
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
|
||||
self._set_dataset_epoch(dataset, 0)
|
||||
self.phases_per_epoch = phases_per_epoch
|
||||
self.chunks = [None] * len(datasets)
|
||||
if dataset_prob is None:
|
||||
# If not provided, assign each dataset a probability proportional to its length.
|
||||
dataset_lens = [
|
||||
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
|
||||
for d, bs in zip(datasets, batch_sizes)
|
||||
]
|
||||
total_len = sum(dataset_lens)
|
||||
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
|
||||
else:
|
||||
assert len(dataset_prob) == len(datasets)
|
||||
dataset_prob = torch.tensor(dataset_prob)
|
||||
|
||||
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
|
||||
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
|
||||
self.dataset_prob = dataset_prob
|
||||
|
||||
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
def get_loader(self, epoch) -> Iterable:
|
||||
dataloaders = []
|
||||
for d_idx, (dataset, batch_size) in enumerate(
|
||||
zip(self.datasets, self.batch_sizes)
|
||||
):
|
||||
if self.phases_per_epoch > 1:
|
||||
# Major epoch that looops over entire dataset
|
||||
# len(main_epoch) == phases_per_epoch * len(epoch)
|
||||
main_epoch = epoch // self.phases_per_epoch
|
||||
|
||||
# Phase with in the main epoch
|
||||
local_phase = epoch % self.phases_per_epoch
|
||||
|
||||
# Start of new data-epoch or job is resumed after preemtion.
|
||||
if local_phase == 0 or self.chunks[d_idx] is None:
|
||||
# set seed for dataset epoch
|
||||
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
|
||||
self._set_dataset_epoch(dataset, main_epoch)
|
||||
|
||||
# Separate random generator for subset sampling
|
||||
g = torch.Generator()
|
||||
g.manual_seed(main_epoch)
|
||||
self.chunks[d_idx] = torch.chunk(
|
||||
torch.randperm(len(dataset), generator=g),
|
||||
self.phases_per_epoch,
|
||||
)
|
||||
|
||||
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
|
||||
else:
|
||||
self._set_dataset_epoch(dataset, epoch)
|
||||
|
||||
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
|
||||
dataloaders.append(
|
||||
DataLoader(
|
||||
dataset,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
)
|
||||
return MixedDataLoader(dataloaders, self.dataset_prob)
|
528
training/dataset/transforms.py
Normal file
528
training/dataset/transforms.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms.v2.functional as Fv2
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from training.utils.data_utils import VideoDatapoint
|
||||
|
||||
|
||||
def hflip(datapoint, index):
|
||||
|
||||
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
obj.segment = F.hflip(obj.segment)
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = max_size * min_original_size / max_original_size
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = int(round(size))
|
||||
oh = int(round(size * h / w))
|
||||
else:
|
||||
oh = int(round(size))
|
||||
ow = int(round(size * w / h))
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
|
||||
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
if square:
|
||||
size = size, size
|
||||
else:
|
||||
cur_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
size = get_size(cur_size, size, max_size)
|
||||
|
||||
old_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
if v2:
|
||||
datapoint.frames[index].data = Fv2.resize(
|
||||
datapoint.frames[index].data, size, antialias=True
|
||||
)
|
||||
else:
|
||||
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
|
||||
|
||||
new_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
|
||||
|
||||
h, w = size
|
||||
datapoint.frames[index].size = (h, w)
|
||||
return datapoint
|
||||
|
||||
|
||||
def pad(datapoint, index, padding, v2=False):
|
||||
old_h, old_w = datapoint.frames[index].size
|
||||
h, w = old_h, old_w
|
||||
if len(padding) == 2:
|
||||
# assumes that we only pad on the bottom right corners
|
||||
datapoint.frames[index].data = F.pad(
|
||||
datapoint.frames[index].data, (0, 0, padding[0], padding[1])
|
||||
)
|
||||
h += padding[1]
|
||||
w += padding[0]
|
||||
else:
|
||||
# left, top, right, bottom
|
||||
datapoint.frames[index].data = F.pad(
|
||||
datapoint.frames[index].data,
|
||||
(padding[0], padding[1], padding[2], padding[3]),
|
||||
)
|
||||
h += padding[1] + padding[3]
|
||||
w += padding[0] + padding[2]
|
||||
|
||||
datapoint.frames[index].size = (h, w)
|
||||
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
if v2:
|
||||
if len(padding) == 2:
|
||||
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
obj.segment = Fv2.pad(obj.segment, tuple(padding))
|
||||
else:
|
||||
if len(padding) == 2:
|
||||
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
obj.segment = F.pad(obj.segment, tuple(padding))
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, consistent_transform, p=0.5):
|
||||
self.p = p
|
||||
self.consistent_transform = consistent_transform
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
if random.random() < self.p:
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = hflip(datapoint, i)
|
||||
return datapoint
|
||||
for i in range(len(datapoint.frames)):
|
||||
if random.random() < self.p:
|
||||
datapoint = hflip(datapoint, i)
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomResizeAPI:
|
||||
def __init__(
|
||||
self, sizes, consistent_transform, max_size=None, square=False, v2=False
|
||||
):
|
||||
if isinstance(sizes, int):
|
||||
sizes = (sizes,)
|
||||
assert isinstance(sizes, Iterable)
|
||||
self.sizes = list(sizes)
|
||||
self.max_size = max_size
|
||||
self.square = square
|
||||
self.consistent_transform = consistent_transform
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
size = random.choice(self.sizes)
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = resize(
|
||||
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
||||
)
|
||||
return datapoint
|
||||
for i in range(len(datapoint.frames)):
|
||||
size = random.choice(self.sizes)
|
||||
datapoint = resize(
|
||||
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
||||
)
|
||||
return datapoint
|
||||
|
||||
|
||||
class ToTensorAPI:
|
||||
def __init__(self, v2=False):
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for img in datapoint.frames:
|
||||
if self.v2:
|
||||
img.data = Fv2.to_image_tensor(img.data)
|
||||
else:
|
||||
img.data = F.to_tensor(img.data)
|
||||
return datapoint
|
||||
|
||||
|
||||
class NormalizeAPI:
|
||||
def __init__(self, mean, std, v2=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for img in datapoint.frames:
|
||||
if self.v2:
|
||||
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
|
||||
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
|
||||
else:
|
||||
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class ComposeAPI:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
for t in self.transforms:
|
||||
datapoint = t(datapoint, **kwargs)
|
||||
return datapoint
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class RandomGrayscale:
|
||||
def __init__(self, consistent_transform, p=0.5):
|
||||
self.p = p
|
||||
self.consistent_transform = consistent_transform
|
||||
self.Grayscale = T.Grayscale(num_output_channels=3)
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
if random.random() < self.p:
|
||||
for img in datapoint.frames:
|
||||
img.data = self.Grayscale(img.data)
|
||||
return datapoint
|
||||
for img in datapoint.frames:
|
||||
if random.random() < self.p:
|
||||
img.data = self.Grayscale(img.data)
|
||||
return datapoint
|
||||
|
||||
|
||||
class ColorJitter:
|
||||
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
|
||||
self.consistent_transform = consistent_transform
|
||||
self.brightness = (
|
||||
brightness
|
||||
if isinstance(brightness, list)
|
||||
else [max(0, 1 - brightness), 1 + brightness]
|
||||
)
|
||||
self.contrast = (
|
||||
contrast
|
||||
if isinstance(contrast, list)
|
||||
else [max(0, 1 - contrast), 1 + contrast]
|
||||
)
|
||||
self.saturation = (
|
||||
saturation
|
||||
if isinstance(saturation, list)
|
||||
else [max(0, 1 - saturation), 1 + saturation]
|
||||
)
|
||||
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
# Create a color jitter transformation params
|
||||
(
|
||||
fn_idx,
|
||||
brightness_factor,
|
||||
contrast_factor,
|
||||
saturation_factor,
|
||||
hue_factor,
|
||||
) = T.ColorJitter.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
for img in datapoint.frames:
|
||||
if not self.consistent_transform:
|
||||
(
|
||||
fn_idx,
|
||||
brightness_factor,
|
||||
contrast_factor,
|
||||
saturation_factor,
|
||||
hue_factor,
|
||||
) = T.ColorJitter.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness_factor is not None:
|
||||
img.data = F.adjust_brightness(img.data, brightness_factor)
|
||||
elif fn_id == 1 and contrast_factor is not None:
|
||||
img.data = F.adjust_contrast(img.data, contrast_factor)
|
||||
elif fn_id == 2 and saturation_factor is not None:
|
||||
img.data = F.adjust_saturation(img.data, saturation_factor)
|
||||
elif fn_id == 3 and hue_factor is not None:
|
||||
img.data = F.adjust_hue(img.data, hue_factor)
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomAffine:
|
||||
def __init__(
|
||||
self,
|
||||
degrees,
|
||||
consistent_transform,
|
||||
scale=None,
|
||||
translate=None,
|
||||
shear=None,
|
||||
image_mean=(123, 116, 103),
|
||||
log_warning=True,
|
||||
num_tentatives=1,
|
||||
image_interpolation="bicubic",
|
||||
):
|
||||
"""
|
||||
The mask is required for this transform.
|
||||
if consistent_transform if True, then the same random affine is applied to all frames and masks.
|
||||
"""
|
||||
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
|
||||
self.scale = scale
|
||||
self.shear = (
|
||||
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
|
||||
)
|
||||
self.translate = translate
|
||||
self.fill_img = image_mean
|
||||
self.consistent_transform = consistent_transform
|
||||
self.log_warning = log_warning
|
||||
self.num_tentatives = num_tentatives
|
||||
|
||||
if image_interpolation == "bicubic":
|
||||
self.image_interpolation = InterpolationMode.BICUBIC
|
||||
elif image_interpolation == "bilinear":
|
||||
self.image_interpolation = InterpolationMode.BILINEAR
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for _tentative in range(self.num_tentatives):
|
||||
res = self.transform_datapoint(datapoint)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if self.log_warning:
|
||||
logging.warning(
|
||||
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
|
||||
)
|
||||
return datapoint
|
||||
|
||||
def transform_datapoint(self, datapoint: VideoDatapoint):
|
||||
_, height, width = F.get_dimensions(datapoint.frames[0].data)
|
||||
img_size = [width, height]
|
||||
|
||||
if self.consistent_transform:
|
||||
# Create a random affine transformation
|
||||
affine_params = T.RandomAffine.get_params(
|
||||
degrees=self.degrees,
|
||||
translate=self.translate,
|
||||
scale_ranges=self.scale,
|
||||
shears=self.shear,
|
||||
img_size=img_size,
|
||||
)
|
||||
|
||||
for img_idx, img in enumerate(datapoint.frames):
|
||||
this_masks = [
|
||||
obj.segment.unsqueeze(0) if obj.segment is not None else None
|
||||
for obj in img.objects
|
||||
]
|
||||
if not self.consistent_transform:
|
||||
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
|
||||
affine_params = T.RandomAffine.get_params(
|
||||
degrees=self.degrees,
|
||||
translate=self.translate,
|
||||
scale_ranges=self.scale,
|
||||
shears=self.shear,
|
||||
img_size=img_size,
|
||||
)
|
||||
|
||||
transformed_bboxes, transformed_masks = [], []
|
||||
for i in range(len(img.objects)):
|
||||
if this_masks[i] is None:
|
||||
transformed_masks.append(None)
|
||||
# Dummy bbox for a dummy target
|
||||
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
|
||||
else:
|
||||
transformed_mask = F.affine(
|
||||
this_masks[i],
|
||||
*affine_params,
|
||||
interpolation=InterpolationMode.NEAREST,
|
||||
fill=0.0,
|
||||
)
|
||||
if img_idx == 0 and transformed_mask.max() == 0:
|
||||
# We are dealing with a video and the object is not visible in the first frame
|
||||
# Return the datapoint without transformation
|
||||
return None
|
||||
transformed_masks.append(transformed_mask.squeeze())
|
||||
|
||||
for i in range(len(img.objects)):
|
||||
img.objects[i].segment = transformed_masks[i]
|
||||
|
||||
img.data = F.affine(
|
||||
img.data,
|
||||
*affine_params,
|
||||
interpolation=self.image_interpolation,
|
||||
fill=self.fill_img,
|
||||
)
|
||||
return datapoint
|
||||
|
||||
|
||||
def random_mosaic_frame(
|
||||
datapoint,
|
||||
index,
|
||||
grid_h,
|
||||
grid_w,
|
||||
target_grid_y,
|
||||
target_grid_x,
|
||||
should_hflip,
|
||||
):
|
||||
# Step 1: downsize the images and paste them into a mosaic
|
||||
image_data = datapoint.frames[index].data
|
||||
is_pil = isinstance(image_data, PILImage.Image)
|
||||
if is_pil:
|
||||
H_im = image_data.height
|
||||
W_im = image_data.width
|
||||
image_data_output = PILImage.new("RGB", (W_im, H_im))
|
||||
else:
|
||||
H_im = image_data.size(-2)
|
||||
W_im = image_data.size(-1)
|
||||
image_data_output = torch.zeros_like(image_data)
|
||||
|
||||
downsize_cache = {}
|
||||
for grid_y in range(grid_h):
|
||||
for grid_x in range(grid_w):
|
||||
y_offset_b = grid_y * H_im // grid_h
|
||||
x_offset_b = grid_x * W_im // grid_w
|
||||
y_offset_e = (grid_y + 1) * H_im // grid_h
|
||||
x_offset_e = (grid_x + 1) * W_im // grid_w
|
||||
H_im_downsize = y_offset_e - y_offset_b
|
||||
W_im_downsize = x_offset_e - x_offset_b
|
||||
|
||||
if (H_im_downsize, W_im_downsize) in downsize_cache:
|
||||
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
|
||||
else:
|
||||
image_data_downsize = F.resize(
|
||||
image_data,
|
||||
size=(H_im_downsize, W_im_downsize),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
antialias=True, # antialiasing for downsizing
|
||||
)
|
||||
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
|
||||
if should_hflip[grid_y, grid_x].item():
|
||||
image_data_downsize = F.hflip(image_data_downsize)
|
||||
|
||||
if is_pil:
|
||||
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
|
||||
else:
|
||||
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
|
||||
image_data_downsize
|
||||
)
|
||||
|
||||
datapoint.frames[index].data = image_data_output
|
||||
|
||||
# Step 2: downsize the masks and paste them into the target grid of the mosaic
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is None:
|
||||
continue
|
||||
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
|
||||
segment_output = torch.zeros_like(obj.segment)
|
||||
|
||||
target_y_offset_b = target_grid_y * H_im // grid_h
|
||||
target_x_offset_b = target_grid_x * W_im // grid_w
|
||||
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
|
||||
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
|
||||
target_H_im_downsize = target_y_offset_e - target_y_offset_b
|
||||
target_W_im_downsize = target_x_offset_e - target_x_offset_b
|
||||
|
||||
segment_downsize = F.resize(
|
||||
obj.segment[None, None],
|
||||
size=(target_H_im_downsize, target_W_im_downsize),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
antialias=True, # antialiasing for downsizing
|
||||
)[0, 0]
|
||||
if should_hflip[target_grid_y, target_grid_x].item():
|
||||
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
|
||||
|
||||
segment_output[
|
||||
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
|
||||
] = segment_downsize
|
||||
obj.segment = segment_output
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomMosaicVideoAPI:
|
||||
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
|
||||
self.prob = prob
|
||||
self.grid_h = grid_h
|
||||
self.grid_w = grid_w
|
||||
self.use_random_hflip = use_random_hflip
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if random.random() > self.prob:
|
||||
return datapoint
|
||||
|
||||
# select a random location to place the target mask in the mosaic
|
||||
target_grid_y = random.randint(0, self.grid_h - 1)
|
||||
target_grid_x = random.randint(0, self.grid_w - 1)
|
||||
# whether to flip each grid in the mosaic horizontally
|
||||
if self.use_random_hflip:
|
||||
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
|
||||
else:
|
||||
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = random_mosaic_frame(
|
||||
datapoint,
|
||||
i,
|
||||
grid_h=self.grid_h,
|
||||
grid_w=self.grid_w,
|
||||
target_grid_y=target_grid_y,
|
||||
target_grid_x=target_grid_x,
|
||||
should_hflip=should_hflip,
|
||||
)
|
||||
|
||||
return datapoint
|
104
training/dataset/utils.py
Normal file
104
training/dataset/utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
ConcatDataset as TorchConcatDataset,
|
||||
Dataset,
|
||||
Subset as TorchSubset,
|
||||
)
|
||||
|
||||
|
||||
class ConcatDataset(TorchConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
|
||||
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
for dataset in self.datasets:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
class Subset(TorchSubset):
|
||||
def __init__(self, dataset, indices) -> None:
|
||||
super(Subset, self).__init__(dataset, indices)
|
||||
|
||||
self.repeat_factors = dataset.repeat_factors[indices]
|
||||
assert len(indices) == len(self.repeat_factors)
|
||||
|
||||
|
||||
# Adapted from Detectron2
|
||||
class RepeatFactorWrapper(Dataset):
|
||||
"""
|
||||
Thin wrapper around a dataset to implement repeat factor sampling.
|
||||
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
|
||||
Set it to uniformly ones to disable repeat factor sampling
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, seed: int = 0):
|
||||
self.dataset = dataset
|
||||
self.epoch_ids = None
|
||||
self._seed = seed
|
||||
|
||||
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
||||
self._int_part = torch.trunc(dataset.repeat_factors)
|
||||
self._frac_part = dataset.repeat_factors - self._int_part
|
||||
|
||||
def _get_epoch_indices(self, generator):
|
||||
"""
|
||||
Create a list of dataset indices (with repeats) to use for one epoch.
|
||||
|
||||
Args:
|
||||
generator (torch.Generator): pseudo random number generator used for
|
||||
stochastic rounding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
||||
is repeated based on its calculated repeat factor.
|
||||
"""
|
||||
# Since repeat factors are fractional, we use stochastic rounding so
|
||||
# that the target repeat factor is achieved in expectation over the
|
||||
# course of training
|
||||
rands = torch.rand(len(self._frac_part), generator=generator)
|
||||
rep_factors = self._int_part + (rands < self._frac_part).float()
|
||||
# Construct a list of indices in which we repeat images as specified
|
||||
indices = []
|
||||
for dataset_index, rep_factor in enumerate(rep_factors):
|
||||
indices.extend([dataset_index] * int(rep_factor.item()))
|
||||
return torch.tensor(indices, dtype=torch.int64)
|
||||
|
||||
def __len__(self):
|
||||
if self.epoch_ids is None:
|
||||
# Here we raise an error instead of returning original len(self.dataset) avoid
|
||||
# accidentally using unwrapped length. Otherwise it's error-prone since the
|
||||
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
|
||||
raise RuntimeError("please call set_epoch first to get wrapped length")
|
||||
# return len(self.dataset)
|
||||
|
||||
return len(self.epoch_ids)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed + epoch)
|
||||
self.epoch_ids = self._get_epoch_indices(g)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.epoch_ids is None:
|
||||
raise RuntimeError(
|
||||
"Repeat ids haven't been computed. Did you forget to call set_epoch?"
|
||||
)
|
||||
|
||||
return self.dataset[self.epoch_ids[idx]]
|
162
training/dataset/vos_dataset.py
Normal file
162
training/dataset/vos_dataset.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from PIL import Image as PILImage
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from training.dataset.vos_raw_dataset import VOSRawDataset
|
||||
from training.dataset.vos_sampler import VOSSampler
|
||||
from training.dataset.vos_segment_loader import JSONSegmentLoader
|
||||
|
||||
from training.utils.data_utils import Frame, Object, VideoDatapoint
|
||||
|
||||
MAX_RETRIES = 100
|
||||
|
||||
|
||||
class VOSDataset(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
transforms,
|
||||
training: bool,
|
||||
video_dataset: VOSRawDataset,
|
||||
sampler: VOSSampler,
|
||||
multiplier: int,
|
||||
always_target=True,
|
||||
target_segments_available=True,
|
||||
):
|
||||
self._transforms = transforms
|
||||
self.training = training
|
||||
self.video_dataset = video_dataset
|
||||
self.sampler = sampler
|
||||
|
||||
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
|
||||
self.repeat_factors *= multiplier
|
||||
print(f"Raw dataset length = {len(self.video_dataset)}")
|
||||
|
||||
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
||||
self.always_target = always_target
|
||||
self.target_segments_available = target_segments_available
|
||||
|
||||
def _get_datapoint(self, idx):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
try:
|
||||
if isinstance(idx, torch.Tensor):
|
||||
idx = idx.item()
|
||||
# sample a video
|
||||
video, segment_loader = self.video_dataset.get_video(idx)
|
||||
# sample frames and object indices to be used in a datapoint
|
||||
sampled_frms_and_objs = self.sampler.sample(
|
||||
video, segment_loader, epoch=self.curr_epoch
|
||||
)
|
||||
break # Succesfully loaded video
|
||||
except Exception as e:
|
||||
if self.training:
|
||||
logging.warning(
|
||||
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
|
||||
)
|
||||
idx = random.randrange(0, len(self.video_dataset))
|
||||
else:
|
||||
# Shouldn't fail to load a val video
|
||||
raise e
|
||||
|
||||
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
|
||||
for transform in self._transforms:
|
||||
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
||||
return datapoint
|
||||
|
||||
def construct(self, video, sampled_frms_and_objs, segment_loader):
|
||||
"""
|
||||
Constructs a VideoDatapoint sample to pass to transforms
|
||||
"""
|
||||
sampled_frames = sampled_frms_and_objs.frames
|
||||
sampled_object_ids = sampled_frms_and_objs.object_ids
|
||||
|
||||
images = []
|
||||
rgb_images = load_images(sampled_frames)
|
||||
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
|
||||
for frame_idx, frame in enumerate(sampled_frames):
|
||||
w, h = rgb_images[frame_idx].size
|
||||
images.append(
|
||||
Frame(
|
||||
data=rgb_images[frame_idx],
|
||||
objects=[],
|
||||
)
|
||||
)
|
||||
# We load the gt segments associated with the current frame
|
||||
if isinstance(segment_loader, JSONSegmentLoader):
|
||||
segments = segment_loader.load(
|
||||
frame.frame_idx, obj_ids=sampled_object_ids
|
||||
)
|
||||
else:
|
||||
segments = segment_loader.load(frame.frame_idx)
|
||||
for obj_id in sampled_object_ids:
|
||||
# Extract the segment
|
||||
if obj_id in segments:
|
||||
assert (
|
||||
segments[obj_id] is not None
|
||||
), "None targets are not supported"
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
segment = segments[obj_id].to(torch.uint8)
|
||||
else:
|
||||
# There is no target, we either use a zero mask target or drop this object
|
||||
if not self.always_target:
|
||||
continue
|
||||
segment = torch.zeros(h, w, dtype=torch.uint8)
|
||||
|
||||
images[frame_idx].objects.append(
|
||||
Object(
|
||||
object_id=obj_id,
|
||||
frame_index=frame.frame_idx,
|
||||
segment=segment,
|
||||
)
|
||||
)
|
||||
return VideoDatapoint(
|
||||
frames=images,
|
||||
video_id=video.video_id,
|
||||
size=(h, w),
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._get_datapoint(idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_dataset)
|
||||
|
||||
|
||||
def load_images(frames):
|
||||
all_images = []
|
||||
cache = {}
|
||||
for frame in frames:
|
||||
if frame.data is None:
|
||||
# Load the frame rgb data from file
|
||||
path = frame.image_path
|
||||
if path in cache:
|
||||
all_images.append(deepcopy(all_images[cache[path]]))
|
||||
continue
|
||||
with g_pathmgr.open(path, "rb") as fopen:
|
||||
all_images.append(PILImage.open(fopen).convert("RGB"))
|
||||
cache[path] = len(all_images) - 1
|
||||
else:
|
||||
# The frame rgb data has already been loaded
|
||||
# Convert it to a PILImage
|
||||
all_images.append(tensor_2_PIL(frame.data))
|
||||
|
||||
return all_images
|
||||
|
||||
|
||||
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
|
||||
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
|
||||
data = data.astype(np.uint8)
|
||||
return PILImage.fromarray(data)
|
308
training/dataset/vos_raw_dataset.py
Normal file
308
training/dataset/vos_raw_dataset.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
from training.dataset.vos_segment_loader import (
|
||||
JSONSegmentLoader,
|
||||
MultiplePNGSegmentLoader,
|
||||
PalettisedPNGSegmentLoader,
|
||||
SA1BSegmentLoader,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VOSFrame:
|
||||
frame_idx: int
|
||||
image_path: str
|
||||
data: Optional[torch.Tensor] = None
|
||||
is_conditioning_only: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VOSVideo:
|
||||
video_name: str
|
||||
video_id: int
|
||||
frames: List[VOSFrame]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
|
||||
class VOSRawDataset:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_video(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PNGRawDataset(VOSRawDataset):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
sample_rate=1,
|
||||
is_palette=True,
|
||||
single_object_mode=False,
|
||||
truncate_video=-1,
|
||||
frames_sampling_mult=False,
|
||||
):
|
||||
self.img_folder = img_folder
|
||||
self.gt_folder = gt_folder
|
||||
self.sample_rate = sample_rate
|
||||
self.is_palette = is_palette
|
||||
self.single_object_mode = single_object_mode
|
||||
self.truncate_video = truncate_video
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
|
||||
# Read and process excluded files if provided
|
||||
if excluded_videos_list_txt is not None:
|
||||
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
excluded_files = []
|
||||
|
||||
# Check if it's not in excluded_files
|
||||
self.video_names = sorted(
|
||||
[video_name for video_name in subset if video_name not in excluded_files]
|
||||
)
|
||||
|
||||
if self.single_object_mode:
|
||||
# single object mode
|
||||
self.video_names = sorted(
|
||||
[
|
||||
os.path.join(video_name, obj)
|
||||
for video_name in self.video_names
|
||||
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
|
||||
]
|
||||
)
|
||||
|
||||
if frames_sampling_mult:
|
||||
video_names_mult = []
|
||||
for video_name in self.video_names:
|
||||
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
|
||||
video_names_mult.extend([video_name] * num_frames)
|
||||
self.video_names = video_names_mult
|
||||
|
||||
def get_video(self, idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[idx]
|
||||
|
||||
if self.single_object_mode:
|
||||
video_frame_root = os.path.join(
|
||||
self.img_folder, os.path.dirname(video_name)
|
||||
)
|
||||
else:
|
||||
video_frame_root = os.path.join(self.img_folder, video_name)
|
||||
|
||||
video_mask_root = os.path.join(self.gt_folder, video_name)
|
||||
|
||||
if self.is_palette:
|
||||
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
|
||||
else:
|
||||
segment_loader = MultiplePNGSegmentLoader(
|
||||
video_mask_root, self.single_object_mode
|
||||
)
|
||||
|
||||
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
|
||||
if self.truncate_video > 0:
|
||||
all_frames = all_frames[: self.truncate_video]
|
||||
frames = []
|
||||
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
|
||||
fid = int(os.path.basename(fpath).split(".")[0])
|
||||
frames.append(VOSFrame(fid, image_path=fpath))
|
||||
video = VOSVideo(video_name, idx, frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
|
||||
class SA1BRawDataset(VOSRawDataset):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
num_frames=1,
|
||||
mask_area_frac_thresh=1.1, # no filtering by default
|
||||
uncertain_iou=-1, # no filtering by default
|
||||
):
|
||||
self.img_folder = img_folder
|
||||
self.gt_folder = gt_folder
|
||||
self.num_frames = num_frames
|
||||
self.mask_area_frac_thresh = mask_area_frac_thresh
|
||||
self.uncertain_iou = uncertain_iou # stability score
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
subset = [
|
||||
path.split(".")[0] for path in subset if path.endswith(".jpg")
|
||||
] # remove extension
|
||||
|
||||
# Read and process excluded files if provided
|
||||
if excluded_videos_list_txt is not None:
|
||||
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
excluded_files = []
|
||||
|
||||
# Check if it's not in excluded_files and it exists
|
||||
self.video_names = [
|
||||
video_name for video_name in subset if video_name not in excluded_files
|
||||
]
|
||||
|
||||
def get_video(self, idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[idx]
|
||||
|
||||
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
|
||||
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
|
||||
|
||||
segment_loader = SA1BSegmentLoader(
|
||||
video_mask_path,
|
||||
mask_area_frac_thresh=self.mask_area_frac_thresh,
|
||||
video_frame_path=video_frame_path,
|
||||
uncertain_iou=self.uncertain_iou,
|
||||
)
|
||||
|
||||
frames = []
|
||||
for frame_idx in range(self.num_frames):
|
||||
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
|
||||
video_name = video_name.split("_")[-1] # filename is sa_{int}
|
||||
# video id needs to be image_id to be able to load correct annotation file during eval
|
||||
video = VOSVideo(video_name, int(video_name), frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
|
||||
class JSONRawDataset(VOSRawDataset):
|
||||
"""
|
||||
Dataset where the annotation in the format of SA-V json files
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
sample_rate=1,
|
||||
rm_unannotated=True,
|
||||
ann_every=1,
|
||||
frames_fps=24,
|
||||
):
|
||||
self.gt_folder = gt_folder
|
||||
self.img_folder = img_folder
|
||||
self.sample_rate = sample_rate
|
||||
self.rm_unannotated = rm_unannotated
|
||||
self.ann_every = ann_every
|
||||
self.frames_fps = frames_fps
|
||||
|
||||
# Read and process excluded files if provided
|
||||
excluded_files = []
|
||||
if excluded_videos_list_txt is not None:
|
||||
if isinstance(excluded_videos_list_txt, str):
|
||||
excluded_videos_lists = [excluded_videos_list_txt]
|
||||
elif isinstance(excluded_videos_list_txt, ListConfig):
|
||||
excluded_videos_lists = list(excluded_videos_list_txt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for excluded_videos_list_txt in excluded_videos_lists:
|
||||
with open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files.extend(
|
||||
[os.path.splitext(line.strip())[0] for line in f]
|
||||
)
|
||||
excluded_files = set(excluded_files)
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
|
||||
self.video_names = sorted(
|
||||
[video_name for video_name in subset if video_name not in excluded_files]
|
||||
)
|
||||
|
||||
def get_video(self, video_idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[video_idx]
|
||||
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
|
||||
segment_loader = JSONSegmentLoader(
|
||||
video_json_path=video_json_path,
|
||||
ann_every=self.ann_every,
|
||||
frames_fps=self.frames_fps,
|
||||
)
|
||||
|
||||
frame_ids = [
|
||||
int(os.path.splitext(frame_name)[0])
|
||||
for frame_name in sorted(
|
||||
os.listdir(os.path.join(self.img_folder, video_name))
|
||||
)
|
||||
]
|
||||
|
||||
frames = [
|
||||
VOSFrame(
|
||||
frame_id,
|
||||
image_path=os.path.join(
|
||||
self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
|
||||
),
|
||||
)
|
||||
for frame_id in frame_ids[:: self.sample_rate]
|
||||
]
|
||||
|
||||
if self.rm_unannotated:
|
||||
# Eliminate the frames that have not been annotated
|
||||
valid_frame_ids = [
|
||||
i * segment_loader.ann_every
|
||||
for i, annot in enumerate(segment_loader.frame_annots)
|
||||
if annot is not None and None not in annot
|
||||
]
|
||||
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
|
||||
|
||||
video = VOSVideo(video_name, video_idx, frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
105
training/dataset/vos_sampler.py
Normal file
105
training/dataset/vos_sampler.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from training.dataset.vos_segment_loader import LazySegments
|
||||
|
||||
MAX_RETRIES = 1000
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampledFramesAndObjects:
|
||||
frames: List[int]
|
||||
object_ids: List[int]
|
||||
|
||||
|
||||
class VOSSampler:
|
||||
def __init__(self, sort_frames=True):
|
||||
# frames are ordered by frame id when sort_frames is True
|
||||
self.sort_frames = sort_frames
|
||||
|
||||
def sample(self, video):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RandomUniformSampler(VOSSampler):
|
||||
def __init__(
|
||||
self,
|
||||
num_frames,
|
||||
max_num_objects,
|
||||
reverse_time_prob=0.0,
|
||||
):
|
||||
self.num_frames = num_frames
|
||||
self.max_num_objects = max_num_objects
|
||||
self.reverse_time_prob = reverse_time_prob
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
if len(video.frames) < self.num_frames:
|
||||
raise Exception(
|
||||
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
|
||||
)
|
||||
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
|
||||
frames = [video.frames[start + step] for step in range(self.num_frames)]
|
||||
if random.uniform(0, 1) < self.reverse_time_prob:
|
||||
# Reverse time
|
||||
frames = frames[::-1]
|
||||
|
||||
# Get first frame object ids
|
||||
visible_object_ids = []
|
||||
loaded_segms = segment_loader.load(frames[0].frame_idx)
|
||||
if isinstance(loaded_segms, LazySegments):
|
||||
# LazySegments for SA1BRawDataset
|
||||
visible_object_ids = list(loaded_segms.keys())
|
||||
else:
|
||||
for object_id, segment in segment_loader.load(
|
||||
frames[0].frame_idx
|
||||
).items():
|
||||
if segment.sum():
|
||||
visible_object_ids.append(object_id)
|
||||
|
||||
# First frame needs to have at least a target to track
|
||||
if len(visible_object_ids) > 0:
|
||||
break
|
||||
if retry >= MAX_RETRIES - 1:
|
||||
raise Exception("No visible objects")
|
||||
|
||||
object_ids = random.sample(
|
||||
visible_object_ids,
|
||||
min(len(visible_object_ids), self.max_num_objects),
|
||||
)
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
||||
|
||||
|
||||
class EvalSampler(VOSSampler):
|
||||
"""
|
||||
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
"""
|
||||
Sampling all the frames and all the objects
|
||||
"""
|
||||
if self.sort_frames:
|
||||
# ordered by frame id
|
||||
frames = sorted(video.frames, key=lambda x: x.frame_idx)
|
||||
else:
|
||||
# use the original order
|
||||
frames = video.frames
|
||||
object_ids = segment_loader.load(frames[0].frame_idx).keys()
|
||||
if len(object_ids) == 0:
|
||||
raise Exception("First frame of the video has no objects")
|
||||
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
300
training/dataset/vos_segment_loader.py
Normal file
300
training/dataset/vos_segment_loader.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
try:
|
||||
from pycocotools import mask as mask_utils
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class JSONSegmentLoader:
|
||||
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
|
||||
# Annotations in the json are provided every ann_every th frame
|
||||
self.ann_every = ann_every
|
||||
# Ids of the objects to consider when sampling this video
|
||||
self.valid_obj_ids = valid_obj_ids
|
||||
with open(video_json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
self.frame_annots = data
|
||||
elif isinstance(data, dict):
|
||||
masklet_field_name = "masklet" if "masklet" in data else "masks"
|
||||
self.frame_annots = data[masklet_field_name]
|
||||
if "fps" in data:
|
||||
if isinstance(data["fps"], list):
|
||||
annotations_fps = int(data["fps"][0])
|
||||
else:
|
||||
annotations_fps = int(data["fps"])
|
||||
assert frames_fps % annotations_fps == 0
|
||||
self.ann_every = frames_fps // annotations_fps
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, frame_id, obj_ids=None):
|
||||
assert frame_id % self.ann_every == 0
|
||||
rle_mask = self.frame_annots[frame_id // self.ann_every]
|
||||
|
||||
valid_objs_ids = set(range(len(rle_mask)))
|
||||
if self.valid_obj_ids is not None:
|
||||
# Remove the masklets that have been filtered out for this video
|
||||
valid_objs_ids &= set(self.valid_obj_ids)
|
||||
if obj_ids is not None:
|
||||
# Only keep the objects that have been sampled
|
||||
valid_objs_ids &= set(obj_ids)
|
||||
valid_objs_ids = sorted(list(valid_objs_ids))
|
||||
|
||||
# Construct rle_masks_filtered that only contains the rle masks we are interested in
|
||||
id_2_idx = {}
|
||||
rle_mask_filtered = []
|
||||
for obj_id in valid_objs_ids:
|
||||
if rle_mask[obj_id] is not None:
|
||||
id_2_idx[obj_id] = len(rle_mask_filtered)
|
||||
rle_mask_filtered.append(rle_mask[obj_id])
|
||||
else:
|
||||
id_2_idx[obj_id] = None
|
||||
|
||||
# Decode the masks
|
||||
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
|
||||
2, 0, 1
|
||||
) # (num_obj, h, w)
|
||||
segments = {}
|
||||
for obj_id in valid_objs_ids:
|
||||
if id_2_idx[obj_id] is None:
|
||||
segments[obj_id] = None
|
||||
else:
|
||||
idx = id_2_idx[obj_id]
|
||||
segments[obj_id] = raw_segments[idx]
|
||||
return segments
|
||||
|
||||
def get_valid_obj_frames_ids(self, num_frames_min=None):
|
||||
# For each object, find all the frames with a valid (not None) mask
|
||||
num_objects = len(self.frame_annots[0])
|
||||
|
||||
# The result dict associates each obj_id with the id of its valid frames
|
||||
res = {obj_id: [] for obj_id in range(num_objects)}
|
||||
|
||||
for annot_idx, annot in enumerate(self.frame_annots):
|
||||
for obj_id in range(num_objects):
|
||||
if annot[obj_id] is not None:
|
||||
res[obj_id].append(int(annot_idx * self.ann_every))
|
||||
|
||||
if num_frames_min is not None:
|
||||
# Remove masklets that have less than num_frames_min valid masks
|
||||
for obj_id, valid_frames in list(res.items()):
|
||||
if len(valid_frames) < num_frames_min:
|
||||
res.pop(obj_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PalettisedPNGSegmentLoader:
|
||||
def __init__(self, video_png_root):
|
||||
"""
|
||||
SegmentLoader for datasets with masks stored as palettised PNGs.
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
# build a mapping from frame id to their PNG mask path
|
||||
# note that in some datasets, the PNG paths could have more
|
||||
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
|
||||
png_filenames = os.listdir(self.video_png_root)
|
||||
self.frame_id_to_png_filename = {}
|
||||
for filename in png_filenames:
|
||||
frame_id, _ = os.path.splitext(filename)
|
||||
self.frame_id_to_png_filename[int(frame_id)] = filename
|
||||
|
||||
def load(self, frame_id):
|
||||
"""
|
||||
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# check the path
|
||||
mask_path = os.path.join(
|
||||
self.video_png_root, self.frame_id_to_png_filename[frame_id]
|
||||
)
|
||||
|
||||
# load the mask
|
||||
masks = PILImage.open(mask_path).convert("P")
|
||||
masks = np.array(masks)
|
||||
|
||||
object_id = pd.unique(masks.flatten())
|
||||
object_id = object_id[object_id != 0] # remove background (0)
|
||||
|
||||
# convert into N binary segmentation masks
|
||||
binary_segments = {}
|
||||
for i in object_id:
|
||||
bs = masks == i
|
||||
binary_segments[i] = torch.from_numpy(bs)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class MultiplePNGSegmentLoader:
|
||||
def __init__(self, video_png_root, single_object_mode=False):
|
||||
"""
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
single_object_mode: whether to load only a single object at a time
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
self.single_object_mode = single_object_mode
|
||||
# read a mask to know the resolution of the video
|
||||
if self.single_object_mode:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
|
||||
else:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
|
||||
tmp_mask = np.array(PILImage.open(tmp_mask_path))
|
||||
self.H = tmp_mask.shape[0]
|
||||
self.W = tmp_mask.shape[1]
|
||||
if self.single_object_mode:
|
||||
self.obj_id = (
|
||||
int(video_png_root.split("/")[-1]) + 1
|
||||
) # offset by 1 as bg is 0
|
||||
else:
|
||||
self.obj_id = None
|
||||
|
||||
def load(self, frame_id):
|
||||
if self.single_object_mode:
|
||||
return self._load_single_png(frame_id)
|
||||
else:
|
||||
return self._load_multiple_pngs(frame_id)
|
||||
|
||||
def _load_single_png(self, frame_id):
|
||||
"""
|
||||
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
|
||||
binary_segments = {}
|
||||
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
# if png doesn't exist, empty mask
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
|
||||
return binary_segments
|
||||
|
||||
def _load_multiple_pngs(self, frame_id):
|
||||
"""
|
||||
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# get the path
|
||||
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
|
||||
num_objects = len(all_objects)
|
||||
assert num_objects > 0
|
||||
|
||||
# load the masks
|
||||
binary_segments = {}
|
||||
for obj_folder in all_objects:
|
||||
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
|
||||
obj_id = int(obj_folder.split("/")[-1])
|
||||
obj_id = obj_id + 1 # offset 1 as bg is 0
|
||||
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[obj_id] = torch.from_numpy(mask > 0)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class LazySegments:
|
||||
"""
|
||||
Only decodes segments that are actually used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.segments = {}
|
||||
self.cache = {}
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
self.segments[key] = item
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
rle = self.segments[key]
|
||||
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
|
||||
self.cache[key] = mask
|
||||
return mask
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.segments
|
||||
|
||||
def __len__(self):
|
||||
return len(self.segments)
|
||||
|
||||
def keys(self):
|
||||
return self.segments.keys()
|
||||
|
||||
|
||||
class SA1BSegmentLoader:
|
||||
def __init__(
|
||||
self,
|
||||
video_mask_path,
|
||||
mask_area_frac_thresh=1.1,
|
||||
video_frame_path=None,
|
||||
uncertain_iou=-1,
|
||||
):
|
||||
with open(video_mask_path, "r") as f:
|
||||
self.frame_annots = json.load(f)
|
||||
|
||||
if mask_area_frac_thresh <= 1.0:
|
||||
# Lazily read frame
|
||||
orig_w, orig_h = PILImage.open(video_frame_path).size
|
||||
area = orig_w * orig_h
|
||||
|
||||
self.frame_annots = self.frame_annots["annotations"]
|
||||
|
||||
rle_masks = []
|
||||
for frame_annot in self.frame_annots:
|
||||
if not frame_annot["area"] > 0:
|
||||
continue
|
||||
if ("uncertain_iou" in frame_annot) and (
|
||||
frame_annot["uncertain_iou"] < uncertain_iou
|
||||
):
|
||||
# uncertain_iou is stability score
|
||||
continue
|
||||
if (
|
||||
mask_area_frac_thresh <= 1.0
|
||||
and (frame_annot["area"] / area) >= mask_area_frac_thresh
|
||||
):
|
||||
continue
|
||||
rle_masks.append(frame_annot["segmentation"])
|
||||
|
||||
self.segments = LazySegments()
|
||||
for i, rle in enumerate(rle_masks):
|
||||
self.segments[i] = rle
|
||||
|
||||
def load(self, frame_idx):
|
||||
return self.segments
|
307
training/loss_fns.py
Normal file
307
training/loss_fns.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from training.trainer import CORE_LOSS_KEY
|
||||
|
||||
from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
|
||||
|
||||
|
||||
def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
|
||||
"""
|
||||
Compute the DICE loss, similar to generalized IOU for masks
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs
|
||||
(0 for the negative class and 1 for the positive class).
|
||||
num_objects: Number of objects in the batch
|
||||
loss_on_multimask: True if multimask prediction is enabled
|
||||
Returns:
|
||||
Dice loss tensor
|
||||
"""
|
||||
inputs = inputs.sigmoid()
|
||||
if loss_on_multimask:
|
||||
# inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
|
||||
assert inputs.dim() == 4 and targets.dim() == 4
|
||||
# flatten spatial dimension while keeping multimask channel dimension
|
||||
inputs = inputs.flatten(2)
|
||||
targets = targets.flatten(2)
|
||||
numerator = 2 * (inputs * targets).sum(-1)
|
||||
else:
|
||||
inputs = inputs.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
if loss_on_multimask:
|
||||
return loss / num_objects
|
||||
return loss.sum() / num_objects
|
||||
|
||||
|
||||
def sigmoid_focal_loss(
|
||||
inputs,
|
||||
targets,
|
||||
num_objects,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2,
|
||||
loss_on_multimask=False,
|
||||
):
|
||||
"""
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs
|
||||
(0 for the negative class and 1 for the positive class).
|
||||
num_objects: Number of objects in the batch
|
||||
alpha: (optional) Weighting factor in range (0,1) to balance
|
||||
positive vs negative examples. Default = -1 (no weighting).
|
||||
gamma: Exponent of the modulating factor (1 - p_t) to
|
||||
balance easy vs hard examples.
|
||||
loss_on_multimask: True if multimask prediction is enabled
|
||||
Returns:
|
||||
focal loss tensor
|
||||
"""
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
if loss_on_multimask:
|
||||
# loss is [N, M, H, W] where M corresponds to multiple predicted masks
|
||||
assert loss.dim() == 4
|
||||
return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
|
||||
return loss.mean(1).sum() / num_objects
|
||||
|
||||
|
||||
def iou_loss(
|
||||
inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs
|
||||
(0 for the negative class and 1 for the positive class).
|
||||
pred_ious: A float tensor containing the predicted IoUs scores per mask
|
||||
num_objects: Number of objects in the batch
|
||||
loss_on_multimask: True if multimask prediction is enabled
|
||||
use_l1_loss: Whether to use L1 loss is used instead of MSE loss
|
||||
Returns:
|
||||
IoU loss tensor
|
||||
"""
|
||||
assert inputs.dim() == 4 and targets.dim() == 4
|
||||
pred_mask = inputs.flatten(2) > 0
|
||||
gt_mask = targets.flatten(2) > 0
|
||||
area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
|
||||
area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
|
||||
actual_ious = area_i / torch.clamp(area_u, min=1.0)
|
||||
|
||||
if use_l1_loss:
|
||||
loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
|
||||
else:
|
||||
loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
|
||||
if loss_on_multimask:
|
||||
return loss / num_objects
|
||||
return loss.sum() / num_objects
|
||||
|
||||
|
||||
class MultiStepMultiMasksAndIous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight_dict,
|
||||
focal_alpha=0.25,
|
||||
focal_gamma=2,
|
||||
supervise_all_iou=False,
|
||||
iou_use_l1_loss=False,
|
||||
pred_obj_scores=False,
|
||||
focal_gamma_obj_score=0.0,
|
||||
focal_alpha_obj_score=-1,
|
||||
):
|
||||
"""
|
||||
This class computes the multi-step multi-mask and IoU losses.
|
||||
Args:
|
||||
weight_dict: dict containing weights for focal, dice, iou losses
|
||||
focal_alpha: alpha for sigmoid focal loss
|
||||
focal_gamma: gamma for sigmoid focal loss
|
||||
supervise_all_iou: if True, back-prop iou losses for all predicted masks
|
||||
iou_use_l1_loss: use L1 loss instead of MSE loss for iou
|
||||
pred_obj_scores: if True, compute loss for object scores
|
||||
focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
|
||||
focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.weight_dict = weight_dict
|
||||
self.focal_alpha = focal_alpha
|
||||
self.focal_gamma = focal_gamma
|
||||
assert "loss_mask" in self.weight_dict
|
||||
assert "loss_dice" in self.weight_dict
|
||||
assert "loss_iou" in self.weight_dict
|
||||
if "loss_class" not in self.weight_dict:
|
||||
self.weight_dict["loss_class"] = 0.0
|
||||
|
||||
self.focal_alpha_obj_score = focal_alpha_obj_score
|
||||
self.focal_gamma_obj_score = focal_gamma_obj_score
|
||||
self.supervise_all_iou = supervise_all_iou
|
||||
self.iou_use_l1_loss = iou_use_l1_loss
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
|
||||
def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
|
||||
assert len(outs_batch) == len(targets_batch)
|
||||
num_objects = torch.tensor(
|
||||
(targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
|
||||
) # Number of objects is fixed within a batch
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_objects)
|
||||
num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
|
||||
|
||||
losses = defaultdict(int)
|
||||
for outs, targets in zip(outs_batch, targets_batch):
|
||||
cur_losses = self._forward(outs, targets, num_objects)
|
||||
for k, v in cur_losses.items():
|
||||
losses[k] += v
|
||||
|
||||
return losses
|
||||
|
||||
def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
|
||||
"""
|
||||
Compute the losses related to the masks: the focal loss and the dice loss.
|
||||
and also the MAE or MSE loss between predicted IoUs and actual IoUs.
|
||||
|
||||
Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
|
||||
of shape [N, M, H, W], where M could be 1 or larger, corresponding to
|
||||
one or multiple predicted masks from a click.
|
||||
|
||||
We back-propagate focal, dice losses only on the prediction channel
|
||||
with the lowest focal+dice loss between predicted mask and ground-truth.
|
||||
If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
|
||||
"""
|
||||
|
||||
target_masks = targets.unsqueeze(1).float()
|
||||
assert target_masks.dim() == 4 # [N, 1, H, W]
|
||||
src_masks_list = outputs["multistep_pred_multimasks_high_res"]
|
||||
ious_list = outputs["multistep_pred_ious"]
|
||||
object_score_logits_list = outputs["multistep_object_score_logits"]
|
||||
|
||||
assert len(src_masks_list) == len(ious_list)
|
||||
assert len(object_score_logits_list) == len(ious_list)
|
||||
|
||||
# accumulate the loss over prediction steps
|
||||
losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
|
||||
for src_masks, ious, object_score_logits in zip(
|
||||
src_masks_list, ious_list, object_score_logits_list
|
||||
):
|
||||
self._update_losses(
|
||||
losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
||||
)
|
||||
losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
|
||||
return losses
|
||||
|
||||
def _update_losses(
|
||||
self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
||||
):
|
||||
target_masks = target_masks.expand_as(src_masks)
|
||||
# get focal, dice and iou loss on all output masks in a prediction step
|
||||
loss_multimask = sigmoid_focal_loss(
|
||||
src_masks,
|
||||
target_masks,
|
||||
num_objects,
|
||||
alpha=self.focal_alpha,
|
||||
gamma=self.focal_gamma,
|
||||
loss_on_multimask=True,
|
||||
)
|
||||
loss_multidice = dice_loss(
|
||||
src_masks, target_masks, num_objects, loss_on_multimask=True
|
||||
)
|
||||
if not self.pred_obj_scores:
|
||||
loss_class = torch.tensor(
|
||||
0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
|
||||
)
|
||||
target_obj = torch.ones(
|
||||
loss_multimask.shape[0],
|
||||
1,
|
||||
dtype=loss_multimask.dtype,
|
||||
device=loss_multimask.device,
|
||||
)
|
||||
else:
|
||||
target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
|
||||
..., None
|
||||
].float()
|
||||
loss_class = sigmoid_focal_loss(
|
||||
object_score_logits,
|
||||
target_obj,
|
||||
num_objects,
|
||||
alpha=self.focal_alpha_obj_score,
|
||||
gamma=self.focal_gamma_obj_score,
|
||||
)
|
||||
|
||||
loss_multiiou = iou_loss(
|
||||
src_masks,
|
||||
target_masks,
|
||||
ious,
|
||||
num_objects,
|
||||
loss_on_multimask=True,
|
||||
use_l1_loss=self.iou_use_l1_loss,
|
||||
)
|
||||
assert loss_multimask.dim() == 2
|
||||
assert loss_multidice.dim() == 2
|
||||
assert loss_multiiou.dim() == 2
|
||||
if loss_multimask.size(1) > 1:
|
||||
# take the mask indices with the smallest focal + dice loss for back propagation
|
||||
loss_combo = (
|
||||
loss_multimask * self.weight_dict["loss_mask"]
|
||||
+ loss_multidice * self.weight_dict["loss_dice"]
|
||||
)
|
||||
best_loss_inds = torch.argmin(loss_combo, dim=-1)
|
||||
batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
|
||||
loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
|
||||
loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
|
||||
# calculate the iou prediction and slot losses only in the index
|
||||
# with the minimum loss for each mask (to be consistent w/ SAM)
|
||||
if self.supervise_all_iou:
|
||||
loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
|
||||
else:
|
||||
loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
|
||||
else:
|
||||
loss_mask = loss_multimask
|
||||
loss_dice = loss_multidice
|
||||
loss_iou = loss_multiiou
|
||||
|
||||
# backprop focal, dice and iou loss only if obj present
|
||||
loss_mask = loss_mask * target_obj
|
||||
loss_dice = loss_dice * target_obj
|
||||
loss_iou = loss_iou * target_obj
|
||||
|
||||
# sum over batch dimension (note that the losses are already divided by num_objects)
|
||||
losses["loss_mask"] += loss_mask.sum()
|
||||
losses["loss_dice"] += loss_dice.sum()
|
||||
losses["loss_iou"] += loss_iou.sum()
|
||||
losses["loss_class"] += loss_class
|
||||
|
||||
def reduce_loss(self, losses):
|
||||
reduced_loss = 0.0
|
||||
for loss_key, weight in self.weight_dict.items():
|
||||
if loss_key not in losses:
|
||||
raise ValueError(f"{type(self)} doesn't compute {loss_key}")
|
||||
if weight != 0:
|
||||
reduced_loss += losses[loss_key] * weight
|
||||
|
||||
return reduced_loss
|
5
training/model/__init__.py
Normal file
5
training/model/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
541
training/model/sam2.py
Normal file
541
training/model/sam2.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2.modeling.sam2_utils import (
|
||||
get_1d_sine_pe,
|
||||
get_next_point,
|
||||
sample_box_points,
|
||||
select_closest_cond_frames,
|
||||
)
|
||||
|
||||
from sam2.utils.misc import concat_points
|
||||
|
||||
from training.utils.data_utils import BatchedVideoDatapoint
|
||||
|
||||
|
||||
class SAM2Train(SAM2Base):
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder,
|
||||
memory_attention=None,
|
||||
memory_encoder=None,
|
||||
prob_to_use_pt_input_for_train=0.0,
|
||||
prob_to_use_pt_input_for_eval=0.0,
|
||||
prob_to_use_box_input_for_train=0.0,
|
||||
prob_to_use_box_input_for_eval=0.0,
|
||||
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
|
||||
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
|
||||
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
|
||||
rand_frames_to_correct_for_train=False,
|
||||
rand_frames_to_correct_for_eval=False,
|
||||
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
|
||||
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
|
||||
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
|
||||
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
|
||||
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
|
||||
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
|
||||
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
|
||||
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
|
||||
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
|
||||
rand_init_cond_frames_for_eval=False,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
# how many additional correction points to sample (on each frame selected to be corrected)
|
||||
# note that the first frame receives an initial input click (in addition to any correction clicks)
|
||||
num_correction_pt_per_frame=7,
|
||||
# method for point sampling during evaluation
|
||||
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
|
||||
# default to "center" to be consistent with evaluation in the SAM paper
|
||||
pt_sampling_for_eval="center",
|
||||
# During training, we optionally allow sampling the correction points from GT regions
|
||||
# instead of the prediction error regions with a small probability. This might allow the
|
||||
# model to overfit less to the error regions in training datasets
|
||||
prob_to_sample_from_gt_for_train=0.0,
|
||||
use_act_ckpt_iterative_pt_sampling=False,
|
||||
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
|
||||
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
|
||||
forward_backbone_per_frame_for_eval=False,
|
||||
freeze_image_encoder=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
|
||||
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
|
||||
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
|
||||
|
||||
# Point sampler and conditioning frames
|
||||
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
|
||||
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
|
||||
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
|
||||
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
|
||||
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
|
||||
logging.info(
|
||||
f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
|
||||
)
|
||||
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
|
||||
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
|
||||
|
||||
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
|
||||
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
|
||||
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
|
||||
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
|
||||
# Initial multi-conditioning frames
|
||||
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
|
||||
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
|
||||
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
|
||||
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
self.num_correction_pt_per_frame = num_correction_pt_per_frame
|
||||
self.pt_sampling_for_eval = pt_sampling_for_eval
|
||||
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
|
||||
# A random number generator with a fixed initial seed across GPUs
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
if freeze_image_encoder:
|
||||
for p in self.image_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, input: BatchedVideoDatapoint):
|
||||
if self.training or not self.forward_backbone_per_frame_for_eval:
|
||||
# precompute image features on all frames before tracking
|
||||
backbone_out = self.forward_image(input.flat_img_batch)
|
||||
else:
|
||||
# defer image feature computation on a frame until it's being tracked
|
||||
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
|
||||
backbone_out = self.prepare_prompt_inputs(backbone_out, input)
|
||||
previous_stages_out = self.forward_tracking(backbone_out, input)
|
||||
|
||||
return previous_stages_out
|
||||
|
||||
def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
|
||||
"""Compute the image backbone features on the fly for the given img_ids."""
|
||||
# Only forward backbone on unique image ids to avoid repetitive computation
|
||||
# (if `img_ids` has only one element, it's already unique so we skip this step).
|
||||
if img_ids.numel() > 1:
|
||||
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
|
||||
else:
|
||||
unique_img_ids, inv_ids = img_ids, None
|
||||
|
||||
# Compute the image features on those unique image ids
|
||||
image = img_batch[unique_img_ids]
|
||||
backbone_out = self.forward_image(image)
|
||||
(
|
||||
_,
|
||||
vision_feats,
|
||||
vision_pos_embeds,
|
||||
feat_sizes,
|
||||
) = self._prepare_backbone_features(backbone_out)
|
||||
# Inverse-map image features for `unique_img_ids` to the final image features
|
||||
# for the original input `img_ids`.
|
||||
if inv_ids is not None:
|
||||
image = image[inv_ids]
|
||||
vision_feats = [x[:, inv_ids] for x in vision_feats]
|
||||
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
|
||||
|
||||
return image, vision_feats, vision_pos_embeds, feat_sizes
|
||||
|
||||
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
|
||||
"""
|
||||
Prepare input mask, point or box prompts. Optionally, we allow tracking from
|
||||
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
|
||||
"""
|
||||
# Load the ground-truth masks on all frames (so that we can later
|
||||
# sample correction points from them)
|
||||
# gt_masks_per_frame = {
|
||||
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
|
||||
# for stage_id, targets in enumerate(input.find_targets)
|
||||
# }
|
||||
gt_masks_per_frame = {
|
||||
stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
|
||||
for stage_id, masks in enumerate(input.masks)
|
||||
}
|
||||
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
|
||||
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
|
||||
num_frames = input.num_frames
|
||||
backbone_out["num_frames"] = num_frames
|
||||
|
||||
# Randomly decide whether to use point inputs or mask inputs
|
||||
if self.training:
|
||||
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
|
||||
prob_to_use_box_input = self.prob_to_use_box_input_for_train
|
||||
num_frames_to_correct = self.num_frames_to_correct_for_train
|
||||
rand_frames_to_correct = self.rand_frames_to_correct_for_train
|
||||
num_init_cond_frames = self.num_init_cond_frames_for_train
|
||||
rand_init_cond_frames = self.rand_init_cond_frames_for_train
|
||||
else:
|
||||
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
|
||||
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
|
||||
num_frames_to_correct = self.num_frames_to_correct_for_eval
|
||||
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
|
||||
num_init_cond_frames = self.num_init_cond_frames_for_eval
|
||||
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
|
||||
if num_frames == 1:
|
||||
# here we handle a special case for mixing video + SAM on image training,
|
||||
# where we force using point input for the SAM task on static images
|
||||
prob_to_use_pt_input = 1.0
|
||||
num_frames_to_correct = 1
|
||||
num_init_cond_frames = 1
|
||||
assert num_init_cond_frames >= 1
|
||||
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
|
||||
use_pt_input = self.rng.random() < prob_to_use_pt_input
|
||||
if rand_init_cond_frames and num_init_cond_frames > 1:
|
||||
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
|
||||
num_init_cond_frames = self.rng.integers(
|
||||
1, num_init_cond_frames, endpoint=True
|
||||
)
|
||||
if (
|
||||
use_pt_input
|
||||
and rand_frames_to_correct
|
||||
and num_frames_to_correct > num_init_cond_frames
|
||||
):
|
||||
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
|
||||
# correction clicks (only for the case of point input)
|
||||
num_frames_to_correct = self.rng.integers(
|
||||
num_init_cond_frames, num_frames_to_correct, endpoint=True
|
||||
)
|
||||
backbone_out["use_pt_input"] = use_pt_input
|
||||
|
||||
# Sample initial conditioning frames
|
||||
if num_init_cond_frames == 1:
|
||||
init_cond_frames = [start_frame_idx] # starting frame
|
||||
else:
|
||||
# starting frame + randomly selected remaining frames (without replacement)
|
||||
init_cond_frames = [start_frame_idx] + self.rng.choice(
|
||||
range(start_frame_idx + 1, num_frames),
|
||||
num_init_cond_frames - 1,
|
||||
replace=False,
|
||||
).tolist()
|
||||
backbone_out["init_cond_frames"] = init_cond_frames
|
||||
backbone_out["frames_not_in_init_cond"] = [
|
||||
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
|
||||
]
|
||||
# Prepare mask or point inputs on initial conditioning frames
|
||||
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
|
||||
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
|
||||
for t in init_cond_frames:
|
||||
if not use_pt_input:
|
||||
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
|
||||
else:
|
||||
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
|
||||
use_box_input = self.rng.random() < prob_to_use_box_input
|
||||
if use_box_input:
|
||||
points, labels = sample_box_points(
|
||||
gt_masks_per_frame[t],
|
||||
)
|
||||
else:
|
||||
# (here we only sample **one initial point** on initial conditioning frames from the
|
||||
# ground-truth mask; we may sample more correction points on the fly)
|
||||
points, labels = get_next_point(
|
||||
gt_masks=gt_masks_per_frame[t],
|
||||
pred_masks=None,
|
||||
method=(
|
||||
"uniform" if self.training else self.pt_sampling_for_eval
|
||||
),
|
||||
)
|
||||
|
||||
point_inputs = {"point_coords": points, "point_labels": labels}
|
||||
backbone_out["point_inputs_per_frame"][t] = point_inputs
|
||||
|
||||
# Sample frames where we will add correction clicks on the fly
|
||||
# based on the error between prediction and ground-truth masks
|
||||
if not use_pt_input:
|
||||
# no correction points will be sampled when using mask inputs
|
||||
frames_to_add_correction_pt = []
|
||||
elif num_frames_to_correct == num_init_cond_frames:
|
||||
frames_to_add_correction_pt = init_cond_frames
|
||||
else:
|
||||
assert num_frames_to_correct > num_init_cond_frames
|
||||
# initial cond frame + randomly selected remaining frames (without replacement)
|
||||
extra_num = num_frames_to_correct - num_init_cond_frames
|
||||
frames_to_add_correction_pt = (
|
||||
init_cond_frames
|
||||
+ self.rng.choice(
|
||||
backbone_out["frames_not_in_init_cond"], extra_num, replace=False
|
||||
).tolist()
|
||||
)
|
||||
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
|
||||
|
||||
return backbone_out
|
||||
|
||||
def forward_tracking(
|
||||
self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
|
||||
):
|
||||
"""Forward video tracking on each frame (and sample correction clicks)."""
|
||||
img_feats_already_computed = backbone_out["backbone_fpn"] is not None
|
||||
if img_feats_already_computed:
|
||||
# Prepare the backbone features
|
||||
# - vision_feats and vision_pos_embeds are in (HW)BC format
|
||||
(
|
||||
_,
|
||||
vision_feats,
|
||||
vision_pos_embeds,
|
||||
feat_sizes,
|
||||
) = self._prepare_backbone_features(backbone_out)
|
||||
|
||||
# Starting the stage loop
|
||||
num_frames = backbone_out["num_frames"]
|
||||
init_cond_frames = backbone_out["init_cond_frames"]
|
||||
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
|
||||
# first process all the initial conditioning frames to encode them as memory,
|
||||
# and then conditioning on them to track the remaining frames
|
||||
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
|
||||
output_dict = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
for stage_id in processing_order:
|
||||
# Get the image features for the current frames
|
||||
# img_ids = input.find_inputs[stage_id].img_ids
|
||||
img_ids = input.flat_obj_to_img_idx[stage_id]
|
||||
if img_feats_already_computed:
|
||||
# Retrieve image features according to img_ids (if they are already computed).
|
||||
current_vision_feats = [x[:, img_ids] for x in vision_feats]
|
||||
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
|
||||
else:
|
||||
# Otherwise, compute the image features on the fly for the given img_ids
|
||||
# (this might be used for evaluation on long videos to avoid backbone OOM).
|
||||
(
|
||||
_,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
) = self._prepare_backbone_features_per_frame(
|
||||
input.flat_img_batch, img_ids
|
||||
)
|
||||
|
||||
# Get output masks based on this frame's prompts and previous memory
|
||||
current_out = self.track_step(
|
||||
frame_idx=stage_id,
|
||||
is_init_cond_frame=stage_id in init_cond_frames,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
|
||||
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
|
||||
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
|
||||
frames_to_add_correction_pt=frames_to_add_correction_pt,
|
||||
output_dict=output_dict,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
# Append the output, depending on whether it's a conditioning frame
|
||||
add_output_as_cond_frame = stage_id in init_cond_frames or (
|
||||
self.add_all_frames_to_correct_as_cond
|
||||
and stage_id in frames_to_add_correction_pt
|
||||
)
|
||||
if add_output_as_cond_frame:
|
||||
output_dict["cond_frame_outputs"][stage_id] = current_out
|
||||
else:
|
||||
output_dict["non_cond_frame_outputs"][stage_id] = current_out
|
||||
|
||||
if return_dict:
|
||||
return output_dict
|
||||
# turn `output_dict` into a list for loss function
|
||||
all_frame_outputs = {}
|
||||
all_frame_outputs.update(output_dict["cond_frame_outputs"])
|
||||
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
|
||||
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
|
||||
# Make DDP happy with activation checkpointing by removing unused keys
|
||||
all_frame_outputs = [
|
||||
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
|
||||
]
|
||||
|
||||
return all_frame_outputs
|
||||
|
||||
def track_step(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
||||
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
|
||||
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
|
||||
frames_to_add_correction_pt=None,
|
||||
gt_masks=None,
|
||||
):
|
||||
if frames_to_add_correction_pt is None:
|
||||
frames_to_add_correction_pt = []
|
||||
current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse,
|
||||
prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
(
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
) = sam_outputs
|
||||
|
||||
current_out["multistep_pred_masks"] = low_res_masks
|
||||
current_out["multistep_pred_masks_high_res"] = high_res_masks
|
||||
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
|
||||
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
|
||||
current_out["multistep_pred_ious"] = [ious]
|
||||
current_out["multistep_point_inputs"] = [point_inputs]
|
||||
current_out["multistep_object_score_logits"] = [object_score_logits]
|
||||
|
||||
# Optionally, sample correction points iteratively to correct the mask
|
||||
if frame_idx in frames_to_add_correction_pt:
|
||||
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
|
||||
is_init_cond_frame,
|
||||
point_inputs,
|
||||
gt_masks,
|
||||
high_res_features,
|
||||
pix_feat,
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
current_out,
|
||||
)
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
) = final_sam_outputs
|
||||
|
||||
# Use the final prediction (after all correction steps for output and eval)
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
current_out["obj_ptr"] = obj_ptr
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
self._encode_memory_in_output(
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
run_mem_encoder,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
current_out,
|
||||
)
|
||||
return current_out
|
||||
|
||||
def _iter_correct_pt_sampling(
|
||||
self,
|
||||
is_init_cond_frame,
|
||||
point_inputs,
|
||||
gt_masks,
|
||||
high_res_features,
|
||||
pix_feat_with_mem,
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
current_out,
|
||||
):
|
||||
|
||||
assert gt_masks is not None
|
||||
all_pred_masks = [low_res_masks]
|
||||
all_pred_high_res_masks = [high_res_masks]
|
||||
all_pred_multimasks = [low_res_multimasks]
|
||||
all_pred_high_res_multimasks = [high_res_multimasks]
|
||||
all_pred_ious = [ious]
|
||||
all_point_inputs = [point_inputs]
|
||||
all_object_score_logits = [object_score_logits]
|
||||
for _ in range(self.num_correction_pt_per_frame):
|
||||
# sample a new point from the error between prediction and ground-truth
|
||||
# (with a small probability, directly sample from GT masks instead of errors)
|
||||
if self.training and self.prob_to_sample_from_gt_for_train > 0:
|
||||
sample_from_gt = (
|
||||
self.rng.random() < self.prob_to_sample_from_gt_for_train
|
||||
)
|
||||
else:
|
||||
sample_from_gt = False
|
||||
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
|
||||
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
|
||||
new_points, new_labels = get_next_point(
|
||||
gt_masks=gt_masks,
|
||||
pred_masks=pred_for_new_pt,
|
||||
method="uniform" if self.training else self.pt_sampling_for_eval,
|
||||
)
|
||||
point_inputs = concat_points(point_inputs, new_points, new_labels)
|
||||
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
|
||||
# For tracking, this means that when the user adds a correction click, we also feed
|
||||
# the tracking output mask logits along with the click as input to the SAM decoder.
|
||||
mask_inputs = low_res_masks
|
||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
||||
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
|
||||
sam_outputs = torch.utils.checkpoint.checkpoint(
|
||||
self._forward_sam_heads,
|
||||
backbone_features=pix_feat_with_mem,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
sam_outputs = self._forward_sam_heads(
|
||||
backbone_features=pix_feat_with_mem,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
(
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
_,
|
||||
object_score_logits,
|
||||
) = sam_outputs
|
||||
all_pred_masks.append(low_res_masks)
|
||||
all_pred_high_res_masks.append(high_res_masks)
|
||||
all_pred_multimasks.append(low_res_multimasks)
|
||||
all_pred_high_res_multimasks.append(high_res_multimasks)
|
||||
all_pred_ious.append(ious)
|
||||
all_point_inputs.append(point_inputs)
|
||||
all_object_score_logits.append(object_score_logits)
|
||||
|
||||
# Concatenate the masks along channel (to compute losses on all of them,
|
||||
# using `MultiStepIteractiveMasks`)
|
||||
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
|
||||
current_out["multistep_pred_masks_high_res"] = torch.cat(
|
||||
all_pred_high_res_masks, dim=1
|
||||
)
|
||||
current_out["multistep_pred_multimasks"] = all_pred_multimasks
|
||||
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
|
||||
current_out["multistep_pred_ious"] = all_pred_ious
|
||||
current_out["multistep_point_inputs"] = all_point_inputs
|
||||
current_out["multistep_object_score_logits"] = all_object_score_logits
|
||||
|
||||
return point_inputs, sam_outputs
|
502
training/optimizer.py
Normal file
502
training/optimizer.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import fnmatch
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, optimizer, schedulers=None) -> None:
|
||||
self.optimizer = optimizer
|
||||
self.schedulers = schedulers
|
||||
self._validate_optimizer_schedulers()
|
||||
self.step_schedulers(0.0, 0)
|
||||
|
||||
def _validate_optimizer_schedulers(self):
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for _, set_of_schedulers in enumerate(self.schedulers):
|
||||
for option, _ in set_of_schedulers.items():
|
||||
assert option in self.optimizer.defaults, (
|
||||
"Optimizer option "
|
||||
f"{option} not found in {self.optimizer}. Valid options are "
|
||||
f"{self.optimizer.defaults.keys()}"
|
||||
)
|
||||
|
||||
def step_schedulers(self, where: float, step: int) -> None:
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
for option, scheduler in self.schedulers[i].items():
|
||||
if "step" in inspect.signature(scheduler.__call__).parameters:
|
||||
new_value = scheduler(step=step, where=where)
|
||||
elif (
|
||||
hasattr(scheduler, "scheduler")
|
||||
and "step"
|
||||
in inspect.signature(scheduler.scheduler.__call__).parameters
|
||||
):
|
||||
# To handle ValueScaler wrappers
|
||||
new_value = scheduler(step=step, where=where)
|
||||
else:
|
||||
new_value = scheduler(where)
|
||||
param_group[option] = new_value
|
||||
|
||||
def step(self, where, step, closure=None):
|
||||
self.step_schedulers(where, step)
|
||||
return self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
return self.optimizer.zero_grad(*args, **kwargs)
|
||||
|
||||
|
||||
def set_default_parameters(
|
||||
scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
|
||||
) -> None:
|
||||
"""Set up the "default" scheduler with the right parameters.
|
||||
|
||||
Args:
|
||||
scheduler_cgfs: A list of scheduler configs, where each scheduler also
|
||||
specifies which parameters it applies to, based on the names of parameters
|
||||
or the class of the modules. At most one scheduler is allowed to skip this
|
||||
specification, which is used as a "default" specification for any remaining
|
||||
parameters.
|
||||
all_parameter_names: Names of all the parameters to consider.
|
||||
"""
|
||||
constraints = [
|
||||
scheduler_cfg.parameter_names
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if scheduler_cfg.parameter_names is not None
|
||||
]
|
||||
if len(constraints) == 0:
|
||||
default_params = set(all_parameter_names)
|
||||
else:
|
||||
default_params = all_parameter_names - set.union(*constraints)
|
||||
default_count = 0
|
||||
for scheduler_cfg in scheduler_cfgs:
|
||||
if scheduler_cfg.parameter_names is None:
|
||||
scheduler_cfg.parameter_names = default_params
|
||||
default_count += 1
|
||||
assert default_count <= 1, "Only one scheduler per option can be default"
|
||||
if default_count == 0:
|
||||
# No default scheduler specified, add a default, but without any scheduler
|
||||
# for that option
|
||||
scheduler_cfgs.append({"parameter_names": default_params})
|
||||
|
||||
|
||||
def name_constraints_to_parameters(
|
||||
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
"""Return parameters which match the intersection of parameter constraints.
|
||||
|
||||
Note that this returns the parameters themselves, not their names.
|
||||
|
||||
Args:
|
||||
param_constraints: A list, with each element being a set of allowed parameters.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
|
||||
Returns:
|
||||
A list containing the parameters which overlap with _each_ constraint set from
|
||||
param_constraints.
|
||||
"""
|
||||
matching_names = set.intersection(*param_constraints)
|
||||
return [value for name, value in named_parameters.items() if name in matching_names]
|
||||
|
||||
|
||||
def map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs: Iterable[List[Dict]],
|
||||
named_parameters: Dict[str, Tensor],
|
||||
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
|
||||
"""Produce parameter groups corresponding to all the scheduler configs.
|
||||
|
||||
Takes all the scheduler configs, each of which applies to a specific optimizer
|
||||
option (like "lr" or "weight_decay") and has a set of parameter names which it
|
||||
applies to, and produces a final set of param groups where each param group
|
||||
covers all the options which apply to a particular set of parameters.
|
||||
|
||||
Args:
|
||||
all_scheduler_cfgs: All the scheduler configs covering every option.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
Returns:
|
||||
Tuple of lists of schedulers and param_groups, where schedulers[i]
|
||||
applies to param_groups[i].
|
||||
"""
|
||||
|
||||
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
|
||||
schedulers = []
|
||||
param_groups = []
|
||||
for scheduler_cfgs in scheduler_cfgs_per_param_group:
|
||||
param_constraints = [
|
||||
scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
|
||||
]
|
||||
matching_parameters = name_constraints_to_parameters(
|
||||
param_constraints, named_parameters
|
||||
)
|
||||
if len(matching_parameters) == 0: # If no overlap of parameters, skip
|
||||
continue
|
||||
schedulers_for_group = {
|
||||
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if "option" in scheduler_cfg
|
||||
}
|
||||
schedulers.append(schedulers_for_group)
|
||||
param_groups.append({"params": matching_parameters})
|
||||
return schedulers, param_groups
|
||||
|
||||
|
||||
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
|
||||
"""Check that the param groups are non-overlapping and cover all the parameters.
|
||||
|
||||
Args:
|
||||
param_groups: List of all param groups
|
||||
model: Model to validate against. The check ensures that all the model
|
||||
parameters are part of param_groups
|
||||
"""
|
||||
for pg in param_groups:
|
||||
# no param should be repeated within a group
|
||||
assert len(pg["params"]) == len(set(pg["params"]))
|
||||
parameters = [set(param_group["params"]) for param_group in param_groups]
|
||||
model_parameters = {parameter for _, parameter in model.named_parameters()}
|
||||
for p1, p2 in itertools.permutations(parameters, 2):
|
||||
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
|
||||
assert set.union(*parameters) == model_parameters, (
|
||||
"Scheduler generated param_groups must include all parameters of the model."
|
||||
f" Found {len(set.union(*parameters))} params whereas model has"
|
||||
f" {len(model_parameters)} params"
|
||||
)
|
||||
|
||||
|
||||
def unix_module_cls_pattern_to_parameter_names(
|
||||
filter_module_cls_names: List[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_module_cls_names.
|
||||
|
||||
Args:
|
||||
filter_module_cls_names: A list of filter strings containing class names, like
|
||||
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
if filter_module_cls_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for module_cls_name in filter_module_cls_names:
|
||||
module_cls = hydra.utils.get_class(module_cls_name)
|
||||
if module_cls not in module_cls_to_param_names:
|
||||
raise AssertionError(
|
||||
f"module_cls_name {module_cls_name} does not "
|
||||
"match any classes in the model"
|
||||
)
|
||||
matching_parameters = module_cls_to_param_names[module_cls]
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
logging.info(
|
||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||
)
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def unix_param_pattern_to_parameter_names(
|
||||
filter_param_names: Optional[List[str]],
|
||||
parameter_names: Dict[str, torch.Tensor],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_param_names.
|
||||
|
||||
Args:
|
||||
filter_param_names: A list of unix-style filter strings with optional
|
||||
wildcards, like ["block.2.*", "block.2.linear.weight"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
|
||||
if filter_param_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for param_name in filter_param_names:
|
||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) >= 1
|
||||
), f"param_name {param_name} does not match any parameters in the model"
|
||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def _unix_pattern_to_parameter_names(
|
||||
scheduler_cfg: DictConfig,
|
||||
parameter_names: Set[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in scheduler_cfg.
|
||||
|
||||
Args:
|
||||
scheduler_cfg: The config for the scheduler
|
||||
parameter_names: The set of all parameter names which will be filtered
|
||||
"""
|
||||
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
|
||||
return None
|
||||
return unix_param_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("param_names"), parameter_names
|
||||
).union(
|
||||
unix_module_cls_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_module_cls_to_param_names(
|
||||
model: nn.Module, param_allowlist: Set[str] = None
|
||||
) -> Dict[Type, str]:
|
||||
"""Produce a mapping from all the modules classes to the names of parames they own.
|
||||
|
||||
Only counts a parameter as part of the immediate parent module, i.e. recursive
|
||||
parents do not count.
|
||||
|
||||
Args:
|
||||
model: Model to iterate over
|
||||
param_allowlist: If specified, only these param names will be processed
|
||||
"""
|
||||
|
||||
module_cls_to_params = {}
|
||||
for module_name, module in model.named_modules():
|
||||
module_cls = type(module)
|
||||
module_cls_to_params.setdefault(module_cls, set())
|
||||
for param_name, _ in module.named_parameters(recurse=False):
|
||||
full_param_name = get_full_parameter_name(module_name, param_name)
|
||||
if param_allowlist is None or full_param_name in param_allowlist:
|
||||
module_cls_to_params[module_cls].add(full_param_name)
|
||||
return module_cls_to_params
|
||||
|
||||
|
||||
def construct_optimizer(
|
||||
model: torch.nn.Module,
|
||||
optimizer_conf: Any,
|
||||
options_conf: Mapping[str, List] = None,
|
||||
param_group_modifiers_conf: List[Callable] = None,
|
||||
param_allowlist: Optional[Set[str]] = None,
|
||||
validate_param_groups=True,
|
||||
) -> Optimizer:
|
||||
"""
|
||||
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
|
||||
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
|
||||
Batchnorm and/or no-update 1-D parameters support, based on the config.
|
||||
|
||||
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
|
||||
(LARS): https://arxiv.org/abs/1708.03888
|
||||
|
||||
Args:
|
||||
model: model to perform stochastic gradient descent
|
||||
optimization or ADAM optimization.
|
||||
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
|
||||
ADAM, still missing the params argument which this function provides to
|
||||
produce the final optimizer
|
||||
param_group_modifiers_conf: Optional user specified functions which can modify
|
||||
the final scheduler configs before the optimizer's param groups are built
|
||||
param_allowlist: The parameters to optimize. Parameters which are not part of
|
||||
this allowlist will be skipped.
|
||||
validate_param_groups: If enabled, valides that the produced param_groups don't
|
||||
overlap and cover all the model parameters.
|
||||
"""
|
||||
if param_allowlist is None:
|
||||
param_allowlist = {name for name, _ in model.named_parameters()}
|
||||
|
||||
named_parameters = {
|
||||
name: param
|
||||
for name, param in model.named_parameters()
|
||||
if name in param_allowlist
|
||||
}
|
||||
|
||||
if not options_conf:
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
|
||||
return Optimizer(optimizer)
|
||||
|
||||
all_parameter_names = {
|
||||
name for name, _ in model.named_parameters() if name in param_allowlist
|
||||
}
|
||||
module_cls_to_all_param_names = get_module_cls_to_param_names(
|
||||
model, param_allowlist
|
||||
)
|
||||
|
||||
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
|
||||
all_scheduler_cfgs = []
|
||||
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
|
||||
for config in scheduler_cfgs:
|
||||
config.option = option
|
||||
config.parameter_names = _unix_pattern_to_parameter_names(
|
||||
config, all_parameter_names, module_cls_to_all_param_names
|
||||
)
|
||||
set_default_parameters(scheduler_cfgs, all_parameter_names)
|
||||
all_scheduler_cfgs.append(scheduler_cfgs)
|
||||
|
||||
if param_group_modifiers_conf:
|
||||
for custom_param_modifier in param_group_modifiers_conf:
|
||||
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
|
||||
all_scheduler_cfgs = custom_param_modifier(
|
||||
scheduler_cfgs=all_scheduler_cfgs, model=model
|
||||
)
|
||||
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs, named_parameters
|
||||
)
|
||||
if validate_param_groups:
|
||||
validate_param_group_params(param_groups, model)
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
|
||||
return Optimizer(optimizer, schedulers)
|
||||
|
||||
|
||||
def get_full_parameter_name(module_name, param_name):
|
||||
if module_name == "":
|
||||
return param_name
|
||||
return f"{module_name}.{param_name}"
|
||||
|
||||
|
||||
class GradientClipper:
|
||||
"""
|
||||
Gradient clipping utils that works for DDP
|
||||
"""
|
||||
|
||||
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
|
||||
assert isinstance(max_norm, (int, float)) or max_norm is None
|
||||
self.max_norm = max_norm if max_norm is None else float(max_norm)
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, model: nn.Module):
|
||||
if self.max_norm is None:
|
||||
return # no-op
|
||||
|
||||
nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
|
||||
)
|
||||
|
||||
|
||||
class ValueScaler:
|
||||
def __init__(self, scheduler, mult_val: float):
|
||||
self.scheduler = scheduler
|
||||
self.mult_val = mult_val
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
val = self.scheduler(*args, **kwargs)
|
||||
return val * self.mult_val
|
||||
|
||||
|
||||
def rgetattr(obj, rattrs: str = None):
|
||||
"""
|
||||
Like getattr(), but supports dotted notation for nested objects.
|
||||
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
|
||||
"""
|
||||
if rattrs is None:
|
||||
return obj
|
||||
attrs = rattrs.split(".")
|
||||
for attr in attrs:
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def layer_decay_param_modifier(
|
||||
scheduler_cfgs: List[List[Dict]],
|
||||
model,
|
||||
layer_decay_value: float,
|
||||
layer_decay_min: Optional[float] = None,
|
||||
apply_to: Optional[str] = None,
|
||||
overrides: List[Dict] = (),
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Args
|
||||
- scheduler_cfgs: a list of omegaconf.ListConfigs.
|
||||
Each element in the list is a omegaconfg.DictConfig with the following structure
|
||||
{
|
||||
"scheduler": <some fvcore scheduler>
|
||||
"option": <value> possible options are "lr", "weight_decay" etc.
|
||||
"parameter_names": Set of str indicating param names that this scheduler applies to
|
||||
}
|
||||
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
|
||||
and a method get_num_layers.
|
||||
Alternatively, use apply_to argument to select a specific component of the model.
|
||||
- layer_decay_value: float
|
||||
- layer_decay_min: min val for layer decay
|
||||
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
|
||||
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
|
||||
Returns
|
||||
- scheduler_configs: same structure as the input, elements can be modified
|
||||
"""
|
||||
model = rgetattr(model, apply_to)
|
||||
num_layers = model.get_num_layers() + 1
|
||||
layer_decays = [
|
||||
layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
|
||||
]
|
||||
if layer_decay_min is not None:
|
||||
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
|
||||
final_scheduler_cfgs = []
|
||||
# scheduler_cfgs is a list of lists
|
||||
for scheduler_cfg_group in scheduler_cfgs:
|
||||
curr_cfg_group = []
|
||||
# scheduler_cfg_group is a list of dictionaries
|
||||
for scheduler_cfg in scheduler_cfg_group:
|
||||
if scheduler_cfg["option"] != "lr":
|
||||
curr_cfg_group.append(scheduler_cfg)
|
||||
continue
|
||||
# Need sorted so that the list of parameter names is deterministic and consistent
|
||||
# across re-runs of this job. Else it was causing issues with loading the optimizer
|
||||
# state during a job restart (D38591759)
|
||||
parameter_names = sorted(scheduler_cfg["parameter_names"])
|
||||
|
||||
# Only want one cfg group per layer
|
||||
layer_cfg_groups = {}
|
||||
for param_name in parameter_names:
|
||||
layer_id = num_layers
|
||||
this_scale = layer_decays[layer_id]
|
||||
if param_name.startswith(apply_to):
|
||||
layer_id = model.get_layer_id(param_name)
|
||||
this_scale = layer_decays[layer_id]
|
||||
# Overrides
|
||||
for override in overrides:
|
||||
if fnmatch.fnmatchcase(param_name, override["pattern"]):
|
||||
this_scale = float(override["value"])
|
||||
layer_id = override["pattern"]
|
||||
break
|
||||
|
||||
if layer_id not in layer_cfg_groups:
|
||||
curr_param = {
|
||||
"option": scheduler_cfg["option"],
|
||||
"scheduler": ValueScaler(
|
||||
scheduler_cfg["scheduler"], this_scale
|
||||
),
|
||||
"parameter_names": {param_name},
|
||||
}
|
||||
else:
|
||||
curr_param = layer_cfg_groups[layer_id]
|
||||
curr_param["parameter_names"].add(param_name)
|
||||
layer_cfg_groups[layer_id] = curr_param
|
||||
|
||||
for layer_cfg in layer_cfg_groups.values():
|
||||
curr_cfg_group.append(layer_cfg)
|
||||
|
||||
final_scheduler_cfgs.append(curr_cfg_group)
|
||||
return final_scheduler_cfgs
|
163
training/scripts/sav_frame_extraction_submitit.py
Normal file
163
training/scripts/sav_frame_extraction_submitit.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
import submitit
|
||||
import tqdm
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="[SA-V Preprocessing] Extracting JPEG frames",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# ------------
|
||||
# DATA
|
||||
# ------------
|
||||
data_parser = parser.add_argument_group(
|
||||
title="SA-V dataset data root",
|
||||
description="What data to load and how to process it.",
|
||||
)
|
||||
data_parser.add_argument(
|
||||
"--sav-vid-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("Where to find the SAV videos"),
|
||||
)
|
||||
data_parser.add_argument(
|
||||
"--sav-frame-sample-rate",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Rate at which to sub-sample frames",
|
||||
)
|
||||
|
||||
# ------------
|
||||
# LAUNCH
|
||||
# ------------
|
||||
launch_parser = parser.add_argument_group(
|
||||
title="Cluster launch settings",
|
||||
description="Number of jobs and retry settings.",
|
||||
)
|
||||
launch_parser.add_argument(
|
||||
"--n-jobs",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Shard the run over this many jobs.",
|
||||
)
|
||||
launch_parser.add_argument(
|
||||
"--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
|
||||
)
|
||||
launch_parser.add_argument(
|
||||
"--partition", type=str, required=True, help="Partition to launch on."
|
||||
)
|
||||
launch_parser.add_argument(
|
||||
"--account", type=str, required=True, help="Partition to launch on."
|
||||
)
|
||||
launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
|
||||
|
||||
# ------------
|
||||
# OUTPUT
|
||||
# ------------
|
||||
output_parser = parser.add_argument_group(
|
||||
title="Setting for results output", description="Where and how to save results."
|
||||
)
|
||||
output_parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("Where to dump the extracted jpeg frames"),
|
||||
)
|
||||
output_parser.add_argument(
|
||||
"--slurm-output-root-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("Where to save slurm outputs"),
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def decode_video(video_path: str):
|
||||
assert os.path.exists(video_path)
|
||||
video = cv2.VideoCapture(video_path)
|
||||
video_frames = []
|
||||
while video.isOpened():
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
video_frames.append(frame)
|
||||
else:
|
||||
break
|
||||
return video_frames
|
||||
|
||||
|
||||
def extract_frames(video_path, sample_rate):
|
||||
frames = decode_video(video_path)
|
||||
return frames[::sample_rate]
|
||||
|
||||
|
||||
def submitit_launch(video_paths, sample_rate, save_root):
|
||||
for path in tqdm.tqdm(video_paths):
|
||||
frames = extract_frames(path, sample_rate)
|
||||
output_folder = os.path.join(save_root, Path(path).stem)
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
for fid, frame in enumerate(frames):
|
||||
frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
|
||||
cv2.imwrite(frame_path, frame)
|
||||
print(f"Saved output to {save_root}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_args_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
sav_vid_dir = args.sav_vid_dir
|
||||
save_root = args.output_dir
|
||||
sample_rate = args.sav_frame_sample_rate
|
||||
|
||||
# List all SA-V videos
|
||||
mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
|
||||
mp4_files = np.array(mp4_files)
|
||||
chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
|
||||
|
||||
print(f"Processing videos in: {sav_vid_dir}")
|
||||
print(f"Processing {len(mp4_files)} files")
|
||||
print(f"Beginning processing in {args.n_jobs} processes")
|
||||
|
||||
# Submitit params
|
||||
jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
|
||||
cpus_per_task = 4
|
||||
executor = submitit.AutoExecutor(folder=jobs_dir)
|
||||
executor.update_parameters(
|
||||
timeout_min=args.timeout,
|
||||
gpus_per_node=0,
|
||||
tasks_per_node=1,
|
||||
slurm_array_parallelism=args.n_jobs,
|
||||
cpus_per_task=cpus_per_task,
|
||||
slurm_partition=args.partition,
|
||||
slurm_account=args.account,
|
||||
slurm_qos=args.qos,
|
||||
)
|
||||
executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
|
||||
|
||||
# Launch
|
||||
jobs = []
|
||||
with executor.batch():
|
||||
for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
|
||||
job = executor.submit(
|
||||
submitit_launch,
|
||||
video_paths=mp4_chunk,
|
||||
sample_rate=sample_rate,
|
||||
save_root=save_root,
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
for j in jobs:
|
||||
print(f"Slurm JobID: {j.job_id}")
|
||||
print(f"Saving outputs to {save_root}")
|
||||
print(f"Slurm outputs at {args.slurm_output_root_dir}")
|
270
training/train.py
Normal file
270
training/train.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import submitit
|
||||
import torch
|
||||
|
||||
from hydra import compose, initialize_config_module
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from training.utils.train_utils import makedir, register_omegaconf_resolvers
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
|
||||
|
||||
def single_proc_run(local_rank, main_port, cfg, world_size):
|
||||
"""Single GPU process"""
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(main_port)
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
try:
|
||||
register_omegaconf_resolvers()
|
||||
except Exception as e:
|
||||
logging.info(e)
|
||||
|
||||
trainer = instantiate(cfg.trainer, _recursive_=False)
|
||||
trainer.run()
|
||||
|
||||
|
||||
def single_node_runner(cfg, main_port: int):
|
||||
assert cfg.launcher.num_nodes == 1
|
||||
num_proc = cfg.launcher.gpus_per_node
|
||||
torch.multiprocessing.set_start_method(
|
||||
"spawn"
|
||||
) # CUDA runtime does not support `fork`
|
||||
if num_proc == 1:
|
||||
# directly call single_proc so we can easily set breakpoints
|
||||
# mp.spawn does not let us set breakpoints
|
||||
single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
|
||||
else:
|
||||
mp_runner = torch.multiprocessing.start_processes
|
||||
args = (main_port, cfg, num_proc)
|
||||
# Note: using "fork" below, "spawn" causes time and error regressions. Using
|
||||
# spawn changes the default multiprocessing context to spawn, which doesn't
|
||||
# interact well with the dataloaders (likely due to the use of OpenCV).
|
||||
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
|
||||
|
||||
|
||||
def format_exception(e: Exception, limit=20):
|
||||
traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
|
||||
return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
|
||||
|
||||
|
||||
class SubmititRunner(submitit.helpers.Checkpointable):
|
||||
"""A callable which is passed to submitit to launch the jobs."""
|
||||
|
||||
def __init__(self, port, cfg):
|
||||
self.cfg = cfg
|
||||
self.port = port
|
||||
self.has_setup = False
|
||||
|
||||
def run_trainer(self):
|
||||
job_env = submitit.JobEnvironment()
|
||||
# Need to add this again so the hydra.job.set_env PYTHONPATH
|
||||
# is also set when launching jobs.
|
||||
add_pythonpath_to_sys_path()
|
||||
os.environ["MASTER_ADDR"] = job_env.hostnames[0]
|
||||
os.environ["MASTER_PORT"] = str(self.port)
|
||||
os.environ["RANK"] = str(job_env.global_rank)
|
||||
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
|
||||
|
||||
register_omegaconf_resolvers()
|
||||
cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
|
||||
cfg_resolved = OmegaConf.create(cfg_resolved)
|
||||
|
||||
trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
|
||||
trainer.run()
|
||||
|
||||
def __call__(self):
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.setup_job_info(job_env.job_id, job_env.global_rank)
|
||||
try:
|
||||
self.run_trainer()
|
||||
except Exception as e:
|
||||
# Log the exception. Then raise it again (as what SubmititRunner currently does).
|
||||
message = format_exception(e)
|
||||
logging.error(message)
|
||||
raise e
|
||||
|
||||
def setup_job_info(self, job_id, rank):
|
||||
"""Set up slurm job info"""
|
||||
self.job_info = {
|
||||
"job_id": job_id,
|
||||
"rank": rank,
|
||||
"cluster": self.cfg.get("cluster", None),
|
||||
"experiment_log_dir": self.cfg.launcher.experiment_log_dir,
|
||||
}
|
||||
|
||||
self.has_setup = True
|
||||
|
||||
|
||||
def add_pythonpath_to_sys_path():
|
||||
if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
|
||||
return
|
||||
sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
cfg = compose(config_name=args.config)
|
||||
if cfg.launcher.experiment_log_dir is None:
|
||||
cfg.launcher.experiment_log_dir = os.path.join(
|
||||
os.getcwd(), "sam2_logs", args.config
|
||||
)
|
||||
print("###################### Train App Config ####################")
|
||||
print(OmegaConf.to_yaml(cfg))
|
||||
print("############################################################")
|
||||
|
||||
add_pythonpath_to_sys_path()
|
||||
makedir(cfg.launcher.experiment_log_dir)
|
||||
with g_pathmgr.open(
|
||||
os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
|
||||
) as f:
|
||||
f.write(OmegaConf.to_yaml(cfg))
|
||||
|
||||
cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
|
||||
cfg_resolved = OmegaConf.create(cfg_resolved)
|
||||
|
||||
with g_pathmgr.open(
|
||||
os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
|
||||
) as f:
|
||||
f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
|
||||
|
||||
submitit_conf = cfg.get("submitit", None)
|
||||
assert submitit_conf is not None, "Missing submitit config"
|
||||
|
||||
submitit_dir = cfg.launcher.experiment_log_dir
|
||||
submitit_dir = os.path.join(submitit_dir, "submitit_logs")
|
||||
# Priotrize cmd line args
|
||||
cfg.launcher.gpus_per_node = (
|
||||
args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
|
||||
)
|
||||
cfg.launcher.num_nodes = (
|
||||
args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
|
||||
)
|
||||
submitit_conf.use_cluster = (
|
||||
args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
|
||||
)
|
||||
if submitit_conf.use_cluster:
|
||||
executor = submitit.AutoExecutor(folder=submitit_dir)
|
||||
submitit_conf.partition = (
|
||||
args.partition
|
||||
if args.partition is not None
|
||||
else submitit_conf.get("partition", None)
|
||||
)
|
||||
submitit_conf.account = (
|
||||
args.account
|
||||
if args.account is not None
|
||||
else submitit_conf.get("account", None)
|
||||
)
|
||||
submitit_conf.qos = (
|
||||
args.qos if args.qos is not None else submitit_conf.get("qos", None)
|
||||
)
|
||||
job_kwargs = {
|
||||
"timeout_min": 60 * submitit_conf.timeout_hour,
|
||||
"name": (
|
||||
submitit_conf.name if hasattr(submitit_conf, "name") else args.config
|
||||
),
|
||||
"slurm_partition": submitit_conf.partition,
|
||||
"gpus_per_node": cfg.launcher.gpus_per_node,
|
||||
"tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": submitit_conf.cpus_per_task,
|
||||
"nodes": cfg.launcher.num_nodes,
|
||||
"slurm_additional_parameters": {
|
||||
"exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
|
||||
},
|
||||
}
|
||||
if "include_nodes" in submitit_conf:
|
||||
assert (
|
||||
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
|
||||
), "Not enough nodes"
|
||||
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
||||
submitit_conf["include_nodes"]
|
||||
)
|
||||
if submitit_conf.account is not None:
|
||||
job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
|
||||
if submitit_conf.qos is not None:
|
||||
job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
|
||||
|
||||
if submitit_conf.get("mem_gb", None) is not None:
|
||||
job_kwargs["mem_gb"] = submitit_conf.mem_gb
|
||||
elif submitit_conf.get("mem", None) is not None:
|
||||
job_kwargs["slurm_mem"] = submitit_conf.mem
|
||||
|
||||
if submitit_conf.get("constraints", None) is not None:
|
||||
job_kwargs["slurm_constraint"] = submitit_conf.constraints
|
||||
|
||||
if submitit_conf.get("comment", None) is not None:
|
||||
job_kwargs["slurm_comment"] = submitit_conf.comment
|
||||
|
||||
# Supports only cpu-bind option within srun_args. New options can be added here
|
||||
if submitit_conf.get("srun_args", None) is not None:
|
||||
job_kwargs["slurm_srun_args"] = []
|
||||
if submitit_conf.srun_args.get("cpu_bind", None) is not None:
|
||||
job_kwargs["slurm_srun_args"].extend(
|
||||
["--cpu-bind", submitit_conf.srun_args.cpu_bind]
|
||||
)
|
||||
|
||||
print("###################### SLURM Config ####################")
|
||||
print(job_kwargs)
|
||||
print("##########################################")
|
||||
executor.update_parameters(**job_kwargs)
|
||||
|
||||
main_port = random.randint(
|
||||
submitit_conf.port_range[0], submitit_conf.port_range[1]
|
||||
)
|
||||
runner = SubmititRunner(main_port, cfg)
|
||||
job = executor.submit(runner)
|
||||
print(f"Submitit Job ID: {job.job_id}")
|
||||
runner.setup_job_info(job.job_id, rank=0)
|
||||
else:
|
||||
cfg.launcher.num_nodes = 1
|
||||
main_port = random.randint(
|
||||
submitit_conf.port_range[0], submitit_conf.port_range[1]
|
||||
)
|
||||
single_node_runner(cfg, main_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
initialize_config_module("sam2", version_base="1.2")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-cluster",
|
||||
type=int,
|
||||
default=None,
|
||||
help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
|
||||
)
|
||||
parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
|
||||
parser.add_argument("--account", type=str, default=None, help="SLURM account")
|
||||
parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
|
||||
parser.add_argument(
|
||||
"--num-gpus", type=int, default=None, help="number of GPUS per node"
|
||||
)
|
||||
parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
|
||||
args = parser.parse_args()
|
||||
args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
|
||||
register_omegaconf_resolvers()
|
||||
main(args)
|
1113
training/trainer.py
Normal file
1113
training/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
5
training/utils/__init__.py
Normal file
5
training/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
361
training/utils/checkpoint_utils.py
Normal file
361
training/utils/checkpoint_utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import fnmatch
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from torch.jit._script import RecursiveScriptModule
|
||||
|
||||
|
||||
def unix_pattern_to_parameter_names(
|
||||
constraints: List[str], all_parameter_names: Sequence[str]
|
||||
) -> Union[None, Set[str]]:
|
||||
"""
|
||||
Go through the list of parameter names and select those that match
|
||||
any of the provided constraints
|
||||
"""
|
||||
parameter_names = []
|
||||
for param_name in constraints:
|
||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"param_names {param_name} don't match any param in the given names."
|
||||
parameter_names.append(matching_parameters)
|
||||
return set.union(*parameter_names)
|
||||
|
||||
|
||||
def filter_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
included_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: state_dict[k] for k in included_keys}
|
||||
|
||||
|
||||
def exclude_params_matching_unix_pattern(
|
||||
patterns: List[str], state_dict: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Remove from the state dictionary the parameters matching the provided unix patterns
|
||||
|
||||
Args:
|
||||
patterns: the list of unix patterns to exclude
|
||||
state_dict: the dictionary to filter
|
||||
|
||||
Returns:
|
||||
A new state dictionary
|
||||
"""
|
||||
if len(patterns) == 0:
|
||||
return state_dict
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys)
|
||||
return {k: v for k, v in state_dict.items() if k not in excluded_keys}
|
||||
|
||||
|
||||
def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]):
|
||||
keys = []
|
||||
trace = []
|
||||
for k, v in state_dict.items():
|
||||
keys.append(k)
|
||||
trace.append(v.sum().item())
|
||||
trace = np.array(trace)[np.argsort(keys)]
|
||||
return trace
|
||||
|
||||
|
||||
def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]):
|
||||
"""
|
||||
Verifies that all the parameters matching the provided patterns
|
||||
are frozen - this acts as a safeguard when ignoring parameter
|
||||
when saving checkpoints - if the parameters are in fact trainable
|
||||
"""
|
||||
if not patterns:
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
non_frozen_keys = {
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if n in frozen_state_dict and p.requires_grad
|
||||
}
|
||||
if non_frozen_keys:
|
||||
raise ValueError(
|
||||
f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_check_parameter_frozen(
|
||||
model: nn.Module, patterns: List[str], disabled: bool = True
|
||||
):
|
||||
"""
|
||||
Context manager that inspects a model surrounding a piece of code
|
||||
and verifies if the model has been updated by this piece of code
|
||||
|
||||
The function will raise an exception if the model has been updated
|
||||
on at least one of the parameter that matches one of the pattern
|
||||
|
||||
Args:
|
||||
model: the model that might have been updated
|
||||
patterns: for the parameters we want to observe
|
||||
allowed:
|
||||
"""
|
||||
if not patterns or disabled:
|
||||
yield
|
||||
return
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_before = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
yield
|
||||
|
||||
frozen_state_dict = filter_params_matching_unix_pattern(
|
||||
patterns=patterns, state_dict=model.state_dict()
|
||||
)
|
||||
summary_after = _get_state_dict_summary(frozen_state_dict)
|
||||
|
||||
if not np.allclose(summary_before, summary_after, atol=1e-6):
|
||||
raise ValueError(
|
||||
f"""
|
||||
The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`.
|
||||
You can resolve this error by either initializing those parameters from within the model definition
|
||||
or using the flag `trainer.checkpoint.initialize_after_preemption` to True.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class CkptExcludeKernel:
|
||||
"""
|
||||
Removes the keys from the given model state_dict that match the key_pattern.
|
||||
|
||||
Args:
|
||||
key_pattern: Patterns used to select the keys in the state_dict
|
||||
that are eligible for this kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, key_pattern: List[str]):
|
||||
self.key_pattern = key_pattern
|
||||
|
||||
def __call__(self, state_dict: Dict):
|
||||
"""
|
||||
Args:
|
||||
state_dict: A dictionary representing the given checkpoint's state dict.
|
||||
"""
|
||||
if len(self.key_pattern) == 0:
|
||||
return state_dict
|
||||
exclude_keys = unix_pattern_to_parameter_names(
|
||||
self.key_pattern, state_dict.keys()
|
||||
)
|
||||
return {k: v for k, v in state_dict.items() if k not in exclude_keys}
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
path_list: List[str],
|
||||
pick_recursive_keys: Optional[List[str]] = None,
|
||||
map_location: str = "cpu",
|
||||
) -> Any:
|
||||
"""
|
||||
Loads a checkpoint from the specified path.
|
||||
|
||||
Args:
|
||||
path_list: A list of paths which contain the checkpoint. Each element
|
||||
is tried (in order) until a file that exists is found. That file is then
|
||||
used to read the checkpoint.
|
||||
pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None.
|
||||
For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"]
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
path_exists = False
|
||||
for path in path_list:
|
||||
if g_pathmgr.exists(path):
|
||||
path_exists = True
|
||||
break
|
||||
|
||||
if not path_exists:
|
||||
raise ValueError(f"No path exists in {path_list}")
|
||||
|
||||
with g_pathmgr.open(path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
logging.info(f"Loaded checkpoint from {path}")
|
||||
if pick_recursive_keys is not None:
|
||||
for key in pick_recursive_keys:
|
||||
checkpoint = checkpoint[key]
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_state_dict(checkpoint, ckpt_state_dict_keys):
|
||||
if isinstance(checkpoint, RecursiveScriptModule):
|
||||
# This is a torchscript JIT model
|
||||
return checkpoint.state_dict()
|
||||
pre_train_dict = checkpoint
|
||||
for i, key in enumerate(ckpt_state_dict_keys):
|
||||
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
|
||||
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
|
||||
):
|
||||
key_str = (
|
||||
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
|
||||
)
|
||||
raise KeyError(
|
||||
f"'{key}' not found in checkpoint{key_str} "
|
||||
f"with keys: {pre_train_dict.keys()}"
|
||||
)
|
||||
pre_train_dict = pre_train_dict[key]
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def load_checkpoint_and_apply_kernels(
|
||||
checkpoint_path: str,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
ckpt_state_dict_keys: Tuple[str] = ("state_dict",),
|
||||
map_location: str = "cpu",
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Performs checkpoint loading with a variety of pre-processing kernel applied in
|
||||
sequence.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint.
|
||||
checkpoint_kernels List(Callable): A list of checkpoint processing kernels
|
||||
to apply in the specified order. Supported kernels include `CkptIncludeKernel`,
|
||||
`CkptExcludeKernel`, etc. These kernels are applied in the
|
||||
given order.
|
||||
ckpt_state_dict_keys (str): Keys containing the model state dict.
|
||||
map_location (str): a function, torch.device, string or a dict specifying how to
|
||||
remap storage locations
|
||||
|
||||
Returns: Model with the matchin pre-trained weights loaded.
|
||||
"""
|
||||
assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format(
|
||||
checkpoint_path
|
||||
)
|
||||
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike.
|
||||
with g_pathmgr.open(checkpoint_path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location=map_location)
|
||||
|
||||
pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys)
|
||||
|
||||
# Not logging into info etc since it's a huge log
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict pre-kernel application: %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
pre_train_dict = f(state_dict=pre_train_dict)
|
||||
|
||||
logging.debug(
|
||||
"Loaded Checkpoint State Dict Post-kernel application %s"
|
||||
% str(", ".join(list(pre_train_dict.keys())))
|
||||
)
|
||||
|
||||
return pre_train_dict
|
||||
|
||||
|
||||
def check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict: bool,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
):
|
||||
if ignore_missing_keys is not None and len(ignore_missing_keys) > 0:
|
||||
ignored_keys = unix_pattern_to_parameter_names(
|
||||
ignore_missing_keys, missing_keys
|
||||
)
|
||||
missing_keys = [key for key in missing_keys if key not in ignored_keys]
|
||||
|
||||
if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0:
|
||||
ignored_unexpected_keys = unix_pattern_to_parameter_names(
|
||||
ignore_unexpected_keys, unexpected_keys
|
||||
)
|
||||
unexpected_keys = [
|
||||
key for key in unexpected_keys if key not in ignored_unexpected_keys
|
||||
]
|
||||
|
||||
err = "State key mismatch."
|
||||
if unexpected_keys:
|
||||
err += f" Unexpected keys: {unexpected_keys}."
|
||||
if missing_keys:
|
||||
err += f" Missing keys: {missing_keys}."
|
||||
|
||||
if unexpected_keys or missing_keys:
|
||||
logging.warning(err)
|
||||
if unexpected_keys or strict:
|
||||
raise KeyError(err)
|
||||
|
||||
|
||||
def load_state_dict_into_model(
|
||||
state_dict: Dict,
|
||||
model: nn.Module,
|
||||
strict: bool = True,
|
||||
ignore_missing_keys: List[str] = None,
|
||||
ignore_unexpected_keys: List[str] = None,
|
||||
checkpoint_kernels: List[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Loads a state dict into the given model.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary containing the model's
|
||||
state dict, or a subset if strict is False
|
||||
model: Model to load the checkpoint weights into
|
||||
strict: raise if the state_dict has missing state keys
|
||||
ignore_missing_keys: unix pattern of keys to ignore
|
||||
"""
|
||||
# Apply kernels
|
||||
if checkpoint_kernels is not None:
|
||||
for f in checkpoint_kernels:
|
||||
state_dict = f(state_dict=state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
check_load_state_dict_errors(
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
strict=strict,
|
||||
ignore_missing_keys=ignore_missing_keys,
|
||||
ignore_unexpected_keys=ignore_unexpected_keys,
|
||||
)
|
||||
return model
|
179
training/utils/data_utils.py
Normal file
179
training/utils/data_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from tensordict import tensorclass
|
||||
|
||||
|
||||
@tensorclass
|
||||
class BatchedVideoMetaData:
|
||||
"""
|
||||
This class represents metadata about a batch of videos.
|
||||
Attributes:
|
||||
unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
|
||||
frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
|
||||
"""
|
||||
|
||||
unique_objects_identifier: torch.LongTensor
|
||||
frame_orig_size: torch.LongTensor
|
||||
|
||||
|
||||
@tensorclass
|
||||
class BatchedVideoDatapoint:
|
||||
"""
|
||||
This class represents a batch of videos with associated annotations and metadata.
|
||||
Attributes:
|
||||
img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
|
||||
obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
|
||||
masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
|
||||
metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
|
||||
dict_key: A string key used to identify the batch.
|
||||
"""
|
||||
|
||||
img_batch: torch.FloatTensor
|
||||
obj_to_frame_idx: torch.IntTensor
|
||||
masks: torch.BoolTensor
|
||||
metadata: BatchedVideoMetaData
|
||||
|
||||
dict_key: str
|
||||
|
||||
def pin_memory(self, device=None):
|
||||
return self.apply(torch.Tensor.pin_memory, device=device)
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""
|
||||
Returns the number of frames per video.
|
||||
"""
|
||||
return self.batch_size[0]
|
||||
|
||||
@property
|
||||
def num_videos(self) -> int:
|
||||
"""
|
||||
Returns the number of videos in the batch.
|
||||
"""
|
||||
return self.img_batch.shape[1]
|
||||
|
||||
@property
|
||||
def flat_obj_to_img_idx(self) -> torch.IntTensor:
|
||||
"""
|
||||
Returns a flattened tensor containing the object to img index.
|
||||
The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
|
||||
"""
|
||||
frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
|
||||
flat_idx = video_idx * self.num_frames + frame_idx
|
||||
return flat_idx
|
||||
|
||||
@property
|
||||
def flat_img_batch(self) -> torch.FloatTensor:
|
||||
"""
|
||||
Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
|
||||
"""
|
||||
|
||||
return self.img_batch.transpose(0, 1).flatten(0, 1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Object:
|
||||
# Id of the object in the media
|
||||
object_id: int
|
||||
# Index of the frame in the media (0 if single image)
|
||||
frame_index: int
|
||||
segment: Union[torch.Tensor, dict] # RLE dict or binary mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
data: Union[torch.Tensor, PILImage.Image]
|
||||
objects: List[Object]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoDatapoint:
|
||||
"""Refers to an image/video and all its annotations"""
|
||||
|
||||
frames: List[Frame]
|
||||
video_id: int
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
def collate_fn(
|
||||
batch: List[VideoDatapoint],
|
||||
dict_key,
|
||||
) -> BatchedVideoDatapoint:
|
||||
"""
|
||||
Args:
|
||||
batch: A list of VideoDatapoint instances.
|
||||
dict_key (str): A string key used to identify the batch.
|
||||
"""
|
||||
img_batch = []
|
||||
for video in batch:
|
||||
img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
|
||||
|
||||
img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
|
||||
T = img_batch.shape[0]
|
||||
# Prepare data structures for sequential processing. Per-frame processing but batched across videos.
|
||||
step_t_objects_identifier = [[] for _ in range(T)]
|
||||
step_t_frame_orig_size = [[] for _ in range(T)]
|
||||
|
||||
step_t_masks = [[] for _ in range(T)]
|
||||
step_t_obj_to_frame_idx = [
|
||||
[] for _ in range(T)
|
||||
] # List to store frame indices for each time step
|
||||
|
||||
for video_idx, video in enumerate(batch):
|
||||
orig_video_id = video.video_id
|
||||
orig_frame_size = video.size
|
||||
for t, frame in enumerate(video.frames):
|
||||
objects = frame.objects
|
||||
for obj in objects:
|
||||
orig_obj_id = obj.object_id
|
||||
orig_frame_idx = obj.frame_index
|
||||
step_t_obj_to_frame_idx[t].append(
|
||||
torch.tensor([t, video_idx], dtype=torch.int)
|
||||
)
|
||||
step_t_masks[t].append(obj.segment.to(torch.bool))
|
||||
step_t_objects_identifier[t].append(
|
||||
torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
|
||||
)
|
||||
step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
|
||||
|
||||
obj_to_frame_idx = torch.stack(
|
||||
[
|
||||
torch.stack(obj_to_frame_idx, dim=0)
|
||||
for obj_to_frame_idx in step_t_obj_to_frame_idx
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
|
||||
objects_identifier = torch.stack(
|
||||
[torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
|
||||
)
|
||||
frame_orig_size = torch.stack(
|
||||
[torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
|
||||
)
|
||||
return BatchedVideoDatapoint(
|
||||
img_batch=img_batch,
|
||||
obj_to_frame_idx=obj_to_frame_idx,
|
||||
masks=masks,
|
||||
metadata=BatchedVideoMetaData(
|
||||
unique_objects_identifier=objects_identifier,
|
||||
frame_orig_size=frame_orig_size,
|
||||
),
|
||||
dict_key=dict_key,
|
||||
batch_size=[T],
|
||||
)
|
576
training/utils/distributed.py
Normal file
576
training/utils/distributed.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# Default to GPU 0
|
||||
_cuda_device_index: int = 0
|
||||
|
||||
# Setting _cuda_device_index to -1 internally implies that we should use CPU
|
||||
_CPU_DEVICE_INDEX = -1
|
||||
_PRIMARY_RANK = 0
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
|
||||
if dist.get_backend() == "nccl":
|
||||
# Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes
|
||||
# being much slower than others causing a timeout (which can happen in relation
|
||||
# or LVIS class mAP evaluation).
|
||||
timeout = 43200
|
||||
return dist.new_group(
|
||||
backend="gloo",
|
||||
timeout=datetime.timedelta(seconds=timeout),
|
||||
)
|
||||
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""Return true if the current process is the main one"""
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
|
||||
`all_gather` above, but using filesystem instead of collective ops.
|
||||
|
||||
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
|
||||
(and other ranks will have an empty list).
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
print("gathering via files")
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
# if unspecified, we will save to the current python file dir
|
||||
if filesys_save_dir is not None:
|
||||
save_dir = filesys_save_dir
|
||||
elif "EXP_DIR" in os.environ:
|
||||
save_dir = os.environ["EXP_DIR"]
|
||||
else:
|
||||
# try the same directory where the code is stored
|
||||
save_dir = filesys_save_dir or os.path.dirname(__file__)
|
||||
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
|
||||
if is_main_process():
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# use a timestamp and salt to distinguish different all_gather
|
||||
timestamp = int(time.time()) if is_main_process() else 0
|
||||
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
|
||||
# broadcast the timestamp and salt across ranks
|
||||
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
|
||||
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
|
||||
dist.all_reduce(timestamp_and_salt, group=cpu_group)
|
||||
timestamp, salt = timestamp_and_salt.tolist()
|
||||
|
||||
# save the data to a file on the disk
|
||||
rank_save = get_rank()
|
||||
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
|
||||
save_data_path = os.path.join(save_dir, save_data_filename)
|
||||
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
|
||||
torch.save(data, save_data_path)
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# read the data from the files
|
||||
data_list = []
|
||||
if rank_save == 0 or not gather_to_rank_0_only:
|
||||
for rank_load in range(world_size):
|
||||
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
|
||||
load_data_path = os.path.join(save_dir, load_data_filename)
|
||||
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
|
||||
data_list.append(torch.load(load_data_path))
|
||||
dist.barrier(group=cpu_group)
|
||||
|
||||
# delete the saved file
|
||||
os.remove(save_data_path)
|
||||
return data_list
|
||||
|
||||
|
||||
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
|
||||
return all_gather_via_filesys(
|
||||
data, filesys_save_dir, gather_to_rank_0_only=True
|
||||
)
|
||||
|
||||
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
|
||||
return all_gather_via_filesys(data, filesys_save_dir)
|
||||
|
||||
cpu_group = None
|
||||
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
|
||||
cpu_group = _get_global_gloo_group()
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(data, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
device = "cuda" if cpu_group is None else "cpu"
|
||||
tensor = torch.ByteTensor(data_view).to(device)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
|
||||
size_list = [
|
||||
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
|
||||
]
|
||||
if cpu_group is None:
|
||||
dist.all_gather(size_list, local_size)
|
||||
else:
|
||||
print("gathering on cpu")
|
||||
dist.all_gather(size_list, local_size, group=cpu_group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
assert isinstance(local_size.item(), int)
|
||||
local_size = int(local_size.item())
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(
|
||||
size=(max_size - local_size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
if cpu_group is None:
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
else:
|
||||
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||
obj = torch.load(buffer)
|
||||
data_list.append(obj)
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This helper function converts to the correct
|
||||
device and returns the tensor + original device.
|
||||
"""
|
||||
orig_device = "cpu" if not tensor.is_cuda else "gpu"
|
||||
if (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
|
||||
and not tensor.is_cuda
|
||||
):
|
||||
tensor = tensor.cuda()
|
||||
return (tensor, orig_device)
|
||||
|
||||
|
||||
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
|
||||
"""
|
||||
For some backends, such as NCCL, communication only works if the
|
||||
tensor is on the GPU. This converts the tensor back to original device.
|
||||
"""
|
||||
if tensor.is_cuda and orig_device == "cpu":
|
||||
tensor = tensor.cpu()
|
||||
return tensor
|
||||
|
||||
|
||||
def is_distributed_training_run() -> bool:
|
||||
return (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
and (torch.distributed.get_world_size() > 1)
|
||||
)
|
||||
|
||||
|
||||
def is_primary() -> bool:
|
||||
"""
|
||||
Returns True if this is rank 0 of a distributed training job OR if it is
|
||||
a single trainer job. Otherwise False.
|
||||
"""
|
||||
return get_rank() == _PRIMARY_RANK
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing mean reduction
|
||||
of tensor over all processes.
|
||||
"""
|
||||
return all_reduce_op(
|
||||
tensor,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
lambda t: t / torch.distributed.get_world_size(),
|
||||
)
|
||||
|
||||
|
||||
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing sum
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN)
|
||||
|
||||
|
||||
def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing min
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX)
|
||||
|
||||
|
||||
def all_reduce_op(
|
||||
tensor: torch.Tensor,
|
||||
op: torch.distributed.ReduceOp,
|
||||
after_op_func: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_reduce for performing
|
||||
reduction of tensor over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.all_reduce(tensor, op)
|
||||
if after_op_func is not None:
|
||||
tensor = after_op_func(tensor)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
Wrapper over torch.distributed.all_gather for performing
|
||||
'gather' of 'tensor' over all processes in both distributed /
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
if tensor.ndim == 0:
|
||||
# 0 dim tensors cannot be gathered. so unsqueeze
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
gathered_tensors = [
|
||||
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, tensor)
|
||||
gathered_tensors = [
|
||||
convert_to_normal_tensor(_tensor, orig_device)
|
||||
for _tensor in gathered_tensors
|
||||
]
|
||||
else:
|
||||
gathered_tensors = [tensor]
|
||||
|
||||
return gathered_tensors
|
||||
|
||||
|
||||
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
|
||||
gathered_tensors = gather_tensors_from_all(tensor)
|
||||
gathered_tensor = torch.cat(gathered_tensors, 0)
|
||||
return gathered_tensor
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source
|
||||
to all processes in both distributed / non-distributed scenarios.
|
||||
"""
|
||||
if is_distributed_training_run():
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
torch.distributed.broadcast(tensor, src)
|
||||
tensor = convert_to_normal_tensor(tensor, orig_device)
|
||||
return tensor
|
||||
|
||||
|
||||
def barrier() -> None:
|
||||
"""
|
||||
Wrapper over torch.distributed.barrier, returns without waiting
|
||||
if the distributed process group is not initialized instead of throwing error.
|
||||
"""
|
||||
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
|
||||
return
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting worldsize in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_world_size()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 1
|
||||
)
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""
|
||||
Simple wrapper for correctly getting rank in both distributed
|
||||
/ non-distributed settings
|
||||
"""
|
||||
return (
|
||||
torch.distributed.get_rank()
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
def get_primary_rank() -> int:
|
||||
return _PRIMARY_RANK
|
||||
|
||||
|
||||
def set_cuda_device_index(idx: int) -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = idx
|
||||
torch.cuda.set_device(_cuda_device_index)
|
||||
|
||||
|
||||
def set_cpu_device() -> None:
|
||||
global _cuda_device_index
|
||||
_cuda_device_index = _CPU_DEVICE_INDEX
|
||||
|
||||
|
||||
def get_cuda_device_index() -> int:
|
||||
return _cuda_device_index
|
||||
|
||||
|
||||
def init_distributed_data_parallel_model(
|
||||
model: torch.nn.Module,
|
||||
broadcast_buffers: bool = False,
|
||||
find_unused_parameters: bool = True,
|
||||
bucket_cap_mb: int = 25,
|
||||
) -> torch.nn.parallel.DistributedDataParallel:
|
||||
global _cuda_device_index
|
||||
|
||||
if _cuda_device_index == _CPU_DEVICE_INDEX:
|
||||
# CPU-only model, don't specify device
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
else:
|
||||
# GPU model
|
||||
return torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[_cuda_device_index],
|
||||
output_device=_cuda_device_index,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
)
|
||||
|
||||
|
||||
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
|
||||
"""Broadcast an object from a source to all workers.
|
||||
|
||||
Args:
|
||||
obj: Object to broadcast, must be serializable
|
||||
src: Source rank for broadcast (default is primary)
|
||||
use_disk: If enabled, removes redundant CPU memory copies by writing to
|
||||
disk
|
||||
"""
|
||||
# Either broadcast from primary to the fleet (default),
|
||||
# or use the src setting as the original rank
|
||||
if get_rank() == src:
|
||||
# Emit data
|
||||
buffer = io.BytesIO()
|
||||
torch.save(obj, buffer)
|
||||
data_view = buffer.getbuffer()
|
||||
length_tensor = torch.LongTensor([len(data_view)])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.ByteTensor(data_view)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
else:
|
||||
# Fetch from the source
|
||||
length_tensor = torch.LongTensor([0])
|
||||
length_tensor = broadcast(length_tensor, src=src)
|
||||
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
|
||||
data_tensor = broadcast(data_tensor, src=src)
|
||||
if use_disk:
|
||||
with tempfile.TemporaryFile("r+b") as f:
|
||||
f.write(data_tensor.numpy())
|
||||
# remove reference to the data tensor and hope that Python garbage
|
||||
# collects it
|
||||
del data_tensor
|
||||
f.seek(0)
|
||||
obj = torch.load(f)
|
||||
else:
|
||||
buffer = io.BytesIO(data_tensor.numpy())
|
||||
obj = torch.load(buffer)
|
||||
return obj
|
||||
|
||||
|
||||
def all_gather_tensor(tensor: torch.Tensor, world_size=None):
|
||||
if world_size is None:
|
||||
world_size = get_world_size()
|
||||
# make contiguous because NCCL won't gather the tensor otherwise
|
||||
assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!"
|
||||
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
||||
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_all, tensor, async_op=False) # performance opt
|
||||
tensor_all = [
|
||||
convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all
|
||||
]
|
||||
return tensor_all
|
||||
|
||||
|
||||
def all_gather_batch(tensors: List[torch.Tensor]):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
for tensor in tensors:
|
||||
tensor_all = all_gather_tensor(tensor, world_size)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
class GatherLayer(autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
dist.all_reduce(all_gradients)
|
||||
return all_gradients[dist.get_rank()]
|
||||
|
||||
|
||||
def all_gather_batch_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
tensor_list = []
|
||||
output_tensor = []
|
||||
|
||||
for tensor in tensors:
|
||||
tensor_all = GatherLayer.apply(tensor)
|
||||
tensor_list.append(tensor_all)
|
||||
|
||||
for tensor_all in tensor_list:
|
||||
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
def unwrap_ddp_if_wrapped(model):
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
def create_new_process_group(group_size):
|
||||
"""
|
||||
Creates process groups of a gives `group_size` and returns
|
||||
process group that current GPU participates in.
|
||||
|
||||
`group_size` must divide the total number of GPUs (world_size).
|
||||
|
||||
Modified from
|
||||
https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60
|
||||
|
||||
Args:
|
||||
group_size (int): number of GPU's to collaborate for sync bn
|
||||
"""
|
||||
|
||||
assert group_size > 0
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size <= 8:
|
||||
if group_size > world_size:
|
||||
logging.warning(
|
||||
f"Requested group size [{group_size}] > world size [{world_size}]. "
|
||||
"Assuming local debug run and capping it to world size."
|
||||
)
|
||||
group_size = world_size
|
||||
assert world_size >= group_size
|
||||
assert world_size % group_size == 0
|
||||
|
||||
group = None
|
||||
for group_num in range(world_size // group_size):
|
||||
group_ids = range(group_num * group_size, (group_num + 1) * group_size)
|
||||
cur_group = torch.distributed.new_group(ranks=group_ids)
|
||||
if torch.distributed.get_rank() // group_size == group_num:
|
||||
group = cur_group
|
||||
# can not drop out and return here, every process must go through creation of all subgroups
|
||||
|
||||
assert group is not None
|
||||
return group
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
246
training/utils/logger.py
Normal file
246
training/utils/logger.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
|
||||
import atexit
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from numpy import ndarray
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||
|
||||
Scalar = Union[Tensor, ndarray, int, float]
|
||||
|
||||
|
||||
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
|
||||
makedir(log_dir)
|
||||
summary_writer_method = SummaryWriter
|
||||
return TensorBoardLogger(
|
||||
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
|
||||
)
|
||||
|
||||
|
||||
class TensorBoardWriterWrapper:
|
||||
"""
|
||||
A wrapper around a SummaryWriter object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
*args: Any,
|
||||
filename_suffix: str = None,
|
||||
summary_writer_method: Any = SummaryWriter,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TensorBoard logger.
|
||||
On construction, the logger creates a new events file that logs
|
||||
will be written to. If the environment variable `RANK` is defined,
|
||||
logger will only log if RANK = 0.
|
||||
|
||||
NOTE: If using the logger with distributed training:
|
||||
- This logger can call collective operations
|
||||
- Logs will be written on rank 0 only
|
||||
- Logger must be constructed synchronously *after* initializing distributed process group.
|
||||
|
||||
Args:
|
||||
path (str): path to write logs to
|
||||
*args, **kwargs: Extra arguments to pass to SummaryWriter
|
||||
"""
|
||||
self._writer: Optional[SummaryWriter] = None
|
||||
_, self._rank = get_machine_local_and_dist_rank()
|
||||
self._path: str = path
|
||||
if self._rank == 0:
|
||||
logging.info(
|
||||
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
|
||||
)
|
||||
self._writer = summary_writer_method(
|
||||
log_dir=path,
|
||||
*args,
|
||||
filename_suffix=filename_suffix or str(uuid.uuid4()),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Not logging meters on this host because env RANK: {self._rank} != 0"
|
||||
)
|
||||
atexit.register(self.close)
|
||||
|
||||
@property
|
||||
def writer(self) -> Optional[SummaryWriter]:
|
||||
return self._writer
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._path
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Writes pending logs to disk."""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close writer, flushing pending logs to disk.
|
||||
Logs cannot be written after `close` is called.
|
||||
"""
|
||||
|
||||
if not self._writer:
|
||||
return
|
||||
|
||||
self._writer.close()
|
||||
self._writer = None
|
||||
|
||||
|
||||
class TensorBoardLogger(TensorBoardWriterWrapper):
|
||||
"""
|
||||
A simple logger for TensorBoard.
|
||||
"""
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
"""Add multiple scalar values to TensorBoard.
|
||||
|
||||
Args:
|
||||
payload (dict): dictionary of tag name and scalar value
|
||||
step (int, Optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
for k, v in payload.items():
|
||||
self.log(k, v, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
"""Add scalar data to TensorBoard.
|
||||
|
||||
Args:
|
||||
name (string): tag name used to group scalars
|
||||
data (float/int/Tensor): scalar data to log
|
||||
step (int, optional): step value to record
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_scalar(name, data, global_step=step, new_style=True)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
"""Add hyperparameter data to TensorBoard.
|
||||
|
||||
Args:
|
||||
hparams (dict): dictionary of hyperparameter names and corresponding values
|
||||
meters (dict): dictionary of name of meter and corersponding values
|
||||
"""
|
||||
if not self._writer:
|
||||
return
|
||||
self._writer.add_hparams(hparams, meters)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
|
||||
"""
|
||||
|
||||
def __init__(self, logging_conf):
|
||||
# allow turning off TensorBoard with "should_log: false" in config
|
||||
tb_config = logging_conf.tensorboard_writer
|
||||
tb_should_log = tb_config and tb_config.pop("should_log", True)
|
||||
self.tb_logger = instantiate(tb_config) if tb_should_log else None
|
||||
|
||||
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_dict(payload, step)
|
||||
|
||||
def log(self, name: str, data: Scalar, step: int) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log(name, data, step)
|
||||
|
||||
def log_hparams(
|
||||
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
|
||||
) -> None:
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_hparams(hparams, meters)
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
# we tune the buffering value so that the logs are updated
|
||||
# frequently.
|
||||
log_buffer_kb = 10 * 1024 # 10KB
|
||||
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
|
||||
atexit.register(io.close)
|
||||
return io
|
||||
|
||||
|
||||
def setup_logging(
|
||||
name,
|
||||
output_dir=None,
|
||||
rank=0,
|
||||
log_level_primary="INFO",
|
||||
log_level_secondary="ERROR",
|
||||
):
|
||||
"""
|
||||
Setup various logging streams: stdout and file handlers.
|
||||
For file handlers, we only setup for the master gpu.
|
||||
"""
|
||||
# get the filename if we want to log to the file as well
|
||||
log_filename = None
|
||||
if output_dir:
|
||||
makedir(output_dir)
|
||||
if rank == 0:
|
||||
log_filename = f"{output_dir}/log.txt"
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(log_level_primary)
|
||||
|
||||
# create formatter
|
||||
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
|
||||
formatter = logging.Formatter(FORMAT)
|
||||
|
||||
# Cleanup any existing handlers
|
||||
for h in logger.handlers:
|
||||
logger.removeHandler(h)
|
||||
logger.root.handlers = []
|
||||
|
||||
# setup the console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
if rank == 0:
|
||||
console_handler.setLevel(log_level_primary)
|
||||
else:
|
||||
console_handler.setLevel(log_level_secondary)
|
||||
|
||||
# we log to file as well if user wants
|
||||
if log_filename and rank == 0:
|
||||
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
|
||||
file_handler.setLevel(log_level_primary)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
logging.root = logger
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""
|
||||
After training is done, we ensure to shut down all the logger streams.
|
||||
"""
|
||||
logging.info("Shutting down loggers...")
|
||||
handlers = logging.root.handlers
|
||||
for handler in handlers:
|
||||
handler.close()
|
288
training/utils/train_utils.py
Normal file
288
training/utils/train_utils.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def multiply_all(*args):
|
||||
return np.prod(np.array(args)).item()
|
||||
|
||||
|
||||
def collect_dict_keys(config):
|
||||
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
|
||||
val_keys = []
|
||||
# If the this config points to the collate function, then it has a key
|
||||
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
|
||||
val_keys.append(config["dict_key"])
|
||||
else:
|
||||
# Recursively proceed
|
||||
for v in config.values():
|
||||
if isinstance(v, type(config)):
|
||||
val_keys.extend(collect_dict_keys(v))
|
||||
elif isinstance(v, omegaconf.listconfig.ListConfig):
|
||||
for item in v:
|
||||
if isinstance(item, type(config)):
|
||||
val_keys.extend(collect_dict_keys(item))
|
||||
return val_keys
|
||||
|
||||
|
||||
class Phase:
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
|
||||
|
||||
def register_omegaconf_resolvers():
|
||||
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
|
||||
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
|
||||
OmegaConf.register_new_resolver("add", lambda x, y: x + y)
|
||||
OmegaConf.register_new_resolver("times", multiply_all)
|
||||
OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
|
||||
OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
|
||||
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
|
||||
OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
|
||||
OmegaConf.register_new_resolver("int", lambda x: int(x))
|
||||
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
|
||||
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
|
||||
|
||||
|
||||
def setup_distributed_backend(backend, timeout_mins):
|
||||
"""
|
||||
Initialize torch.distributed and set the CUDA device.
|
||||
Expects environment variables to be set as per
|
||||
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
|
||||
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
|
||||
"""
|
||||
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
|
||||
# of waiting
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
|
||||
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_machine_local_and_dist_rank():
|
||||
"""
|
||||
Get the distributed and local rank of the current gpu.
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||
distributed_rank = int(os.environ.get("RANK", None))
|
||||
assert (
|
||||
local_rank is not None and distributed_rank is not None
|
||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
||||
return local_rank, distributed_rank
|
||||
|
||||
|
||||
def print_cfg(cfg):
|
||||
"""
|
||||
Supports printing both Hydra DictConfig and also the AttrDict config
|
||||
"""
|
||||
logging.info("Training with config:")
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
|
||||
def set_seeds(seed_value, max_epochs, dist_rank):
|
||||
"""
|
||||
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
|
||||
seeds if the CUDA is available. This ensures deterministic nature of the training.
|
||||
"""
|
||||
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
|
||||
seed_value = (seed_value + dist_rank) * max_epochs
|
||||
logging.info(f"MACHINE SEED: {seed_value}")
|
||||
random.seed(seed_value)
|
||||
np.random.seed(seed_value)
|
||||
torch.manual_seed(seed_value)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed_value)
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
logging.info(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_amp_type(amp_type: Optional[str] = None):
|
||||
if amp_type is None:
|
||||
return None
|
||||
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
|
||||
if amp_type == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float16
|
||||
|
||||
|
||||
def log_env_variables():
|
||||
env_keys = sorted(list(os.environ.keys()))
|
||||
st = ""
|
||||
for k in env_keys:
|
||||
v = os.environ[k]
|
||||
st += f"{k}={v}\n"
|
||||
logging.info("Logging ENV_VARIABLES")
|
||||
logging.info(st)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class MemMeter:
|
||||
"""Computes and stores the current, avg, and max of peak Mem usage per iteration"""
|
||||
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.device = device
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0 # Per iteration max usage
|
||||
self.avg = 0 # Avg per iteration max usage
|
||||
self.peak = 0 # Peak usage for lifetime of program
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self._allow_updates = True
|
||||
|
||||
def update(self, n=1, reset_peak_usage=True):
|
||||
self.val = torch.cuda.max_memory_allocated() // 1e9
|
||||
self.sum += self.val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
self.peak = max(self.peak, self.val)
|
||||
if reset_peak_usage:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = (
|
||||
"{name}: {val"
|
||||
+ self.fmt
|
||||
+ "} ({avg"
|
||||
+ self.fmt
|
||||
+ "}/{peak"
|
||||
+ self.fmt
|
||||
+ "})"
|
||||
)
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def human_readable_time(time_seconds):
|
||||
time = int(time_seconds)
|
||||
minutes, seconds = divmod(time, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
days, hours = divmod(hours, 24)
|
||||
return f"{days:02}d {hours:02}h {minutes:02}m"
|
||||
|
||||
|
||||
class DurationMeter:
|
||||
def __init__(self, name, device, fmt=":f"):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.fmt = fmt
|
||||
self.val = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
|
||||
def update(self, val):
|
||||
self.val = val
|
||||
|
||||
def add(self, val):
|
||||
self.val += val
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}: {human_readable_time(self.val)}"
|
||||
|
||||
|
||||
class ProgressMeter:
|
||||
def __init__(self, num_batches, meters, real_meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.real_meters = real_meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch, enable_print=False):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
entries += [
|
||||
" | ".join(
|
||||
[
|
||||
f"{os.path.join(name, subname)}: {val:.4f}"
|
||||
for subname, val in meter.compute().items()
|
||||
]
|
||||
)
|
||||
for name, meter in self.real_meters.items()
|
||||
]
|
||||
logging.info(" | ".join(entries))
|
||||
if enable_print:
|
||||
print(" | ".join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = "{:" + str(num_digits) + "d}"
|
||||
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
||||
|
||||
|
||||
def get_resume_checkpoint(checkpoint_save_dir):
|
||||
if not g_pathmgr.isdir(checkpoint_save_dir):
|
||||
return None
|
||||
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
|
||||
if not g_pathmgr.isfile(ckpt_file):
|
||||
return None
|
||||
|
||||
return ckpt_file
|
Reference in New Issue
Block a user