[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
104
training/dataset/utils.py
Normal file
104
training/dataset/utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
ConcatDataset as TorchConcatDataset,
|
||||
Dataset,
|
||||
Subset as TorchSubset,
|
||||
)
|
||||
|
||||
|
||||
class ConcatDataset(TorchConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
|
||||
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
for dataset in self.datasets:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
class Subset(TorchSubset):
|
||||
def __init__(self, dataset, indices) -> None:
|
||||
super(Subset, self).__init__(dataset, indices)
|
||||
|
||||
self.repeat_factors = dataset.repeat_factors[indices]
|
||||
assert len(indices) == len(self.repeat_factors)
|
||||
|
||||
|
||||
# Adapted from Detectron2
|
||||
class RepeatFactorWrapper(Dataset):
|
||||
"""
|
||||
Thin wrapper around a dataset to implement repeat factor sampling.
|
||||
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
|
||||
Set it to uniformly ones to disable repeat factor sampling
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, seed: int = 0):
|
||||
self.dataset = dataset
|
||||
self.epoch_ids = None
|
||||
self._seed = seed
|
||||
|
||||
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
||||
self._int_part = torch.trunc(dataset.repeat_factors)
|
||||
self._frac_part = dataset.repeat_factors - self._int_part
|
||||
|
||||
def _get_epoch_indices(self, generator):
|
||||
"""
|
||||
Create a list of dataset indices (with repeats) to use for one epoch.
|
||||
|
||||
Args:
|
||||
generator (torch.Generator): pseudo random number generator used for
|
||||
stochastic rounding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
||||
is repeated based on its calculated repeat factor.
|
||||
"""
|
||||
# Since repeat factors are fractional, we use stochastic rounding so
|
||||
# that the target repeat factor is achieved in expectation over the
|
||||
# course of training
|
||||
rands = torch.rand(len(self._frac_part), generator=generator)
|
||||
rep_factors = self._int_part + (rands < self._frac_part).float()
|
||||
# Construct a list of indices in which we repeat images as specified
|
||||
indices = []
|
||||
for dataset_index, rep_factor in enumerate(rep_factors):
|
||||
indices.extend([dataset_index] * int(rep_factor.item()))
|
||||
return torch.tensor(indices, dtype=torch.int64)
|
||||
|
||||
def __len__(self):
|
||||
if self.epoch_ids is None:
|
||||
# Here we raise an error instead of returning original len(self.dataset) avoid
|
||||
# accidentally using unwrapped length. Otherwise it's error-prone since the
|
||||
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
|
||||
raise RuntimeError("please call set_epoch first to get wrapped length")
|
||||
# return len(self.dataset)
|
||||
|
||||
return len(self.epoch_ids)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed + epoch)
|
||||
self.epoch_ids = self._get_epoch_indices(g)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.epoch_ids is None:
|
||||
raise RuntimeError(
|
||||
"Repeat ids haven't been computed. Did you forget to call set_epoch?"
|
||||
)
|
||||
|
||||
return self.dataset[self.epoch_ids[idx]]
|
Reference in New Issue
Block a user