163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import random
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from iopath.common.file_io import g_pathmgr
|
|
from PIL import Image as PILImage
|
|
from torchvision.datasets.vision import VisionDataset
|
|
|
|
from training.dataset.vos_raw_dataset import VOSRawDataset
|
|
from training.dataset.vos_sampler import VOSSampler
|
|
from training.dataset.vos_segment_loader import JSONSegmentLoader
|
|
|
|
from training.utils.data_utils import Frame, Object, VideoDatapoint
|
|
|
|
MAX_RETRIES = 100
|
|
|
|
|
|
class VOSDataset(VisionDataset):
|
|
def __init__(
|
|
self,
|
|
transforms,
|
|
training: bool,
|
|
video_dataset: VOSRawDataset,
|
|
sampler: VOSSampler,
|
|
multiplier: int,
|
|
always_target=True,
|
|
target_segments_available=True,
|
|
):
|
|
self._transforms = transforms
|
|
self.training = training
|
|
self.video_dataset = video_dataset
|
|
self.sampler = sampler
|
|
|
|
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
|
|
self.repeat_factors *= multiplier
|
|
print(f"Raw dataset length = {len(self.video_dataset)}")
|
|
|
|
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
|
self.always_target = always_target
|
|
self.target_segments_available = target_segments_available
|
|
|
|
def _get_datapoint(self, idx):
|
|
|
|
for retry in range(MAX_RETRIES):
|
|
try:
|
|
if isinstance(idx, torch.Tensor):
|
|
idx = idx.item()
|
|
# sample a video
|
|
video, segment_loader = self.video_dataset.get_video(idx)
|
|
# sample frames and object indices to be used in a datapoint
|
|
sampled_frms_and_objs = self.sampler.sample(
|
|
video, segment_loader, epoch=self.curr_epoch
|
|
)
|
|
break # Succesfully loaded video
|
|
except Exception as e:
|
|
if self.training:
|
|
logging.warning(
|
|
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
|
|
)
|
|
idx = random.randrange(0, len(self.video_dataset))
|
|
else:
|
|
# Shouldn't fail to load a val video
|
|
raise e
|
|
|
|
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
|
|
for transform in self._transforms:
|
|
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
|
return datapoint
|
|
|
|
def construct(self, video, sampled_frms_and_objs, segment_loader):
|
|
"""
|
|
Constructs a VideoDatapoint sample to pass to transforms
|
|
"""
|
|
sampled_frames = sampled_frms_and_objs.frames
|
|
sampled_object_ids = sampled_frms_and_objs.object_ids
|
|
|
|
images = []
|
|
rgb_images = load_images(sampled_frames)
|
|
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
|
|
for frame_idx, frame in enumerate(sampled_frames):
|
|
w, h = rgb_images[frame_idx].size
|
|
images.append(
|
|
Frame(
|
|
data=rgb_images[frame_idx],
|
|
objects=[],
|
|
)
|
|
)
|
|
# We load the gt segments associated with the current frame
|
|
if isinstance(segment_loader, JSONSegmentLoader):
|
|
segments = segment_loader.load(
|
|
frame.frame_idx, obj_ids=sampled_object_ids
|
|
)
|
|
else:
|
|
segments = segment_loader.load(frame.frame_idx)
|
|
for obj_id in sampled_object_ids:
|
|
# Extract the segment
|
|
if obj_id in segments:
|
|
assert (
|
|
segments[obj_id] is not None
|
|
), "None targets are not supported"
|
|
# segment is uint8 and remains uint8 throughout the transforms
|
|
segment = segments[obj_id].to(torch.uint8)
|
|
else:
|
|
# There is no target, we either use a zero mask target or drop this object
|
|
if not self.always_target:
|
|
continue
|
|
segment = torch.zeros(h, w, dtype=torch.uint8)
|
|
|
|
images[frame_idx].objects.append(
|
|
Object(
|
|
object_id=obj_id,
|
|
frame_index=frame.frame_idx,
|
|
segment=segment,
|
|
)
|
|
)
|
|
return VideoDatapoint(
|
|
frames=images,
|
|
video_id=video.video_id,
|
|
size=(h, w),
|
|
)
|
|
|
|
def __getitem__(self, idx):
|
|
return self._get_datapoint(idx)
|
|
|
|
def __len__(self):
|
|
return len(self.video_dataset)
|
|
|
|
|
|
def load_images(frames):
|
|
all_images = []
|
|
cache = {}
|
|
for frame in frames:
|
|
if frame.data is None:
|
|
# Load the frame rgb data from file
|
|
path = frame.image_path
|
|
if path in cache:
|
|
all_images.append(deepcopy(all_images[cache[path]]))
|
|
continue
|
|
with g_pathmgr.open(path, "rb") as fopen:
|
|
all_images.append(PILImage.open(fopen).convert("RGB"))
|
|
cache[path] = len(all_images) - 1
|
|
else:
|
|
# The frame rgb data has already been loaded
|
|
# Convert it to a PILImage
|
|
all_images.append(tensor_2_PIL(frame.data))
|
|
|
|
return all_images
|
|
|
|
|
|
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
|
|
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
|
|
data = data.astype(np.uint8)
|
|
return PILImage.fromarray(data)
|