SAM2.1 checkpoints + training code + Demo
This commit is contained in:
Haitham Khedr
2024-09-28 08:20:56 -07:00
parent 7e1596c0b6
commit aa9b8722d0
325 changed files with 38174 additions and 223 deletions

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from typing import Callable, Iterable, List, Optional, Sequence
import torch
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler
class MixedDataLoader:
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
"""
Args:
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
"""
assert len(dataloaders) == mixing_prob.shape[0]
self.dataloaders = dataloaders
self.mixing_prob = mixing_prob
# Iterator state
self._iter_dls = None
self._iter_mixing_prob = None
self.random_generator = torch.Generator()
def __len__(self):
return sum([len(d) for d in self.dataloaders])
def __iter__(self):
# Synchronize dataloader seeds
self.random_generator.manual_seed(42)
self._iter_dls = [iter(loader) for loader in self.dataloaders]
self._iter_mixing_prob = self.mixing_prob.clone()
return self
def __next__(self):
"""
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
"""
if self._iter_dls is None:
raise TypeError(f"{type(self).__name__} object is not an iterator")
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
dataset_idx = self._iter_mixing_prob.multinomial(
1, generator=self.random_generator
).item()
try:
item = next(self._iter_dls[dataset_idx])
return item
except StopIteration:
# No more iterations for this dataset, set it's mixing probability to zero and try again.
self._iter_mixing_prob[dataset_idx] = 0
except Exception as e:
# log and raise any other unexpected error.
logging.error(e)
raise e
# Exhausted all iterators
raise StopIteration
class TorchTrainMixedDataset:
def __init__(
self,
datasets: List[Dataset],
batch_sizes: List[int],
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
phases_per_epoch: int = 1,
dataset_prob: Optional[List[float]] = None,
) -> None:
"""
Args:
datasets (List[Dataset]): List of Datasets to be mixed.
batch_sizes (List[int]): Batch sizes for each dataset in the list.
num_workers (int): Number of workers per dataloader.
shuffle (bool): Whether or not to shuffle data.
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
drop_last (bool): Whether or not to drop the last batch of data.
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
worker_init_fn (Callable): Function to init each dataloader worker.
phases_per_epoch (int): Number of phases per epoch.
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
"""
self.datasets = datasets
self.batch_sizes = batch_sizes
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
assert len(self.datasets) > 0
for dataset in self.datasets:
assert not isinstance(dataset, IterableDataset), "Not supported"
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
self._set_dataset_epoch(dataset, 0)
self.phases_per_epoch = phases_per_epoch
self.chunks = [None] * len(datasets)
if dataset_prob is None:
# If not provided, assign each dataset a probability proportional to its length.
dataset_lens = [
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
for d, bs in zip(datasets, batch_sizes)
]
total_len = sum(dataset_lens)
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
else:
assert len(dataset_prob) == len(datasets)
dataset_prob = torch.tensor(dataset_prob)
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
self.dataset_prob = dataset_prob
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
def get_loader(self, epoch) -> Iterable:
dataloaders = []
for d_idx, (dataset, batch_size) in enumerate(
zip(self.datasets, self.batch_sizes)
):
if self.phases_per_epoch > 1:
# Major epoch that looops over entire dataset
# len(main_epoch) == phases_per_epoch * len(epoch)
main_epoch = epoch // self.phases_per_epoch
# Phase with in the main epoch
local_phase = epoch % self.phases_per_epoch
# Start of new data-epoch or job is resumed after preemtion.
if local_phase == 0 or self.chunks[d_idx] is None:
# set seed for dataset epoch
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
self._set_dataset_epoch(dataset, main_epoch)
# Separate random generator for subset sampling
g = torch.Generator()
g.manual_seed(main_epoch)
self.chunks[d_idx] = torch.chunk(
torch.randperm(len(dataset), generator=g),
self.phases_per_epoch,
)
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
else:
self._set_dataset_epoch(dataset, epoch)
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
sampler.set_epoch(epoch)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
dataloaders.append(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)
)
return MixedDataLoader(dataloaders, self.dataset_prob)

View File

