[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
162
training/dataset/vos_dataset.py
Normal file
162
training/dataset/vos_dataset.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from PIL import Image as PILImage
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
from training.dataset.vos_raw_dataset import VOSRawDataset
|
||||
from training.dataset.vos_sampler import VOSSampler
|
||||
from training.dataset.vos_segment_loader import JSONSegmentLoader
|
||||
|
||||
from training.utils.data_utils import Frame, Object, VideoDatapoint
|
||||
|
||||
MAX_RETRIES = 100
|
||||
|
||||
|
||||
class VOSDataset(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
transforms,
|
||||
training: bool,
|
||||
video_dataset: VOSRawDataset,
|
||||
sampler: VOSSampler,
|
||||
multiplier: int,
|
||||
always_target=True,
|
||||
target_segments_available=True,
|
||||
):
|
||||
self._transforms = transforms
|
||||
self.training = training
|
||||
self.video_dataset = video_dataset
|
||||
self.sampler = sampler
|
||||
|
||||
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
|
||||
self.repeat_factors *= multiplier
|
||||
print(f"Raw dataset length = {len(self.video_dataset)}")
|
||||
|
||||
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
|
||||
self.always_target = always_target
|
||||
self.target_segments_available = target_segments_available
|
||||
|
||||
def _get_datapoint(self, idx):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
try:
|
||||
if isinstance(idx, torch.Tensor):
|
||||
idx = idx.item()
|
||||
# sample a video
|
||||
video, segment_loader = self.video_dataset.get_video(idx)
|
||||
# sample frames and object indices to be used in a datapoint
|
||||
sampled_frms_and_objs = self.sampler.sample(
|
||||
video, segment_loader, epoch=self.curr_epoch
|
||||
)
|
||||
break # Succesfully loaded video
|
||||
except Exception as e:
|
||||
if self.training:
|
||||
logging.warning(
|
||||
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
|
||||
)
|
||||
idx = random.randrange(0, len(self.video_dataset))
|
||||
else:
|
||||
# Shouldn't fail to load a val video
|
||||
raise e
|
||||
|
||||
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
|
||||
for transform in self._transforms:
|
||||
datapoint = transform(datapoint, epoch=self.curr_epoch)
|
||||
return datapoint
|
||||
|
||||
def construct(self, video, sampled_frms_and_objs, segment_loader):
|
||||
"""
|
||||
Constructs a VideoDatapoint sample to pass to transforms
|
||||
"""
|
||||
sampled_frames = sampled_frms_and_objs.frames
|
||||
sampled_object_ids = sampled_frms_and_objs.object_ids
|
||||
|
||||
images = []
|
||||
rgb_images = load_images(sampled_frames)
|
||||
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
|
||||
for frame_idx, frame in enumerate(sampled_frames):
|
||||
w, h = rgb_images[frame_idx].size
|
||||
images.append(
|
||||
Frame(
|
||||
data=rgb_images[frame_idx],
|
||||
objects=[],
|
||||
)
|
||||
)
|
||||
# We load the gt segments associated with the current frame
|
||||
if isinstance(segment_loader, JSONSegmentLoader):
|
||||
segments = segment_loader.load(
|
||||
frame.frame_idx, obj_ids=sampled_object_ids
|
||||
)
|
||||
else:
|
||||
segments = segment_loader.load(frame.frame_idx)
|
||||
for obj_id in sampled_object_ids:
|
||||
# Extract the segment
|
||||
if obj_id in segments:
|
||||
assert (
|
||||
segments[obj_id] is not None
|
||||
), "None targets are not supported"
|
||||
# segment is uint8 and remains uint8 throughout the transforms
|
||||
segment = segments[obj_id].to(torch.uint8)
|
||||
else:
|
||||
# There is no target, we either use a zero mask target or drop this object
|
||||
if not self.always_target:
|
||||
continue
|
||||
segment = torch.zeros(h, w, dtype=torch.uint8)
|
||||
|
||||
images[frame_idx].objects.append(
|
||||
Object(
|
||||
object_id=obj_id,
|
||||
frame_index=frame.frame_idx,
|
||||
segment=segment,
|
||||
)
|
||||
)
|
||||
return VideoDatapoint(
|
||||
frames=images,
|
||||
video_id=video.video_id,
|
||||
size=(h, w),
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._get_datapoint(idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_dataset)
|
||||
|
||||
|
||||
def load_images(frames):
|
||||
all_images = []
|
||||
cache = {}
|
||||
for frame in frames:
|
||||
if frame.data is None:
|
||||
# Load the frame rgb data from file
|
||||
path = frame.image_path
|
||||
if path in cache:
|
||||
all_images.append(deepcopy(all_images[cache[path]]))
|
||||
continue
|
||||
with g_pathmgr.open(path, "rb") as fopen:
|
||||
all_images.append(PILImage.open(fopen).convert("RGB"))
|
||||
cache[path] = len(all_images) - 1
|
||||
else:
|
||||
# The frame rgb data has already been loaded
|
||||
# Convert it to a PILImage
|
||||
all_images.append(tensor_2_PIL(frame.data))
|
||||
|
||||
return all_images
|
||||
|
||||
|
||||
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
|
||||
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
|
||||
data = data.astype(np.uint8)
|
||||
return PILImage.fromarray(data)
|
Reference in New Issue
Block a user