[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
180
training/dataset/sam2_datasets.py
Normal file
180
training/dataset/sam2_datasets.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
|
||||
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class MixedDataLoader:
|
||||
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
|
||||
"""
|
||||
Args:
|
||||
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
|
||||
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
|
||||
|
||||
"""
|
||||
assert len(dataloaders) == mixing_prob.shape[0]
|
||||
self.dataloaders = dataloaders
|
||||
self.mixing_prob = mixing_prob
|
||||
# Iterator state
|
||||
self._iter_dls = None
|
||||
self._iter_mixing_prob = None
|
||||
self.random_generator = torch.Generator()
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(d) for d in self.dataloaders])
|
||||
|
||||
def __iter__(self):
|
||||
# Synchronize dataloader seeds
|
||||
self.random_generator.manual_seed(42)
|
||||
self._iter_dls = [iter(loader) for loader in self.dataloaders]
|
||||
self._iter_mixing_prob = self.mixing_prob.clone()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
|
||||
"""
|
||||
if self._iter_dls is None:
|
||||
raise TypeError(f"{type(self).__name__} object is not an iterator")
|
||||
|
||||
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
|
||||
dataset_idx = self._iter_mixing_prob.multinomial(
|
||||
1, generator=self.random_generator
|
||||
).item()
|
||||
try:
|
||||
item = next(self._iter_dls[dataset_idx])
|
||||
return item
|
||||
except StopIteration:
|
||||
# No more iterations for this dataset, set it's mixing probability to zero and try again.
|
||||
self._iter_mixing_prob[dataset_idx] = 0
|
||||
except Exception as e:
|
||||
# log and raise any other unexpected error.
|
||||
logging.error(e)
|
||||
raise e
|
||||
|
||||
# Exhausted all iterators
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class TorchTrainMixedDataset:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: List[Dataset],
|
||||
batch_sizes: List[int],
|
||||
num_workers: int,
|
||||
shuffle: bool,
|
||||
pin_memory: bool,
|
||||
drop_last: bool,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
worker_init_fn: Optional[Callable] = None,
|
||||
phases_per_epoch: int = 1,
|
||||
dataset_prob: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
datasets (List[Dataset]): List of Datasets to be mixed.
|
||||
batch_sizes (List[int]): Batch sizes for each dataset in the list.
|
||||
num_workers (int): Number of workers per dataloader.
|
||||
shuffle (bool): Whether or not to shuffle data.
|
||||
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
|
||||
drop_last (bool): Whether or not to drop the last batch of data.
|
||||
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
|
||||
worker_init_fn (Callable): Function to init each dataloader worker.
|
||||
phases_per_epoch (int): Number of phases per epoch.
|
||||
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
|
||||
"""
|
||||
|
||||
self.datasets = datasets
|
||||
self.batch_sizes = batch_sizes
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.collate_fn = collate_fn
|
||||
self.worker_init_fn = worker_init_fn
|
||||
assert len(self.datasets) > 0
|
||||
for dataset in self.datasets:
|
||||
assert not isinstance(dataset, IterableDataset), "Not supported"
|
||||
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
|
||||
self._set_dataset_epoch(dataset, 0)
|
||||
self.phases_per_epoch = phases_per_epoch
|
||||
self.chunks = [None] * len(datasets)
|
||||
if dataset_prob is None:
|
||||
# If not provided, assign each dataset a probability proportional to its length.
|
||||
dataset_lens = [
|
||||
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
|
||||
for d, bs in zip(datasets, batch_sizes)
|
||||
]
|
||||
total_len = sum(dataset_lens)
|
||||
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
|
||||
else:
|
||||
assert len(dataset_prob) == len(datasets)
|
||||
dataset_prob = torch.tensor(dataset_prob)
|
||||
|
||||
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
|
||||
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
|
||||
self.dataset_prob = dataset_prob
|
||||
|
||||
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
def get_loader(self, epoch) -> Iterable:
|
||||
dataloaders = []
|
||||
for d_idx, (dataset, batch_size) in enumerate(
|
||||
zip(self.datasets, self.batch_sizes)
|
||||
):
|
||||
if self.phases_per_epoch > 1:
|
||||
# Major epoch that looops over entire dataset
|
||||
# len(main_epoch) == phases_per_epoch * len(epoch)
|
||||
main_epoch = epoch // self.phases_per_epoch
|
||||
|
||||
# Phase with in the main epoch
|
||||
local_phase = epoch % self.phases_per_epoch
|
||||
|
||||
# Start of new data-epoch or job is resumed after preemtion.
|
||||
if local_phase == 0 or self.chunks[d_idx] is None:
|
||||
# set seed for dataset epoch
|
||||
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
|
||||
self._set_dataset_epoch(dataset, main_epoch)
|
||||
|
||||
# Separate random generator for subset sampling
|
||||
g = torch.Generator()
|
||||
g.manual_seed(main_epoch)
|
||||
self.chunks[d_idx] = torch.chunk(
|
||||
torch.randperm(len(dataset), generator=g),
|
||||
self.phases_per_epoch,
|
||||
)
|
||||
|
||||
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
|
||||
else:
|
||||
self._set_dataset_epoch(dataset, epoch)
|
||||
|
||||
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
|
||||
dataloaders.append(
|
||||
DataLoader(
|
||||
dataset,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
)
|
||||
return MixedDataLoader(dataloaders, self.dataset_prob)
|
Reference in New Issue
Block a user