181 lines
7.2 KiB
Python
181 lines
7.2 KiB
Python
![]() |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import logging
|
||
|
import math
|
||
|
from typing import Callable, Iterable, List, Optional, Sequence
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
|
||
|
|
||
|
from torch.utils.data.distributed import DistributedSampler
|
||
|
|
||
|
|
||
|
class MixedDataLoader:
|
||
|
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
|
||
|
"""
|
||
|
Args:
|
||
|
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
|
||
|
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
|
||
|
|
||
|
"""
|
||
|
assert len(dataloaders) == mixing_prob.shape[0]
|
||
|
self.dataloaders = dataloaders
|
||
|
self.mixing_prob = mixing_prob
|
||
|
# Iterator state
|
||
|
self._iter_dls = None
|
||
|
self._iter_mixing_prob = None
|
||
|
self.random_generator = torch.Generator()
|
||
|
|
||
|
def __len__(self):
|
||
|
return sum([len(d) for d in self.dataloaders])
|
||
|
|
||
|
def __iter__(self):
|
||
|
# Synchronize dataloader seeds
|
||
|
self.random_generator.manual_seed(42)
|
||
|
self._iter_dls = [iter(loader) for loader in self.dataloaders]
|
||
|
self._iter_mixing_prob = self.mixing_prob.clone()
|
||
|
return self
|
||
|
|
||
|
def __next__(self):
|
||
|
"""
|
||
|
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
|
||
|
"""
|
||
|
if self._iter_dls is None:
|
||
|
raise TypeError(f"{type(self).__name__} object is not an iterator")
|
||
|
|
||
|
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
|
||
|
dataset_idx = self._iter_mixing_prob.multinomial(
|
||
|
1, generator=self.random_generator
|
||
|
).item()
|
||
|
try:
|
||
|
item = next(self._iter_dls[dataset_idx])
|
||
|
return item
|
||
|
except StopIteration:
|
||
|
# No more iterations for this dataset, set it's mixing probability to zero and try again.
|
||
|
self._iter_mixing_prob[dataset_idx] = 0
|
||
|
except Exception as e:
|
||
|
# log and raise any other unexpected error.
|
||
|
logging.error(e)
|
||
|
raise e
|
||
|
|
||
|
# Exhausted all iterators
|
||
|
raise StopIteration
|
||
|
|
||
|
|
||
|
class TorchTrainMixedDataset:
|
||
|
def __init__(
|
||
|
self,
|
||
|
datasets: List[Dataset],
|
||
|
batch_sizes: List[int],
|
||
|
num_workers: int,
|
||
|
shuffle: bool,
|
||
|
pin_memory: bool,
|
||
|
drop_last: bool,
|
||
|
collate_fn: Optional[Callable] = None,
|
||
|
worker_init_fn: Optional[Callable] = None,
|
||
|
phases_per_epoch: int = 1,
|
||
|
dataset_prob: Optional[List[float]] = None,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Args:
|
||
|
datasets (List[Dataset]): List of Datasets to be mixed.
|
||
|
batch_sizes (List[int]): Batch sizes for each dataset in the list.
|
||
|
num_workers (int): Number of workers per dataloader.
|
||
|
shuffle (bool): Whether or not to shuffle data.
|
||
|
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
|
||
|
drop_last (bool): Whether or not to drop the last batch of data.
|
||
|
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
|
||
|
worker_init_fn (Callable): Function to init each dataloader worker.
|
||
|
phases_per_epoch (int): Number of phases per epoch.
|
||
|
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
|
||
|
"""
|
||
|
|
||
|
self.datasets = datasets
|
||
|
self.batch_sizes = batch_sizes
|
||
|
self.num_workers = num_workers
|
||
|
self.shuffle = shuffle
|
||
|
self.pin_memory = pin_memory
|
||
|
self.drop_last = drop_last
|
||
|
self.collate_fn = collate_fn
|
||
|
self.worker_init_fn = worker_init_fn
|
||
|
assert len(self.datasets) > 0
|
||
|
for dataset in self.datasets:
|
||
|
assert not isinstance(dataset, IterableDataset), "Not supported"
|
||
|
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
|
||
|
self._set_dataset_epoch(dataset, 0)
|
||
|
self.phases_per_epoch = phases_per_epoch
|
||
|
self.chunks = [None] * len(datasets)
|
||
|
if dataset_prob is None:
|
||
|
# If not provided, assign each dataset a probability proportional to its length.
|
||
|
dataset_lens = [
|
||
|
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
|
||
|
for d, bs in zip(datasets, batch_sizes)
|
||
|
]
|
||
|
total_len = sum(dataset_lens)
|
||
|
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
|
||
|
else:
|
||
|
assert len(dataset_prob) == len(datasets)
|
||
|
dataset_prob = torch.tensor(dataset_prob)
|
||
|
|
||
|
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
|
||
|
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
|
||
|
self.dataset_prob = dataset_prob
|
||
|
|
||
|
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
|
||
|
if hasattr(dataset, "epoch"):
|
||
|
dataset.epoch = epoch
|
||
|
if hasattr(dataset, "set_epoch"):
|
||
|
dataset.set_epoch(epoch)
|
||
|
|
||
|
def get_loader(self, epoch) -> Iterable:
|
||
|
dataloaders = []
|
||
|
for d_idx, (dataset, batch_size) in enumerate(
|
||
|
zip(self.datasets, self.batch_sizes)
|
||
|
):
|
||
|
if self.phases_per_epoch > 1:
|
||
|
# Major epoch that looops over entire dataset
|
||
|
# len(main_epoch) == phases_per_epoch * len(epoch)
|
||
|
main_epoch = epoch // self.phases_per_epoch
|
||
|
|
||
|
# Phase with in the main epoch
|
||
|
local_phase = epoch % self.phases_per_epoch
|
||
|
|
||
|
# Start of new data-epoch or job is resumed after preemtion.
|
||
|
if local_phase == 0 or self.chunks[d_idx] is None:
|
||
|
# set seed for dataset epoch
|
||
|
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
|
||
|
self._set_dataset_epoch(dataset, main_epoch)
|
||
|
|
||
|
# Separate random generator for subset sampling
|
||
|
g = torch.Generator()
|
||
|
g.manual_seed(main_epoch)
|
||
|
self.chunks[d_idx] = torch.chunk(
|
||
|
torch.randperm(len(dataset), generator=g),
|
||
|
self.phases_per_epoch,
|
||
|
)
|
||
|
|
||
|
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
|
||
|
else:
|
||
|
self._set_dataset_epoch(dataset, epoch)
|
||
|
|
||
|
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
|
||
|
sampler.set_epoch(epoch)
|
||
|
|
||
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
|
||
|
dataloaders.append(
|
||
|
DataLoader(
|
||
|
dataset,
|
||
|
num_workers=self.num_workers,
|
||
|
pin_memory=self.pin_memory,
|
||
|
batch_sampler=batch_sampler,
|
||
|
collate_fn=self.collate_fn,
|
||
|
worker_init_fn=self.worker_init_fn,
|
||
|
)
|
||
|
)
|
||
|
return MixedDataLoader(dataloaders, self.dataset_prob)
|