@@ -0,0 +1,528 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Transforms and data augmentation for both image + bbox.
"""
import logging
import random
from typing import Iterable
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from torchvision.transforms import InterpolationMode
from training.utils.data_utils import VideoDatapoint
def hflip(datapoint, index):
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.hflip(obj.segment)
return datapoint
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = max_size * min_original_size / max_original_size
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = int(round(size))
oh = int(round(size * h / w))
else:
oh = int(round(size))
ow = int(round(size * w / h))
return (oh, ow)
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
cur_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
size = get_size(cur_size, size, max_size)
old_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
if v2:
datapoint.frames[index].data = Fv2.resize(
datapoint.frames[index].data, size, antialias=True
)
else:
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
new_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
h, w = size
datapoint.frames[index].size = (h, w)
return datapoint
def pad(datapoint, index, padding, v2=False):
old_h, old_w = datapoint.frames[index].size
h, w = old_h, old_w
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data, (0, 0, padding[0], padding[1])
)
h += padding[1]
w += padding[0]
else:
# left, top, right, bottom
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data,
(padding[0], padding[1], padding[2], padding[3]),
)
h += padding[1] + padding[3]
w += padding[0] + padding[2]
datapoint.frames[index].size = (h, w)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
if v2:
if len(padding) == 2:
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = Fv2.pad(obj.segment, tuple(padding))
else:
if len(padding) == 2:
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = F.pad(obj.segment, tuple(padding))
return datapoint
class RandomHorizontalFlip:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for i in range(len(datapoint.frames)):
datapoint = hflip(datapoint, i)
return datapoint
for i in range(len(datapoint.frames)):
if random.random() < self.p:
datapoint = hflip(datapoint, i)
return datapoint
class RandomResizeAPI:
def __init__(
self, sizes, consistent_transform, max_size=None, square=False, v2=False
):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
self.consistent_transform = consistent_transform
self.v2 = v2
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
size = random.choice(self.sizes)
for i in range(len(datapoint.frames)):
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
for i in range(len(datapoint.frames)):
size = random.choice(self.sizes)
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
class ToTensorAPI:
def __init__(self, v2=False):
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.to_image_tensor(img.data)
else:
img.data = F.to_tensor(img.data)
return datapoint
class NormalizeAPI:
def __init__(self, mean, std, v2=False):
self.mean = mean
self.std = std
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
else:
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
return datapoint
class ComposeAPI:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, datapoint, **kwargs):
for t in self.transforms:
datapoint = t(datapoint, **kwargs)
return datapoint
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomGrayscale:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
self.Grayscale = T.Grayscale(num_output_channels=3)
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for img in datapoint.frames:
img.data = self.Grayscale(img.data)
return datapoint
for img in datapoint.frames:
if random.random() < self.p:
img.data = self.Grayscale(img.data)
return datapoint
class ColorJitter:
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
self.consistent_transform = consistent_transform
self.brightness = (
brightness
if isinstance(brightness, list)
else [max(0, 1 - brightness), 1 + brightness]
)
self.contrast = (
contrast
if isinstance(contrast, list)
else [max(0, 1 - contrast), 1 + contrast]
)
self.saturation = (
saturation
if isinstance(saturation, list)
else [max(0, 1 - saturation), 1 + saturation]
)
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
# Create a color jitter transformation params
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for img in datapoint.frames:
if not self.consistent_transform:
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img.data = F.adjust_brightness(img.data, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img.data = F.adjust_contrast(img.data, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img.data = F.adjust_saturation(img.data, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img.data = F.adjust_hue(img.data, hue_factor)
return datapoint
class RandomAffine:
def __init__(
self,
degrees,
consistent_transform,
scale=None,
translate=None,
shear=None,
image_mean=(123, 116, 103),
log_warning=True,
num_tentatives=1,
image_interpolation="bicubic",
):
"""
The mask is required for this transform.
if consistent_transform if True, then the same random affine is applied to all frames and masks.
"""
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
self.scale = scale
self.shear = (
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
)
self.translate = translate
self.fill_img = image_mean
self.consistent_transform = consistent_transform
self.log_warning = log_warning
self.num_tentatives = num_tentatives
if image_interpolation == "bicubic":
self.image_interpolation = InterpolationMode.BICUBIC
elif image_interpolation == "bilinear":
self.image_interpolation = InterpolationMode.BILINEAR
else:
raise NotImplementedError
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for _tentative in range(self.num_tentatives):
res = self.transform_datapoint(datapoint)
if res is not None:
return res
if self.log_warning:
logging.warning(
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
)
return datapoint
def transform_datapoint(self, datapoint: VideoDatapoint):
_, height, width = F.get_dimensions(datapoint.frames[0].data)
img_size = [width, height]
if self.consistent_transform:
# Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
for img_idx, img in enumerate(datapoint.frames):
this_masks = [
obj.segment.unsqueeze(0) if obj.segment is not None else None
for obj in img.objects
]
if not self.consistent_transform:
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
transformed_bboxes, transformed_masks = [], []
for i in range(len(img.objects)):
if this_masks[i] is None:
transformed_masks.append(None)
# Dummy bbox for a dummy target
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
else:
transformed_mask = F.affine(
this_masks[i],
*affine_params,
interpolation=InterpolationMode.NEAREST,
fill=0.0,
)
if img_idx == 0 and transformed_mask.max() == 0:
# We are dealing with a video and the object is not visible in the first frame
# Return the datapoint without transformation
return None
transformed_masks.append(transformed_mask.squeeze())
for i in range(len(img.objects)):
img.objects[i].segment = transformed_masks[i]
img.data = F.affine(
img.data,
*affine_params,
interpolation=self.image_interpolation,
fill=self.fill_img,
)
return datapoint
def random_mosaic_frame(
datapoint,
index,
grid_h,
grid_w,
target_grid_y,
target_grid_x,
should_hflip,
):
# Step 1: downsize the images and paste them into a mosaic
image_data = datapoint.frames[index].data
is_pil = isinstance(image_data, PILImage.Image)
if is_pil:
H_im = image_data.height
W_im = image_data.width
image_data_output = PILImage.new("RGB", (W_im, H_im))
else:
H_im = image_data.size(-2)
W_im = image_data.size(-1)
image_data_output = torch.zeros_like(image_data)
downsize_cache = {}
for grid_y in range(grid_h):
for grid_x in range(grid_w):
y_offset_b = grid_y * H_im // grid_h
x_offset_b = grid_x * W_im // grid_w
y_offset_e = (grid_y + 1) * H_im // grid_h
x_offset_e = (grid_x + 1) * W_im // grid_w
H_im_downsize = y_offset_e - y_offset_b
W_im_downsize = x_offset_e - x_offset_b
if (H_im_downsize, W_im_downsize) in downsize_cache:
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
else:
image_data_downsize = F.resize(
image_data,
size=(H_im_downsize, W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
if should_hflip[grid_y, grid_x].item():
image_data_downsize = F.hflip(image_data_downsize)
if is_pil:
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
else:
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
image_data_downsize
)
datapoint.frames[index].data = image_data_output
# Step 2: downsize the masks and paste them into the target grid of the mosaic
for obj in datapoint.frames[index].objects:
if obj.segment is None:
continue
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
segment_output = torch.zeros_like(obj.segment)
target_y_offset_b = target_grid_y * H_im // grid_h
target_x_offset_b = target_grid_x * W_im // grid_w
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
target_H_im_downsize = target_y_offset_e - target_y_offset_b
target_W_im_downsize = target_x_offset_e - target_x_offset_b
segment_downsize = F.resize(
obj.segment[None, None],
size=(target_H_im_downsize, target_W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)[0, 0]
if should_hflip[target_grid_y, target_grid_x].item():
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
segment_output[
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
] = segment_downsize
obj.segment = segment_output
return datapoint
class RandomMosaicVideoAPI:
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
self.prob = prob
self.grid_h = grid_h
self.grid_w = grid_w
self.use_random_hflip = use_random_hflip
def __call__(self, datapoint, **kwargs):
if random.random() > self.prob:
return datapoint
# select a random location to place the target mask in the mosaic
target_grid_y = random.randint(0, self.grid_h - 1)
target_grid_x = random.randint(0, self.grid_w - 1)
# whether to flip each grid in the mosaic horizontally
if self.use_random_hflip:
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
else:
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
for i in range(len(datapoint.frames)):
datapoint = random_mosaic_frame(
datapoint,
i,
grid_h=self.grid_h,
grid_w=self.grid_w,
target_grid_y=target_grid_y,
target_grid_x=target_grid_x,
should_hflip=should_hflip,
)
return datapoint

104
training/dataset/utils.py Normal file
View File

@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
from typing import Iterable
import torch
from torch.utils.data import (
ConcatDataset as TorchConcatDataset,
Dataset,
Subset as TorchSubset,
)
class ConcatDataset(TorchConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__(datasets)
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
def set_epoch(self, epoch: int):
for dataset in self.datasets:
if hasattr(dataset, "epoch"):
dataset.epoch = epoch
if hasattr(dataset, "set_epoch"):
dataset.set_epoch(epoch)
class Subset(TorchSubset):
def __init__(self, dataset, indices) -> None:
super(Subset, self).__init__(dataset, indices)
self.repeat_factors = dataset.repeat_factors[indices]
assert len(indices) == len(self.repeat_factors)
# Adapted from Detectron2
class RepeatFactorWrapper(Dataset):
"""
Thin wrapper around a dataset to implement repeat factor sampling.
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
Set it to uniformly ones to disable repeat factor sampling
"""
def __init__(self, dataset, seed: int = 0):
self.dataset = dataset
self.epoch_ids = None
self._seed = seed
# Split into whole number (_int_part) and fractional (_frac_part) parts.
self._int_part = torch.trunc(dataset.repeat_factors)
self._frac_part = dataset.repeat_factors - self._int_part
def _get_epoch_indices(self, generator):
"""
Create a list of dataset indices (with repeats) to use for one epoch.
Args:
generator (torch.Generator): pseudo random number generator used for
stochastic rounding.
Returns:
torch.Tensor: list of dataset indices to use in one epoch. Each index
is repeated based on its calculated repeat factor.
"""
# Since repeat factors are fractional, we use stochastic rounding so
# that the target repeat factor is achieved in expectation over the
# course of training
rands = torch.rand(len(self._frac_part), generator=generator)
rep_factors = self._int_part + (rands < self._frac_part).float()
# Construct a list of indices in which we repeat images as specified
indices = []
for dataset_index, rep_factor in enumerate(rep_factors):
indices.extend([dataset_index] * int(rep_factor.item()))
return torch.tensor(indices, dtype=torch.int64)
def __len__(self):
if self.epoch_ids is None:
# Here we raise an error instead of returning original len(self.dataset) avoid
# accidentally using unwrapped length. Otherwise it's error-prone since the
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
raise RuntimeError("please call set_epoch first to get wrapped length")
# return len(self.dataset)
return len(self.epoch_ids)
def set_epoch(self, epoch: int):
g = torch.Generator()
g.manual_seed(self._seed + epoch)
self.epoch_ids = self._get_epoch_indices(g)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def __getitem__(self, idx):
if self.epoch_ids is None:
raise RuntimeError(
"Repeat ids haven't been computed. Did you forget to call set_epoch?"
)
return self.dataset[self.epoch_ids[idx]]

View File

@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from copy import deepcopy
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from torchvision.datasets.vision import VisionDataset
from training.dataset.vos_raw_dataset import VOSRawDataset
from training.dataset.vos_sampler import VOSSampler
from training.dataset.vos_segment_loader import JSONSegmentLoader
from training.utils.data_utils import Frame, Object, VideoDatapoint
MAX_RETRIES = 100
class VOSDataset(VisionDataset):
def __init__(
self,
transforms,
training: bool,
video_dataset: VOSRawDataset,
sampler: VOSSampler,
multiplier: int,
always_target=True,
target_segments_available=True,
):
self._transforms = transforms
self.training = training
self.video_dataset = video_dataset
self.sampler = sampler
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.video_dataset)}")
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.always_target = always_target
self.target_segments_available = target_segments_available
def _get_datapoint(self, idx):
for retry in range(MAX_RETRIES):
try:
if isinstance(idx, torch.Tensor):
idx = idx.item()
# sample a video
video, segment_loader = self.video_dataset.get_video(idx)
# sample frames and object indices to be used in a datapoint
sampled_frms_and_objs = self.sampler.sample(
video, segment_loader, epoch=self.curr_epoch
)
break # Succesfully loaded video
except Exception as e:
if self.training:
logging.warning(
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
)
idx = random.randrange(0, len(self.video_dataset))
else:
# Shouldn't fail to load a val video
raise e
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
return datapoint
def construct(self, video, sampled_frms_and_objs, segment_loader):
"""
Constructs a VideoDatapoint sample to pass to transforms
"""
sampled_frames = sampled_frms_and_objs.frames
sampled_object_ids = sampled_frms_and_objs.object_ids
images = []
rgb_images = load_images(sampled_frames)
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
for frame_idx, frame in enumerate(sampled_frames):
w, h = rgb_images[frame_idx].size
images.append(
Frame(
data=rgb_images[frame_idx],
objects=[],
)
)
# We load the gt segments associated with the current frame
if isinstance(segment_loader, JSONSegmentLoader):
segments = segment_loader.load(
frame.frame_idx, obj_ids=sampled_object_ids
)
else:
segments = segment_loader.load(frame.frame_idx)
for obj_id in sampled_object_ids:
# Extract the segment
if obj_id in segments:
assert (
segments[obj_id] is not None
), "None targets are not supported"
# segment is uint8 and remains uint8 throughout the transforms
segment = segments[obj_id].to(torch.uint8)
else:
# There is no target, we either use a zero mask target or drop this object
if not self.always_target:
continue
segment = torch.zeros(h, w, dtype=torch.uint8)
images[frame_idx].objects.append(
Object(
object_id=obj_id,
frame_index=frame.frame_idx,
segment=segment,
)
)
return VideoDatapoint(
frames=images,
video_id=video.video_id,
size=(h, w),
)
def __getitem__(self, idx):
return self._get_datapoint(idx)
def __len__(self):
return len(self.video_dataset)
def load_images(frames):
all_images = []
cache = {}
for frame in frames:
if frame.data is None:
# Load the frame rgb data from file
path = frame.image_path
if path in cache:
all_images.append(deepcopy(all_images[cache[path]]))
continue
with g_pathmgr.open(path, "rb") as fopen:
all_images.append(PILImage.open(fopen).convert("RGB"))
cache[path] = len(all_images) - 1
else:
# The frame rgb data has already been loaded
# Convert it to a PILImage
all_images.append(tensor_2_PIL(frame.data))
return all_images
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
data = data.astype(np.uint8)
return PILImage.fromarray(data)

View File

@@ -0,0 +1,308 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import logging
import os
from dataclasses import dataclass
from typing import List, Optional
import pandas as pd
import torch
from iopath.common.file_io import g_pathmgr
from omegaconf.listconfig import ListConfig
from training.dataset.vos_segment_loader import (
JSONSegmentLoader,
MultiplePNGSegmentLoader,
PalettisedPNGSegmentLoader,
SA1BSegmentLoader,
)
@dataclass
class VOSFrame:
frame_idx: int
image_path: str
data: Optional[torch.Tensor] = None
is_conditioning_only: Optional[bool] = False
@dataclass
class VOSVideo:
video_name: str
video_id: int
frames: List[VOSFrame]
def __len__(self):
return len(self.frames)
class VOSRawDataset:
def __init__(self):
pass
def get_video(self, idx):
raise NotImplementedError()
class PNGRawDataset(VOSRawDataset):
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
sample_rate=1,
is_palette=True,
single_object_mode=False,
truncate_video=-1,
frames_sampling_mult=False,
):
self.img_folder = img_folder
self.gt_folder = gt_folder
self.sample_rate = sample_rate
self.is_palette = is_palette
self.single_object_mode = single_object_mode
self.truncate_video = truncate_video
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
# Read and process excluded files if provided
if excluded_videos_list_txt is not None:
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
else:
excluded_files = []
# Check if it's not in excluded_files
self.video_names = sorted(
[video_name for video_name in subset if video_name not in excluded_files]
)
if self.single_object_mode:
# single object mode
self.video_names = sorted(
[
os.path.join(video_name, obj)
for video_name in self.video_names
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
]
)
if frames_sampling_mult:
video_names_mult = []
for video_name in self.video_names:
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
video_names_mult.extend([video_name] * num_frames)
self.video_names = video_names_mult
def get_video(self, idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[idx]
if self.single_object_mode:
video_frame_root = os.path.join(
self.img_folder, os.path.dirname(video_name)
)
else:
video_frame_root = os.path.join(self.img_folder, video_name)
video_mask_root = os.path.join(self.gt_folder, video_name)
if self.is_palette:
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
else:
segment_loader = MultiplePNGSegmentLoader(
video_mask_root, self.single_object_mode
)
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
if self.truncate_video > 0:
all_frames = all_frames[: self.truncate_video]
frames = []
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
fid = int(os.path.basename(fpath).split(".")[0])
frames.append(VOSFrame(fid, image_path=fpath))
video = VOSVideo(video_name, idx, frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)
class SA1BRawDataset(VOSRawDataset):
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
num_frames=1,
mask_area_frac_thresh=1.1, # no filtering by default
uncertain_iou=-1, # no filtering by default
):
self.img_folder = img_folder
self.gt_folder = gt_folder
self.num_frames = num_frames
self.mask_area_frac_thresh = mask_area_frac_thresh
self.uncertain_iou = uncertain_iou # stability score
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
subset = [
path.split(".")[0] for path in subset if path.endswith(".jpg")
] # remove extension
# Read and process excluded files if provided
if excluded_videos_list_txt is not None:
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
else:
excluded_files = []
# Check if it's not in excluded_files and it exists
self.video_names = [
video_name for video_name in subset if video_name not in excluded_files
]
def get_video(self, idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[idx]
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
segment_loader = SA1BSegmentLoader(
video_mask_path,
mask_area_frac_thresh=self.mask_area_frac_thresh,
video_frame_path=video_frame_path,
uncertain_iou=self.uncertain_iou,
)
frames = []
for frame_idx in range(self.num_frames):
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
video_name = video_name.split("_")[-1] # filename is sa_{int}
# video id needs to be image_id to be able to load correct annotation file during eval
video = VOSVideo(video_name, int(video_name), frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)
class JSONRawDataset(VOSRawDataset):
"""
Dataset where the annotation in the format of SA-V json files
"""
def __init__(
self,
img_folder,
gt_folder,
file_list_txt=None,
excluded_videos_list_txt=None,
sample_rate=1,
rm_unannotated=True,
ann_every=1,
frames_fps=24,
):
self.gt_folder = gt_folder
self.img_folder = img_folder
self.sample_rate = sample_rate
self.rm_unannotated = rm_unannotated
self.ann_every = ann_every
self.frames_fps = frames_fps
# Read and process excluded files if provided
excluded_files = []
if excluded_videos_list_txt is not None:
if isinstance(excluded_videos_list_txt, str):
excluded_videos_lists = [excluded_videos_list_txt]
elif isinstance(excluded_videos_list_txt, ListConfig):
excluded_videos_lists = list(excluded_videos_list_txt)
else:
raise NotImplementedError
for excluded_videos_list_txt in excluded_videos_lists:
with open(excluded_videos_list_txt, "r") as f:
excluded_files.extend(
[os.path.splitext(line.strip())[0] for line in f]
)
excluded_files = set(excluded_files)
# Read the subset defined in file_list_txt
if file_list_txt is not None:
with g_pathmgr.open(file_list_txt, "r") as f:
subset = [os.path.splitext(line.strip())[0] for line in f]
else:
subset = os.listdir(self.img_folder)
self.video_names = sorted(
[video_name for video_name in subset if video_name not in excluded_files]
)
def get_video(self, video_idx):
"""
Given a VOSVideo object, return the mask tensors.
"""
video_name = self.video_names[video_idx]
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
segment_loader = JSONSegmentLoader(
video_json_path=video_json_path,
ann_every=self.ann_every,
frames_fps=self.frames_fps,
)
frame_ids = [
int(os.path.splitext(frame_name)[0])
for frame_name in sorted(
os.listdir(os.path.join(self.img_folder, video_name))
)
]
frames = [
VOSFrame(
frame_id,
image_path=os.path.join(
self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
),
)
for frame_id in frame_ids[:: self.sample_rate]
]
if self.rm_unannotated:
# Eliminate the frames that have not been annotated
valid_frame_ids = [
i * segment_loader.ann_every
for i, annot in enumerate(segment_loader.frame_annots)
if annot is not None and None not in annot
]
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
video = VOSVideo(video_name, video_idx, frames)
return video, segment_loader
def __len__(self):
return len(self.video_names)

View File

@@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from dataclasses import dataclass
from typing import List
from training.dataset.vos_segment_loader import LazySegments
MAX_RETRIES = 1000
@dataclass
class SampledFramesAndObjects:
frames: List[int]
object_ids: List[int]
class VOSSampler:
def __init__(self, sort_frames=True):
# frames are ordered by frame id when sort_frames is True
self.sort_frames = sort_frames
def sample(self, video):
raise NotImplementedError()
class RandomUniformSampler(VOSSampler):
def __init__(
self,
num_frames,
max_num_objects,
reverse_time_prob=0.0,
):
self.num_frames = num_frames
self.max_num_objects = max_num_objects
self.reverse_time_prob = reverse_time_prob
def sample(self, video, segment_loader, epoch=None):
for retry in range(MAX_RETRIES):
if len(video.frames) < self.num_frames:
raise Exception(
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
)
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
frames = [video.frames[start + step] for step in range(self.num_frames)]
if random.uniform(0, 1) < self.reverse_time_prob:
# Reverse time
frames = frames[::-1]
# Get first frame object ids
visible_object_ids = []
loaded_segms = segment_loader.load(frames[0].frame_idx)
if isinstance(loaded_segms, LazySegments):
# LazySegments for SA1BRawDataset
visible_object_ids = list(loaded_segms.keys())
else:
for object_id, segment in segment_loader.load(
frames[0].frame_idx
).items():
if segment.sum():
visible_object_ids.append(object_id)
# First frame needs to have at least a target to track
if len(visible_object_ids) > 0:
break
if retry >= MAX_RETRIES - 1:
raise Exception("No visible objects")
object_ids = random.sample(
visible_object_ids,
min(len(visible_object_ids), self.max_num_objects),
)
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
class EvalSampler(VOSSampler):
"""
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
"""
def __init__(
self,
):
super().__init__()
def sample(self, video, segment_loader, epoch=None):
"""
Sampling all the frames and all the objects
"""
if self.sort_frames:
# ordered by frame id
frames = sorted(video.frames, key=lambda x: x.frame_idx)
else:
# use the original order
frames = video.frames
object_ids = segment_loader.load(frames[0].frame_idx).keys()
if len(object_ids) == 0:
raise Exception("First frame of the video has no objects")
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)

View File

@@ -0,0 +1,300 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import json
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image as PILImage
try:
from pycocotools import mask as mask_utils
except:
pass
class JSONSegmentLoader:
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
# Annotations in the json are provided every ann_every th frame
self.ann_every = ann_every
# Ids of the objects to consider when sampling this video
self.valid_obj_ids = valid_obj_ids
with open(video_json_path, "r") as f:
data = json.load(f)
if isinstance(data, list):
self.frame_annots = data
elif isinstance(data, dict):
masklet_field_name = "masklet" if "masklet" in data else "masks"
self.frame_annots = data[masklet_field_name]
if "fps" in data:
if isinstance(data["fps"], list):
annotations_fps = int(data["fps"][0])
else:
annotations_fps = int(data["fps"])
assert frames_fps % annotations_fps == 0
self.ann_every = frames_fps // annotations_fps
else:
raise NotImplementedError
def load(self, frame_id, obj_ids=None):
assert frame_id % self.ann_every == 0
rle_mask = self.frame_annots[frame_id // self.ann_every]
valid_objs_ids = set(range(len(rle_mask)))
if self.valid_obj_ids is not None:
# Remove the masklets that have been filtered out for this video
valid_objs_ids &= set(self.valid_obj_ids)
if obj_ids is not None:
# Only keep the objects that have been sampled
valid_objs_ids &= set(obj_ids)
valid_objs_ids = sorted(list(valid_objs_ids))
# Construct rle_masks_filtered that only contains the rle masks we are interested in
id_2_idx = {}
rle_mask_filtered = []
for obj_id in valid_objs_ids:
if rle_mask[obj_id] is not None:
id_2_idx[obj_id] = len(rle_mask_filtered)
rle_mask_filtered.append(rle_mask[obj_id])
else:
id_2_idx[obj_id] = None
# Decode the masks
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
2, 0, 1
) # num_obj, h, w
segments = {}
for obj_id in valid_objs_ids:
if id_2_idx[obj_id] is None:
segments[obj_id] = None
else:
idx = id_2_idx[obj_id]
segments[obj_id] = raw_segments[idx]
return segments
def get_valid_obj_frames_ids(self, num_frames_min=None):
# For each object, find all the frames with a valid (not None) mask
num_objects = len(self.frame_annots[0])
# The result dict associates each obj_id with the id of its valid frames
res = {obj_id: [] for obj_id in range(num_objects)}
for annot_idx, annot in enumerate(self.frame_annots):
for obj_id in range(num_objects):
if annot[obj_id] is not None:
res[obj_id].append(int(annot_idx * self.ann_every))
if num_frames_min is not None:
# Remove masklets that have less than num_frames_min valid masks
for obj_id, valid_frames in list(res.items()):
if len(valid_frames) < num_frames_min:
res.pop(obj_id)
return res
class PalettisedPNGSegmentLoader:
def __init__(self, video_png_root):
"""
SegmentLoader for datasets with masks stored as palettised PNGs.
video_png_root: the folder contains all the masks stored in png
"""
self.video_png_root = video_png_root
# build a mapping from frame id to their PNG mask path
# note that in some datasets, the PNG paths could have more
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
png_filenames = os.listdir(self.video_png_root)
self.frame_id_to_png_filename = {}
for filename in png_filenames:
frame_id, _ = os.path.splitext(filename)
self.frame_id_to_png_filename[int(frame_id)] = filename
def load(self, frame_id):
"""
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
# check the path
mask_path = os.path.join(
self.video_png_root, self.frame_id_to_png_filename[frame_id]
)
# load the mask
masks = PILImage.open(mask_path).convert("P")
masks = np.array(masks)
object_id = pd.unique(masks.flatten())
object_id = object_id[object_id != 0] # remove background (0)
# convert into N binary segmentation masks
binary_segments = {}
for i in object_id:
bs = masks == i
binary_segments[i] = torch.from_numpy(bs)
return binary_segments
def __len__(self):
return
class MultiplePNGSegmentLoader:
def __init__(self, video_png_root, single_object_mode=False):
"""
video_png_root: the folder contains all the masks stored in png
single_object_mode: whether to load only a single object at a time
"""
self.video_png_root = video_png_root
self.single_object_mode = single_object_mode
# read a mask to know the resolution of the video
if self.single_object_mode:
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
else:
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
tmp_mask = np.array(PILImage.open(tmp_mask_path))
self.H = tmp_mask.shape[0]
self.W = tmp_mask.shape[1]
if self.single_object_mode:
self.obj_id = (
int(video_png_root.split("/")[-1]) + 1
) # offset by 1 as bg is 0
else:
self.obj_id = None
def load(self, frame_id):
if self.single_object_mode:
return self._load_single_png(frame_id)
else:
return self._load_multiple_pngs(frame_id)
def _load_single_png(self, frame_id):
"""
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
binary_segments = {}
if os.path.exists(mask_path):
mask = np.array(PILImage.open(mask_path))
else:
# if png doesn't exist, empty mask
mask = np.zeros((self.H, self.W), dtype=bool)
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
return binary_segments
def _load_multiple_pngs(self, frame_id):
"""
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
Args:
frame_id: int, define the mask path
Return:
binary_segments: dict
"""
# get the path
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
num_objects = len(all_objects)
assert num_objects > 0
# load the masks
binary_segments = {}
for obj_folder in all_objects:
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
obj_id = int(obj_folder.split("/")[-1])
obj_id = obj_id + 1 # offset 1 as bg is 0
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
if os.path.exists(mask_path):
mask = np.array(PILImage.open(mask_path))
else:
mask = np.zeros((self.H, self.W), dtype=bool)
binary_segments[obj_id] = torch.from_numpy(mask > 0)
return binary_segments
def __len__(self):
return
class LazySegments:
"""
Only decodes segments that are actually used.
"""
def __init__(self):
self.segments = {}
self.cache = {}
def __setitem__(self, key, item):
self.segments[key] = item
def __getitem__(self, key):
if key in self.cache:
return self.cache[key]
rle = self.segments[key]
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
self.cache[key] = mask
return mask
def __contains__(self, key):
return key in self.segments
def __len__(self):
return len(self.segments)
def keys(self):
return self.segments.keys()
class SA1BSegmentLoader:
def __init__(
self,
video_mask_path,
mask_area_frac_thresh=1.1,
video_frame_path=None,
uncertain_iou=-1,
):
with open(video_mask_path, "r") as f:
self.frame_annots = json.load(f)
if mask_area_frac_thresh <= 1.0:
# Lazily read frame
orig_w, orig_h = PILImage.open(video_frame_path).size
area = orig_w * orig_h
self.frame_annots = self.frame_annots["annotations"]
rle_masks = []
for frame_annot in self.frame_annots:
if not frame_annot["area"] > 0:
continue
if ("uncertain_iou" in frame_annot) and (
frame_annot["uncertain_iou"] < uncertain_iou
):
# uncertain_iou is stability score
continue
if (
mask_area_frac_thresh <= 1.0
and (frame_annot["area"] / area) >= mask_area_frac_thresh
):
continue
rle_masks.append(frame_annot["segmentation"])
self.segments = LazySegments()
for i, rle in enumerate(rle_masks):
self.segments[i] = rle
def load(self, frame_idx):
return self.segments