SAM2.1
SAM2.1 checkpoints + training code + Demo
This commit is contained in:
5
training/dataset/__init__.py
Normal file
5
training/dataset/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
180
training/dataset/sam2_datasets.py
Normal file
180
training/dataset/sam2_datasets.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
|
||||
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class MixedDataLoader:
|
||||
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
|
||||
"""
|
||||
Args:
|
||||
dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
|
||||
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
|
||||
|
||||
"""
|
||||
assert len(dataloaders) == mixing_prob.shape[0]
|
||||
self.dataloaders = dataloaders
|
||||
self.mixing_prob = mixing_prob
|
||||
# Iterator state
|
||||
self._iter_dls = None
|
||||
self._iter_mixing_prob = None
|
||||
self.random_generator = torch.Generator()
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(d) for d in self.dataloaders])
|
||||
|
||||
def __iter__(self):
|
||||
# Synchronize dataloader seeds
|
||||
self.random_generator.manual_seed(42)
|
||||
self._iter_dls = [iter(loader) for loader in self.dataloaders]
|
||||
self._iter_mixing_prob = self.mixing_prob.clone()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
|
||||
"""
|
||||
if self._iter_dls is None:
|
||||
raise TypeError(f"{type(self).__name__} object is not an iterator")
|
||||
|
||||
while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
|
||||
dataset_idx = self._iter_mixing_prob.multinomial(
|
||||
1, generator=self.random_generator
|
||||
).item()
|
||||
try:
|
||||
item = next(self._iter_dls[dataset_idx])
|
||||
return item
|
||||
except StopIteration:
|
||||
# No more iterations for this dataset, set it's mixing probability to zero and try again.
|
||||
self._iter_mixing_prob[dataset_idx] = 0
|
||||
except Exception as e:
|
||||
# log and raise any other unexpected error.
|
||||
logging.error(e)
|
||||
raise e
|
||||
|
||||
# Exhausted all iterators
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class TorchTrainMixedDataset:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: List[Dataset],
|
||||
batch_sizes: List[int],
|
||||
num_workers: int,
|
||||
shuffle: bool,
|
||||
pin_memory: bool,
|
||||
drop_last: bool,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
worker_init_fn: Optional[Callable] = None,
|
||||
phases_per_epoch: int = 1,
|
||||
dataset_prob: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
datasets (List[Dataset]): List of Datasets to be mixed.
|
||||
batch_sizes (List[int]): Batch sizes for each dataset in the list.
|
||||
num_workers (int): Number of workers per dataloader.
|
||||
shuffle (bool): Whether or not to shuffle data.
|
||||
pin_memory (bool): If True, use pinned memory when loading tensors from disk.
|
||||
drop_last (bool): Whether or not to drop the last batch of data.
|
||||
collate_fn (Callable): Function to merge a list of samples into a mini-batch.
|
||||
worker_init_fn (Callable): Function to init each dataloader worker.
|
||||
phases_per_epoch (int): Number of phases per epoch.
|
||||
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
|
||||
"""
|
||||
|
||||
self.datasets = datasets
|
||||
self.batch_sizes = batch_sizes
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.collate_fn = collate_fn
|
||||
self.worker_init_fn = worker_init_fn
|
||||
assert len(self.datasets) > 0
|
||||
for dataset in self.datasets:
|
||||
assert not isinstance(dataset, IterableDataset), "Not supported"
|
||||
# `RepeatFactorWrapper` requires calling set_epoch first to get its length
|
||||
self._set_dataset_epoch(dataset, 0)
|
||||
self.phases_per_epoch = phases_per_epoch
|
||||
self.chunks = [None] * len(datasets)
|
||||
if dataset_prob is None:
|
||||
# If not provided, assign each dataset a probability proportional to its length.
|
||||
dataset_lens = [
|
||||
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
|
||||
for d, bs in zip(datasets, batch_sizes)
|
||||
]
|
||||
total_len = sum(dataset_lens)
|
||||
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
|
||||
else:
|
||||
assert len(dataset_prob) == len(datasets)
|
||||
dataset_prob = torch.tensor(dataset_prob)
|
||||
|
||||
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
|
||||
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
|
||||
self.dataset_prob = dataset_prob
|
||||
|
||||
def _set_dataset_epoch(self, dataset, epoch: int) -> None:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
def get_loader(self, epoch) -> Iterable:
|
||||
dataloaders = []
|
||||
for d_idx, (dataset, batch_size) in enumerate(
|
||||
zip(self.datasets, self.batch_sizes)
|
||||
):
|
||||
if self.phases_per_epoch > 1:
|
||||
# Major epoch that looops over entire dataset
|
||||
# len(main_epoch) == phases_per_epoch * len(epoch)
|
||||
main_epoch = epoch // self.phases_per_epoch
|
||||
|
||||
# Phase with in the main epoch
|
||||
local_phase = epoch % self.phases_per_epoch
|
||||
|
||||
# Start of new data-epoch or job is resumed after preemtion.
|
||||
if local_phase == 0 or self.chunks[d_idx] is None:
|
||||
# set seed for dataset epoch
|
||||
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
|
||||
self._set_dataset_epoch(dataset, main_epoch)
|
||||
|
||||
# Separate random generator for subset sampling
|
||||
g = torch.Generator()
|
||||
g.manual_seed(main_epoch)
|
||||
self.chunks[d_idx] = torch.chunk(
|
||||
torch.randperm(len(dataset), generator=g),
|
||||
self.phases_per_epoch,
|
||||
)
|
||||
|
||||
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
|
||||
else:
|
||||
self._set_dataset_epoch(dataset, epoch)
|
||||
|
||||
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
|
||||
dataloaders.append(
|
||||
DataLoader(
|
||||
dataset,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=self.pin_memory,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
)
|
||||
return MixedDataLoader(dataloaders, self.dataset_prob)
|
528
training/dataset/transforms.py
Normal file
528
training/dataset/transforms.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms.v2.functional as Fv2
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from training.utils.data_utils import VideoDatapoint
|
||||
|
||||
|
||||
def hflip(datapoint, index):
|
||||
|
||||
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
obj.segment = F.hflip(obj.segment)
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = max_size * min_original_size / max_original_size
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = int(round(size))
|
||||
oh = int(round(size * h / w))
|
||||
else:
|
||||
oh = int(round(size))
|
||||
ow = int(round(size * w / h))
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
|
||||
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
if square:
|
||||
size = size, size
|
||||
else:
|
||||
cur_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
size = get_size(cur_size, size, max_size)
|
||||
|
||||
old_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
if v2:
|
||||
datapoint.frames[index].data = Fv2.resize(
|
||||
datapoint.frames[index].data, size, antialias=True
|
||||
)
|
||||
else:
|
||||
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
|
||||
|
||||
new_size = (
|
||||
datapoint.frames[index].data.size()[-2:][::-1]
|
||||
if v2
|
||||
else datapoint.frames[index].data.size
|
||||
)
|
||||
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
|
||||
|
||||
h, w = size
|
||||
datapoint.frames[index].size = (h, w)
|
||||
return datapoint
|
||||
|
||||
|
||||
def pad(datapoint, index, padding, v2=False):
|
||||
old_h, old_w = datapoint.frames[index].size
|
||||
h, w = old_h, old_w
|
||||
if len(padding) == 2:
|
||||
# assumes that we only pad on the bottom right corners
|
||||
datapoint.frames[index].data = F.pad(
|
||||
datapoint.frames[index].data, (0, 0, padding[0], padding[1])
|
||||
)
|
||||
h += padding[1]
|
||||
w += padding[0]
|
||||
else:
|
||||
# left, top, right, bottom
|
||||
datapoint.frames[index].data = F.pad(
|
||||
datapoint.frames[index].data,
|
||||
(padding[0], padding[1], padding[2], padding[3]),
|
||||
)
|
||||
h += padding[1] + padding[3]
|
||||
w += padding[0] + padding[2]
|
||||
|
||||
datapoint.frames[index].size = (h, w)
|
||||
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is not None:
|
||||
if v2:
|
||||
if len(padding) == 2:
|
||||
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
obj.segment = Fv2.pad(obj.segment, tuple(padding))
|
||||
else:
|
||||
if len(padding) == 2:
|
||||
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
||||
else:
|
||||
obj.segment = F.pad(obj.segment, tuple(padding))
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomHorizontalFlip:
|
||||
def __init__(self, consistent_transform, p=0.5):
|
||||
self.p = p
|
||||
self.consistent_transform = consistent_transform
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
if random.random() < self.p:
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = hflip(datapoint, i)
|
||||
return datapoint
|
||||
for i in range(len(datapoint.frames)):
|
||||
if random.random() < self.p:
|
||||
datapoint = hflip(datapoint, i)
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomResizeAPI:
|
||||
def __init__(
|
||||
self, sizes, consistent_transform, max_size=None, square=False, v2=False
|
||||
):
|
||||
if isinstance(sizes, int):
|
||||
sizes = (sizes,)
|
||||
assert isinstance(sizes, Iterable)
|
||||
self.sizes = list(sizes)
|
||||
self.max_size = max_size
|
||||
self.square = square
|
||||
self.consistent_transform = consistent_transform
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
size = random.choice(self.sizes)
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = resize(
|
||||
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
||||
)
|
||||
return datapoint
|
||||
for i in range(len(datapoint.frames)):
|
||||
size = random.choice(self.sizes)
|
||||
datapoint = resize(
|
||||
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
||||
)
|
||||
return datapoint
|
||||
|
||||
|
||||
class ToTensorAPI:
|
||||
def __init__(self, v2=False):
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for img in datapoint.frames:
|
||||
if self.v2:
|
||||
img.data = Fv2.to_image_tensor(img.data)
|
||||
else:
|
||||
img.data = F.to_tensor(img.data)
|
||||
return datapoint
|
||||
|
||||
|
||||
class NormalizeAPI:
|
||||
def __init__(self, mean, std, v2=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.v2 = v2
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for img in datapoint.frames:
|
||||
if self.v2:
|
||||
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
|
||||
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
|
||||
else:
|
||||
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class ComposeAPI:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
for t in self.transforms:
|
||||
datapoint = t(datapoint, **kwargs)
|
||||
return datapoint
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class RandomGrayscale:
|
||||
def __init__(self, consistent_transform, p=0.5):
|
||||
self.p = p
|
||||
self.consistent_transform = consistent_transform
|
||||
self.Grayscale = T.Grayscale(num_output_channels=3)
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
if random.random() < self.p:
|
||||
for img in datapoint.frames:
|
||||
img.data = self.Grayscale(img.data)
|
||||
return datapoint
|
||||
for img in datapoint.frames:
|
||||
if random.random() < self.p:
|
||||
img.data = self.Grayscale(img.data)
|
||||
return datapoint
|
||||
|
||||
|
||||
class ColorJitter:
|
||||
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
|
||||
self.consistent_transform = consistent_transform
|
||||
self.brightness = (
|
||||
brightness
|
||||
if isinstance(brightness, list)
|
||||
else [max(0, 1 - brightness), 1 + brightness]
|
||||
)
|
||||
self.contrast = (
|
||||
contrast
|
||||
if isinstance(contrast, list)
|
||||
else [max(0, 1 - contrast), 1 + contrast]
|
||||
)
|
||||
self.saturation = (
|
||||
saturation
|
||||
if isinstance(saturation, list)
|
||||
else [max(0, 1 - saturation), 1 + saturation]
|
||||
)
|
||||
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
if self.consistent_transform:
|
||||
# Create a color jitter transformation params
|
||||
(
|
||||
fn_idx,
|
||||
brightness_factor,
|
||||
contrast_factor,
|
||||
saturation_factor,
|
||||
hue_factor,
|
||||
) = T.ColorJitter.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
for img in datapoint.frames:
|
||||
if not self.consistent_transform:
|
||||
(
|
||||
fn_idx,
|
||||
brightness_factor,
|
||||
contrast_factor,
|
||||
saturation_factor,
|
||||
hue_factor,
|
||||
) = T.ColorJitter.get_params(
|
||||
self.brightness, self.contrast, self.saturation, self.hue
|
||||
)
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness_factor is not None:
|
||||
img.data = F.adjust_brightness(img.data, brightness_factor)
|
||||
elif fn_id == 1 and contrast_factor is not None:
|
||||
img.data = F.adjust_contrast(img.data, contrast_factor)
|
||||
elif fn_id == 2 and saturation_factor is not None:
|
||||
img.data = F.adjust_saturation(img.data, saturation_factor)
|
||||
elif fn_id == 3 and hue_factor is not None:
|
||||
img.data = F.adjust_hue(img.data, hue_factor)
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomAffine:
|
||||
def __init__(
|
||||
self,
|
||||
degrees,
|
||||
consistent_transform,
|
||||
scale=None,
|
||||
translate=None,
|
||||
shear=None,
|
||||
image_mean=(123, 116, 103),
|
||||
log_warning=True,
|
||||
num_tentatives=1,
|
||||
image_interpolation="bicubic",
|
||||
):
|
||||
"""
|
||||
The mask is required for this transform.
|
||||
if consistent_transform if True, then the same random affine is applied to all frames and masks.
|
||||
"""
|
||||
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
|
||||
self.scale = scale
|
||||
self.shear = (
|
||||
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
|
||||
)
|
||||
self.translate = translate
|
||||
self.fill_img = image_mean
|
||||
self.consistent_transform = consistent_transform
|
||||
self.log_warning = log_warning
|
||||
self.num_tentatives = num_tentatives
|
||||
|
||||
if image_interpolation == "bicubic":
|
||||
self.image_interpolation = InterpolationMode.BICUBIC
|
||||
elif image_interpolation == "bilinear":
|
||||
self.image_interpolation = InterpolationMode.BILINEAR
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
||||
for _tentative in range(self.num_tentatives):
|
||||
res = self.transform_datapoint(datapoint)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if self.log_warning:
|
||||
logging.warning(
|
||||
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
|
||||
)
|
||||
return datapoint
|
||||
|
||||
def transform_datapoint(self, datapoint: VideoDatapoint):
|
||||
_, height, width = F.get_dimensions(datapoint.frames[0].data)
|
||||
img_size = [width, height]
|
||||
|
||||
if self.consistent_transform:
|
||||
# Create a random affine transformation
|
||||
affine_params = T.RandomAffine.get_params(
|
||||
degrees=self.degrees,
|
||||
translate=self.translate,
|
||||
scale_ranges=self.scale,
|
||||
shears=self.shear,
|
||||
img_size=img_size,
|
||||
)
|
||||
|
||||
for img_idx, img in enumerate(datapoint.frames):
|
||||
this_masks = [
|
||||
obj.segment.unsqueeze(0) if obj.segment is not None else None
|
||||
for obj in img.objects
|
||||
]
|
||||
if not self.consistent_transform:
|
||||
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
|
||||
affine_params = T.RandomAffine.get_params(
|
||||
degrees=self.degrees,
|
||||
translate=self.translate,
|
||||
scale_ranges=self.scale,
|
||||
shears=self.shear,
|
||||
img_size=img_size,
|
||||
)
|
||||
|
||||
transformed_bboxes, transformed_masks = [], []
|
||||
for i in range(len(img.objects)):
|
||||
if this_masks[i] is None:
|
||||
transformed_masks.append(None)
|
||||
# Dummy bbox for a dummy target
|
||||
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
|
||||
else:
|
||||
transformed_mask = F.affine(
|
||||
this_masks[i],
|
||||
*affine_params,
|
||||
interpolation=InterpolationMode.NEAREST,
|
||||
fill=0.0,
|
||||
)
|
||||
if img_idx == 0 and transformed_mask.max() == 0:
|
||||
# We are dealing with a video and the object is not visible in the first frame
|
||||
# Return the datapoint without transformation
|
||||
return None
|
||||
transformed_masks.append(transformed_mask.squeeze())
|
||||
|
||||
for i in range(len(img.objects)):
|
||||
img.objects[i].segment = transformed_masks[i]
|
||||
|
||||
img.data = F.affine(
|
||||
img.data,
|
||||
*affine_params,
|
||||
interpolation=self.image_interpolation,
|
||||
fill=self.fill_img,
|
||||
)
|
||||
return datapoint
|
||||
|
||||
|
||||
def random_mosaic_frame(
|
||||
datapoint,
|
||||
index,
|
||||
grid_h,
|
||||
grid_w,
|
||||
target_grid_y,
|
||||
target_grid_x,
|
||||
should_hflip,
|
||||
):
|
||||
# Step 1: downsize the images and paste them into a mosaic
|
||||
image_data = datapoint.frames[index].data
|
||||
is_pil = isinstance(image_data, PILImage.Image)
|
||||
if is_pil:
|
||||
H_im = image_data.height
|
||||
W_im = image_data.width
|
||||
image_data_output = PILImage.new("RGB", (W_im, H_im))
|
||||
else:
|
||||
H_im = image_data.size(-2)
|
||||
W_im = image_data.size(-1)
|
||||
image_data_output = torch.zeros_like(image_data)
|
||||
|
||||
downsize_cache = {}
|
||||
for grid_y in range(grid_h):
|
||||
for grid_x in range(grid_w):
|
||||
y_offset_b = grid_y * H_im // grid_h
|
||||
x_offset_b = grid_x * W_im // grid_w
|
||||
y_offset_e = (grid_y + 1) * H_im // grid_h
|
||||
x_offset_e = (grid_x + 1) * W_im // grid_w
|
||||
H_im_downsize = y_offset_e - y_offset_b
|
||||
W_im_downsize = x_offset_e - x_offset_b
|
||||
|
||||
if (H_im_downsize, W_im_downsize) in downsize_cache:
|
||||
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
|
||||
else:
|
||||
image_data_downsize = F.resize(
|
||||
image_data,
|
||||
size=(H_im_downsize, W_im_downsize),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
antialias=True, # antialiasing for downsizing
|
||||
)
|
||||
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
|
||||
if should_hflip[grid_y, grid_x].item():
|
||||
image_data_downsize = F.hflip(image_data_downsize)
|
||||
|
||||
if is_pil:
|
||||
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
|
||||
else:
|
||||
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
|
||||
image_data_downsize
|
||||
)
|
||||
|
||||
datapoint.frames[index].data = image_data_output
|
||||
|
||||
# Step 2: downsize the masks and paste them into the target grid of the mosaic
|
||||
for obj in datapoint.frames[index].objects:
|
||||
if obj.segment is None:
|
||||
continue
|
||||
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
|
||||
segment_output = torch.zeros_like(obj.segment)
|
||||
|
||||
target_y_offset_b = target_grid_y * H_im // grid_h
|
||||
target_x_offset_b = target_grid_x * W_im // grid_w
|
||||
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
|
||||
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
|
||||
target_H_im_downsize = target_y_offset_e - target_y_offset_b
|
||||
target_W_im_downsize = target_x_offset_e - target_x_offset_b
|
||||
|
||||
segment_downsize = F.resize(
|
||||
obj.segment[None, None],
|
||||
size=(target_H_im_downsize, target_W_im_downsize),
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
antialias=True, # antialiasing for downsizing
|
||||
)[0, 0]
|
||||
if should_hflip[target_grid_y, target_grid_x].item():
|
||||
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
|
||||
|
||||
segment_output[
|
||||
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
|
||||
] = segment_downsize
|
||||
obj.segment = segment_output
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class RandomMosaicVideoAPI:
|
||||
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
|
||||
self.prob = prob
|
||||
self.grid_h = grid_h
|
||||
self.grid_w = grid_w
|
||||
self.use_random_hflip = use_random_hflip
|
||||
|
||||
def __call__(self, datapoint, **kwargs):
|
||||
if random.random() > self.prob:
|
||||
return datapoint
|
||||
|
||||
# select a random location to place the target mask in the mosaic
|
||||
target_grid_y = random.randint(0, self.grid_h - 1)
|
||||
target_grid_x = random.randint(0, self.grid_w - 1)
|
||||
# whether to flip each grid in the mosaic horizontally
|
||||
if self.use_random_hflip:
|
||||
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
|
||||
else:
|
||||
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
|
||||
for i in range(len(datapoint.frames)):
|
||||
datapoint = random_mosaic_frame(
|
||||
datapoint,
|
||||
i,
|
||||
grid_h=self.grid_h,
|
||||
grid_w=self.grid_w,
|
||||
target_grid_y=target_grid_y,
|
||||
target_grid_x=target_grid_x,
|
||||
should_hflip=should_hflip,
|
||||
)
|
||||
|
||||
return datapoint
|
104
training/dataset/utils.py
Normal file
104
training/dataset/utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (
|
||||
ConcatDataset as TorchConcatDataset,
|
||||
Dataset,
|
||||
Subset as TorchSubset,
|
||||
)
|
||||
|
||||
|
||||
class ConcatDataset(TorchConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
|
||||
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
for dataset in self.datasets:
|
||||
if hasattr(dataset, "epoch"):
|
||||
dataset.epoch = epoch
|
||||
if hasattr(dataset, "set_epoch"):
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
class Subset(TorchSubset):
|
||||
def __init__(self, dataset, indices) -> None:
|
||||
super(Subset, self).__init__(dataset, indices)
|
||||
|
||||
self.repeat_factors = dataset.repeat_factors[indices]
|
||||
assert len(indices) == len(self.repeat_factors)
|
||||
|
||||
|
||||
# Adapted from Detectron2
|
||||
class RepeatFactorWrapper(Dataset):
|
||||
"""
|
||||
Thin wrapper around a dataset to implement repeat factor sampling.
|
||||
The underlying dataset must have a repeat_factors member to indicate the per-image factor.
|
||||
Set it to uniformly ones to disable repeat factor sampling
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, seed: int = 0):
|
||||
self.dataset = dataset
|
||||
self.epoch_ids = None
|
||||
self._seed = seed
|
||||
|
||||
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
||||
self._int_part = torch.trunc(dataset.repeat_factors)
|
||||
self._frac_part = dataset.repeat_factors - self._int_part
|
||||
|
||||
def _get_epoch_indices(self, generator):
|
||||
"""
|
||||
Create a list of dataset indices (with repeats) to use for one epoch.
|
||||
|
||||
Args:
|
||||
generator (torch.Generator): pseudo random number generator used for
|
||||
stochastic rounding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
||||
is repeated based on its calculated repeat factor.
|
||||
"""
|
||||
# Since repeat factors are fractional, we use stochastic rounding so
|
||||
# that the target repeat factor is achieved in expectation over the
|
||||
# course of training
|
||||
rands = torch.rand(len(self._frac_part), generator=generator)
|
||||
rep_factors = self._int_part + (rands < self._frac_part).float()
|
||||
# Construct a list of indices in which we repeat images as specified
|
||||
indices = []
|
||||
for dataset_index, rep_factor in enumerate(rep_factors):
|
||||
indices.extend([dataset_index] * int(rep_factor.item()))
|
||||
return torch.tensor(indices, dtype=torch.int64)
|
||||
|
||||
def __len__(self):
|
||||
if self.epoch_ids is None:
|
||||
# Here we raise an error instead of returning original len(self.dataset) avoid
|
||||
# accidentally using unwrapped length. Otherwise it's error-prone since the
|
||||
# length changes to `len(self.epoch_ids)`changes after set_epoch is called.
|
||||
raise RuntimeError("please call set_epoch first to get wrapped length")
|
||||
# return len(self.dataset)
|
||||
|
||||
return len(self.epoch_ids)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed + epoch)
|
||||
self.epoch_ids = self._get_epoch_indices(g)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.epoch_ids is None:
|
||||
raise RuntimeError(
|
||||
"Repeat ids haven't been computed. Did you forget to call set_epoch?"
|
||||
)
|
||||
|
||||
return self.dataset[self.epoch_ids[idx]]
|
162
training/dataset/vos_dataset.py
Normal file
162
training/dataset/vos_dataset.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from PIL import Image as PILImage
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from training.dataset.vos_raw_dataset import VOSRawDataset
|
||||
from training.dataset.vos_sampler import VOSSampler
|
||||
from training.dataset.vos_segment_loader import JSONSegmentLoader
|
||||
|
||||
from training.utils.data_utils import Frame, Object, VideoDatapoint
|
||||
|
||||
MAX_RETRIES = 100
|
||||
|
||||
|
||||
class VOSDataset(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
transforms,
|
||||
training: bool,
|
||||
video_dataset: VOSRawDataset,
|
||||
sampler: VOSSampler,
|
||||
multiplier: int,
|
||||
always_target=True,
|
||||
target_segments_available=True,
|
||||
):
|
||||
self._transforms = transforms
|
||||
self.training = training
|
||||
self.video_dataset = video_dataset
|
||||
self.sampler = sampler
|
||||
|
||||
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
|
||||
self.repeat_factors *= multiplier
|
||||
print(f"Raw dataset length = {len(self.video_dataset)}")
|
||||
|
||||
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
||||
self.always_target = always_target
|
||||
self.target_segments_available = target_segments_available
|
||||
|
||||
def _get_datapoint(self, idx):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
try:
|
||||
if isinstance(idx, torch.Tensor):
|
||||
idx = idx.item()
|
||||
# sample a video
|
||||
video, segment_loader = self.video_dataset.get_video(idx)
|
||||
# sample frames and object indices to be used in a datapoint
|
||||
sampled_frms_and_objs = self.sampler.sample(
|
||||
video, segment_loader, epoch=self.curr_epoch
|
||||
)
|
||||
break # Succesfully loaded video
|
||||
except Exception as e:
|
||||
if self.training:
|
||||
logging.warning(
|
||||
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
|
||||
)
|
||||
idx = random.randrange(0, len(self.video_dataset))
|
||||
else:
|
||||
# Shouldn't fail to load a val video
|
||||
raise e
|
||||
|
||||
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
|
||||
for transform in self._transforms:
|
||||
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
||||
return datapoint
|
||||
|
||||
def construct(self, video, sampled_frms_and_objs, segment_loader):
|
||||
"""
|
||||
Constructs a VideoDatapoint sample to pass to transforms
|
||||
"""
|
||||
sampled_frames = sampled_frms_and_objs.frames
|
||||
sampled_object_ids = sampled_frms_and_objs.object_ids
|
||||
|
||||
images = []
|
||||
rgb_images = load_images(sampled_frames)
|
||||
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
|
||||
for frame_idx, frame in enumerate(sampled_frames):
|
||||
w, h = rgb_images[frame_idx].size
|
||||
images.append(
|
||||
Frame(
|
||||
data=rgb_images[frame_idx],
|
||||
objects=[],
|
||||
)
|
||||
)
|
||||
# We load the gt segments associated with the current frame
|
||||
if isinstance(segment_loader, JSONSegmentLoader):
|
||||
segments = segment_loader.load(
|
||||
frame.frame_idx, obj_ids=sampled_object_ids
|
||||
)
|
||||
else:
|
||||
segments = segment_loader.load(frame.frame_idx)
|
||||
for obj_id in sampled_object_ids:
|
||||
# Extract the segment
|
||||
if obj_id in segments:
|
||||
assert (
|
||||
segments[obj_id] is not None
|
||||
), "None targets are not supported"
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
segment = segments[obj_id].to(torch.uint8)
|
||||
else:
|
||||
# There is no target, we either use a zero mask target or drop this object
|
||||
if not self.always_target:
|
||||
continue
|
||||
segment = torch.zeros(h, w, dtype=torch.uint8)
|
||||
|
||||
images[frame_idx].objects.append(
|
||||
Object(
|
||||
object_id=obj_id,
|
||||
frame_index=frame.frame_idx,
|
||||
segment=segment,
|
||||
)
|
||||
)
|
||||
return VideoDatapoint(
|
||||
frames=images,
|
||||
video_id=video.video_id,
|
||||
size=(h, w),
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._get_datapoint(idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_dataset)
|
||||
|
||||
|
||||
def load_images(frames):
|
||||
all_images = []
|
||||
cache = {}
|
||||
for frame in frames:
|
||||
if frame.data is None:
|
||||
# Load the frame rgb data from file
|
||||
path = frame.image_path
|
||||
if path in cache:
|
||||
all_images.append(deepcopy(all_images[cache[path]]))
|
||||
continue
|
||||
with g_pathmgr.open(path, "rb") as fopen:
|
||||
all_images.append(PILImage.open(fopen).convert("RGB"))
|
||||
cache[path] = len(all_images) - 1
|
||||
else:
|
||||
# The frame rgb data has already been loaded
|
||||
# Convert it to a PILImage
|
||||
all_images.append(tensor_2_PIL(frame.data))
|
||||
|
||||
return all_images
|
||||
|
||||
|
||||
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
|
||||
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
|
||||
data = data.astype(np.uint8)
|
||||
return PILImage.fromarray(data)
|
308
training/dataset/vos_raw_dataset.py
Normal file
308
training/dataset/vos_raw_dataset.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
from training.dataset.vos_segment_loader import (
|
||||
JSONSegmentLoader,
|
||||
MultiplePNGSegmentLoader,
|
||||
PalettisedPNGSegmentLoader,
|
||||
SA1BSegmentLoader,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VOSFrame:
|
||||
frame_idx: int
|
||||
image_path: str
|
||||
data: Optional[torch.Tensor] = None
|
||||
is_conditioning_only: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VOSVideo:
|
||||
video_name: str
|
||||
video_id: int
|
||||
frames: List[VOSFrame]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
|
||||
class VOSRawDataset:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_video(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PNGRawDataset(VOSRawDataset):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
sample_rate=1,
|
||||
is_palette=True,
|
||||
single_object_mode=False,
|
||||
truncate_video=-1,
|
||||
frames_sampling_mult=False,
|
||||
):
|
||||
self.img_folder = img_folder
|
||||
self.gt_folder = gt_folder
|
||||
self.sample_rate = sample_rate
|
||||
self.is_palette = is_palette
|
||||
self.single_object_mode = single_object_mode
|
||||
self.truncate_video = truncate_video
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
|
||||
# Read and process excluded files if provided
|
||||
if excluded_videos_list_txt is not None:
|
||||
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
excluded_files = []
|
||||
|
||||
# Check if it's not in excluded_files
|
||||
self.video_names = sorted(
|
||||
[video_name for video_name in subset if video_name not in excluded_files]
|
||||
)
|
||||
|
||||
if self.single_object_mode:
|
||||
# single object mode
|
||||
self.video_names = sorted(
|
||||
[
|
||||
os.path.join(video_name, obj)
|
||||
for video_name in self.video_names
|
||||
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
|
||||
]
|
||||
)
|
||||
|
||||
if frames_sampling_mult:
|
||||
video_names_mult = []
|
||||
for video_name in self.video_names:
|
||||
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
|
||||
video_names_mult.extend([video_name] * num_frames)
|
||||
self.video_names = video_names_mult
|
||||
|
||||
def get_video(self, idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[idx]
|
||||
|
||||
if self.single_object_mode:
|
||||
video_frame_root = os.path.join(
|
||||
self.img_folder, os.path.dirname(video_name)
|
||||
)
|
||||
else:
|
||||
video_frame_root = os.path.join(self.img_folder, video_name)
|
||||
|
||||
video_mask_root = os.path.join(self.gt_folder, video_name)
|
||||
|
||||
if self.is_palette:
|
||||
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
|
||||
else:
|
||||
segment_loader = MultiplePNGSegmentLoader(
|
||||
video_mask_root, self.single_object_mode
|
||||
)
|
||||
|
||||
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
|
||||
if self.truncate_video > 0:
|
||||
all_frames = all_frames[: self.truncate_video]
|
||||
frames = []
|
||||
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
|
||||
fid = int(os.path.basename(fpath).split(".")[0])
|
||||
frames.append(VOSFrame(fid, image_path=fpath))
|
||||
video = VOSVideo(video_name, idx, frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
|
||||
class SA1BRawDataset(VOSRawDataset):
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
num_frames=1,
|
||||
mask_area_frac_thresh=1.1, # no filtering by default
|
||||
uncertain_iou=-1, # no filtering by default
|
||||
):
|
||||
self.img_folder = img_folder
|
||||
self.gt_folder = gt_folder
|
||||
self.num_frames = num_frames
|
||||
self.mask_area_frac_thresh = mask_area_frac_thresh
|
||||
self.uncertain_iou = uncertain_iou # stability score
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
subset = [
|
||||
path.split(".")[0] for path in subset if path.endswith(".jpg")
|
||||
] # remove extension
|
||||
|
||||
# Read and process excluded files if provided
|
||||
if excluded_videos_list_txt is not None:
|
||||
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
excluded_files = []
|
||||
|
||||
# Check if it's not in excluded_files and it exists
|
||||
self.video_names = [
|
||||
video_name for video_name in subset if video_name not in excluded_files
|
||||
]
|
||||
|
||||
def get_video(self, idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[idx]
|
||||
|
||||
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
|
||||
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
|
||||
|
||||
segment_loader = SA1BSegmentLoader(
|
||||
video_mask_path,
|
||||
mask_area_frac_thresh=self.mask_area_frac_thresh,
|
||||
video_frame_path=video_frame_path,
|
||||
uncertain_iou=self.uncertain_iou,
|
||||
)
|
||||
|
||||
frames = []
|
||||
for frame_idx in range(self.num_frames):
|
||||
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
|
||||
video_name = video_name.split("_")[-1] # filename is sa_{int}
|
||||
# video id needs to be image_id to be able to load correct annotation file during eval
|
||||
video = VOSVideo(video_name, int(video_name), frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
|
||||
class JSONRawDataset(VOSRawDataset):
|
||||
"""
|
||||
Dataset where the annotation in the format of SA-V json files
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_folder,
|
||||
gt_folder,
|
||||
file_list_txt=None,
|
||||
excluded_videos_list_txt=None,
|
||||
sample_rate=1,
|
||||
rm_unannotated=True,
|
||||
ann_every=1,
|
||||
frames_fps=24,
|
||||
):
|
||||
self.gt_folder = gt_folder
|
||||
self.img_folder = img_folder
|
||||
self.sample_rate = sample_rate
|
||||
self.rm_unannotated = rm_unannotated
|
||||
self.ann_every = ann_every
|
||||
self.frames_fps = frames_fps
|
||||
|
||||
# Read and process excluded files if provided
|
||||
excluded_files = []
|
||||
if excluded_videos_list_txt is not None:
|
||||
if isinstance(excluded_videos_list_txt, str):
|
||||
excluded_videos_lists = [excluded_videos_list_txt]
|
||||
elif isinstance(excluded_videos_list_txt, ListConfig):
|
||||
excluded_videos_lists = list(excluded_videos_list_txt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for excluded_videos_list_txt in excluded_videos_lists:
|
||||
with open(excluded_videos_list_txt, "r") as f:
|
||||
excluded_files.extend(
|
||||
[os.path.splitext(line.strip())[0] for line in f]
|
||||
)
|
||||
excluded_files = set(excluded_files)
|
||||
|
||||
# Read the subset defined in file_list_txt
|
||||
if file_list_txt is not None:
|
||||
with g_pathmgr.open(file_list_txt, "r") as f:
|
||||
subset = [os.path.splitext(line.strip())[0] for line in f]
|
||||
else:
|
||||
subset = os.listdir(self.img_folder)
|
||||
|
||||
self.video_names = sorted(
|
||||
[video_name for video_name in subset if video_name not in excluded_files]
|
||||
)
|
||||
|
||||
def get_video(self, video_idx):
|
||||
"""
|
||||
Given a VOSVideo object, return the mask tensors.
|
||||
"""
|
||||
video_name = self.video_names[video_idx]
|
||||
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
|
||||
segment_loader = JSONSegmentLoader(
|
||||
video_json_path=video_json_path,
|
||||
ann_every=self.ann_every,
|
||||
frames_fps=self.frames_fps,
|
||||
)
|
||||
|
||||
frame_ids = [
|
||||
int(os.path.splitext(frame_name)[0])
|
||||
for frame_name in sorted(
|
||||
os.listdir(os.path.join(self.img_folder, video_name))
|
||||
)
|
||||
]
|
||||
|
||||
frames = [
|
||||
VOSFrame(
|
||||
frame_id,
|
||||
image_path=os.path.join(
|
||||
self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
|
||||
),
|
||||
)
|
||||
for frame_id in frame_ids[:: self.sample_rate]
|
||||
]
|
||||
|
||||
if self.rm_unannotated:
|
||||
# Eliminate the frames that have not been annotated
|
||||
valid_frame_ids = [
|
||||
i * segment_loader.ann_every
|
||||
for i, annot in enumerate(segment_loader.frame_annots)
|
||||
if annot is not None and None not in annot
|
||||
]
|
||||
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
|
||||
|
||||
video = VOSVideo(video_name, video_idx, frames)
|
||||
return video, segment_loader
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
105
training/dataset/vos_sampler.py
Normal file
105
training/dataset/vos_sampler.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from training.dataset.vos_segment_loader import LazySegments
|
||||
|
||||
MAX_RETRIES = 1000
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampledFramesAndObjects:
|
||||
frames: List[int]
|
||||
object_ids: List[int]
|
||||
|
||||
|
||||
class VOSSampler:
|
||||
def __init__(self, sort_frames=True):
|
||||
# frames are ordered by frame id when sort_frames is True
|
||||
self.sort_frames = sort_frames
|
||||
|
||||
def sample(self, video):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RandomUniformSampler(VOSSampler):
|
||||
def __init__(
|
||||
self,
|
||||
num_frames,
|
||||
max_num_objects,
|
||||
reverse_time_prob=0.0,
|
||||
):
|
||||
self.num_frames = num_frames
|
||||
self.max_num_objects = max_num_objects
|
||||
self.reverse_time_prob = reverse_time_prob
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
if len(video.frames) < self.num_frames:
|
||||
raise Exception(
|
||||
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
|
||||
)
|
||||
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
|
||||
frames = [video.frames[start + step] for step in range(self.num_frames)]
|
||||
if random.uniform(0, 1) < self.reverse_time_prob:
|
||||
# Reverse time
|
||||
frames = frames[::-1]
|
||||
|
||||
# Get first frame object ids
|
||||
visible_object_ids = []
|
||||
loaded_segms = segment_loader.load(frames[0].frame_idx)
|
||||
if isinstance(loaded_segms, LazySegments):
|
||||
# LazySegments for SA1BRawDataset
|
||||
visible_object_ids = list(loaded_segms.keys())
|
||||
else:
|
||||
for object_id, segment in segment_loader.load(
|
||||
frames[0].frame_idx
|
||||
).items():
|
||||
if segment.sum():
|
||||
visible_object_ids.append(object_id)
|
||||
|
||||
# First frame needs to have at least a target to track
|
||||
if len(visible_object_ids) > 0:
|
||||
break
|
||||
if retry >= MAX_RETRIES - 1:
|
||||
raise Exception("No visible objects")
|
||||
|
||||
object_ids = random.sample(
|
||||
visible_object_ids,
|
||||
min(len(visible_object_ids), self.max_num_objects),
|
||||
)
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
||||
|
||||
|
||||
class EvalSampler(VOSSampler):
|
||||
"""
|
||||
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
"""
|
||||
Sampling all the frames and all the objects
|
||||
"""
|
||||
if self.sort_frames:
|
||||
# ordered by frame id
|
||||
frames = sorted(video.frames, key=lambda x: x.frame_idx)
|
||||
else:
|
||||
# use the original order
|
||||
frames = video.frames
|
||||
object_ids = segment_loader.load(frames[0].frame_idx).keys()
|
||||
if len(object_ids) == 0:
|
||||
raise Exception("First frame of the video has no objects")
|
||||
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
300
training/dataset/vos_segment_loader.py
Normal file
300
training/dataset/vos_segment_loader.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
try:
|
||||
from pycocotools import mask as mask_utils
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class JSONSegmentLoader:
|
||||
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
|
||||
# Annotations in the json are provided every ann_every th frame
|
||||
self.ann_every = ann_every
|
||||
# Ids of the objects to consider when sampling this video
|
||||
self.valid_obj_ids = valid_obj_ids
|
||||
with open(video_json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
self.frame_annots = data
|
||||
elif isinstance(data, dict):
|
||||
masklet_field_name = "masklet" if "masklet" in data else "masks"
|
||||
self.frame_annots = data[masklet_field_name]
|
||||
if "fps" in data:
|
||||
if isinstance(data["fps"], list):
|
||||
annotations_fps = int(data["fps"][0])
|
||||
else:
|
||||
annotations_fps = int(data["fps"])
|
||||
assert frames_fps % annotations_fps == 0
|
||||
self.ann_every = frames_fps // annotations_fps
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, frame_id, obj_ids=None):
|
||||
assert frame_id % self.ann_every == 0
|
||||
rle_mask = self.frame_annots[frame_id // self.ann_every]
|
||||
|
||||
valid_objs_ids = set(range(len(rle_mask)))
|
||||
if self.valid_obj_ids is not None:
|
||||
# Remove the masklets that have been filtered out for this video
|
||||
valid_objs_ids &= set(self.valid_obj_ids)
|
||||
if obj_ids is not None:
|
||||
# Only keep the objects that have been sampled
|
||||
valid_objs_ids &= set(obj_ids)
|
||||
valid_objs_ids = sorted(list(valid_objs_ids))
|
||||
|
||||
# Construct rle_masks_filtered that only contains the rle masks we are interested in
|
||||
id_2_idx = {}
|
||||
rle_mask_filtered = []
|
||||
for obj_id in valid_objs_ids:
|
||||
if rle_mask[obj_id] is not None:
|
||||
id_2_idx[obj_id] = len(rle_mask_filtered)
|
||||
rle_mask_filtered.append(rle_mask[obj_id])
|
||||
else:
|
||||
id_2_idx[obj_id] = None
|
||||
|
||||
# Decode the masks
|
||||
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
|
||||
2, 0, 1
|
||||
) # (num_obj, h, w)
|
||||
segments = {}
|
||||
for obj_id in valid_objs_ids:
|
||||
if id_2_idx[obj_id] is None:
|
||||
segments[obj_id] = None
|
||||
else:
|
||||
idx = id_2_idx[obj_id]
|
||||
segments[obj_id] = raw_segments[idx]
|
||||
return segments
|
||||
|
||||
def get_valid_obj_frames_ids(self, num_frames_min=None):
|
||||
# For each object, find all the frames with a valid (not None) mask
|
||||
num_objects = len(self.frame_annots[0])
|
||||
|
||||
# The result dict associates each obj_id with the id of its valid frames
|
||||
res = {obj_id: [] for obj_id in range(num_objects)}
|
||||
|
||||
for annot_idx, annot in enumerate(self.frame_annots):
|
||||
for obj_id in range(num_objects):
|
||||
if annot[obj_id] is not None:
|
||||
res[obj_id].append(int(annot_idx * self.ann_every))
|
||||
|
||||
if num_frames_min is not None:
|
||||
# Remove masklets that have less than num_frames_min valid masks
|
||||
for obj_id, valid_frames in list(res.items()):
|
||||
if len(valid_frames) < num_frames_min:
|
||||
res.pop(obj_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PalettisedPNGSegmentLoader:
|
||||
def __init__(self, video_png_root):
|
||||
"""
|
||||
SegmentLoader for datasets with masks stored as palettised PNGs.
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
# build a mapping from frame id to their PNG mask path
|
||||
# note that in some datasets, the PNG paths could have more
|
||||
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
|
||||
png_filenames = os.listdir(self.video_png_root)
|
||||
self.frame_id_to_png_filename = {}
|
||||
for filename in png_filenames:
|
||||
frame_id, _ = os.path.splitext(filename)
|
||||
self.frame_id_to_png_filename[int(frame_id)] = filename
|
||||
|
||||
def load(self, frame_id):
|
||||
"""
|
||||
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# check the path
|
||||
mask_path = os.path.join(
|
||||
self.video_png_root, self.frame_id_to_png_filename[frame_id]
|
||||
)
|
||||
|
||||
# load the mask
|
||||
masks = PILImage.open(mask_path).convert("P")
|
||||
masks = np.array(masks)
|
||||
|
||||
object_id = pd.unique(masks.flatten())
|
||||
object_id = object_id[object_id != 0] # remove background (0)
|
||||
|
||||
# convert into N binary segmentation masks
|
||||
binary_segments = {}
|
||||
for i in object_id:
|
||||
bs = masks == i
|
||||
binary_segments[i] = torch.from_numpy(bs)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class MultiplePNGSegmentLoader:
|
||||
def __init__(self, video_png_root, single_object_mode=False):
|
||||
"""
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
single_object_mode: whether to load only a single object at a time
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
self.single_object_mode = single_object_mode
|
||||
# read a mask to know the resolution of the video
|
||||
if self.single_object_mode:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
|
||||
else:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
|
||||
tmp_mask = np.array(PILImage.open(tmp_mask_path))
|
||||
self.H = tmp_mask.shape[0]
|
||||
self.W = tmp_mask.shape[1]
|
||||
if self.single_object_mode:
|
||||
self.obj_id = (
|
||||
int(video_png_root.split("/")[-1]) + 1
|
||||
) # offset by 1 as bg is 0
|
||||
else:
|
||||
self.obj_id = None
|
||||
|
||||
def load(self, frame_id):
|
||||
if self.single_object_mode:
|
||||
return self._load_single_png(frame_id)
|
||||
else:
|
||||
return self._load_multiple_pngs(frame_id)
|
||||
|
||||
def _load_single_png(self, frame_id):
|
||||
"""
|
||||
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
|
||||
binary_segments = {}
|
||||
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
# if png doesn't exist, empty mask
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
|
||||
return binary_segments
|
||||
|
||||
def _load_multiple_pngs(self, frame_id):
|
||||
"""
|
||||
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# get the path
|
||||
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
|
||||
num_objects = len(all_objects)
|
||||
assert num_objects > 0
|
||||
|
||||
# load the masks
|
||||
binary_segments = {}
|
||||
for obj_folder in all_objects:
|
||||
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
|
||||
obj_id = int(obj_folder.split("/")[-1])
|
||||
obj_id = obj_id + 1 # offset 1 as bg is 0
|
||||
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[obj_id] = torch.from_numpy(mask > 0)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class LazySegments:
|
||||
"""
|
||||
Only decodes segments that are actually used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.segments = {}
|
||||
self.cache = {}
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
self.segments[key] = item
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
rle = self.segments[key]
|
||||
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
|
||||
self.cache[key] = mask
|
||||
return mask
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.segments
|
||||
|
||||
def __len__(self):
|
||||
return len(self.segments)
|
||||
|
||||
def keys(self):
|
||||
return self.segments.keys()
|
||||
|
||||
|
||||
class SA1BSegmentLoader:
|
||||
def __init__(
|
||||
self,
|
||||
video_mask_path,
|
||||
mask_area_frac_thresh=1.1,
|
||||
video_frame_path=None,
|
||||
uncertain_iou=-1,
|
||||
):
|
||||
with open(video_mask_path, "r") as f:
|
||||
self.frame_annots = json.load(f)
|
||||
|
||||
if mask_area_frac_thresh <= 1.0:
|
||||
# Lazily read frame
|
||||
orig_w, orig_h = PILImage.open(video_frame_path).size
|
||||
area = orig_w * orig_h
|
||||
|
||||
self.frame_annots = self.frame_annots["annotations"]
|
||||
|
||||
rle_masks = []
|
||||
for frame_annot in self.frame_annots:
|
||||
if not frame_annot["area"] > 0:
|
||||
continue
|
||||
if ("uncertain_iou" in frame_annot) and (
|
||||
frame_annot["uncertain_iou"] < uncertain_iou
|
||||
):
|
||||
# uncertain_iou is stability score
|
||||
continue
|
||||
if (
|
||||
mask_area_frac_thresh <= 1.0
|
||||
and (frame_annot["area"] / area) >= mask_area_frac_thresh
|
||||
):
|
||||
continue
|
||||
rle_masks.append(frame_annot["segmentation"])
|
||||
|
||||
self.segments = LazySegments()
|
||||
for i, rle in enumerate(rle_masks):
|
||||
self.segments[i] = rle
|
||||
|
||||
def load(self, frame_idx):
|
||||
return self.segments
|
Reference in New Issue
Block a user