106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
![]() |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import random
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List
|
||
|
|
||
|
from training.dataset.vos_segment_loader import LazySegments
|
||
|
|
||
|
MAX_RETRIES = 1000
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SampledFramesAndObjects:
|
||
|
frames: List[int]
|
||
|
object_ids: List[int]
|
||
|
|
||
|
|
||
|
class VOSSampler:
|
||
|
def __init__(self, sort_frames=True):
|
||
|
# frames are ordered by frame id when sort_frames is True
|
||
|
self.sort_frames = sort_frames
|
||
|
|
||
|
def sample(self, video):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
|
||
|
class RandomUniformSampler(VOSSampler):
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_frames,
|
||
|
max_num_objects,
|
||
|
reverse_time_prob=0.0,
|
||
|
):
|
||
|
self.num_frames = num_frames
|
||
|
self.max_num_objects = max_num_objects
|
||
|
self.reverse_time_prob = reverse_time_prob
|
||
|
|
||
|
def sample(self, video, segment_loader, epoch=None):
|
||
|
|
||
|
for retry in range(MAX_RETRIES):
|
||
|
if len(video.frames) < self.num_frames:
|
||
|
raise Exception(
|
||
|
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
|
||
|
)
|
||
|
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
|
||
|
frames = [video.frames[start + step] for step in range(self.num_frames)]
|
||
|
if random.uniform(0, 1) < self.reverse_time_prob:
|
||
|
# Reverse time
|
||
|
frames = frames[::-1]
|
||
|
|
||
|
# Get first frame object ids
|
||
|
visible_object_ids = []
|
||
|
loaded_segms = segment_loader.load(frames[0].frame_idx)
|
||
|
if isinstance(loaded_segms, LazySegments):
|
||
|
# LazySegments for SA1BRawDataset
|
||
|
visible_object_ids = list(loaded_segms.keys())
|
||
|
else:
|
||
|
for object_id, segment in segment_loader.load(
|
||
|
frames[0].frame_idx
|
||
|
).items():
|
||
|
if segment.sum():
|
||
|
visible_object_ids.append(object_id)
|
||
|
|
||
|
# First frame needs to have at least a target to track
|
||
|
if len(visible_object_ids) > 0:
|
||
|
break
|
||
|
if retry >= MAX_RETRIES - 1:
|
||
|
raise Exception("No visible objects")
|
||
|
|
||
|
object_ids = random.sample(
|
||
|
visible_object_ids,
|
||
|
min(len(visible_object_ids), self.max_num_objects),
|
||
|
)
|
||
|
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
||
|
|
||
|
|
||
|
class EvalSampler(VOSSampler):
|
||
|
"""
|
||
|
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
def sample(self, video, segment_loader, epoch=None):
|
||
|
"""
|
||
|
Sampling all the frames and all the objects
|
||
|
"""
|
||
|
if self.sort_frames:
|
||
|
# ordered by frame id
|
||
|
frames = sorted(video.frames, key=lambda x: x.frame_idx)
|
||
|
else:
|
||
|
# use the original order
|
||
|
frames = video.frames
|
||
|
object_ids = segment_loader.load(frames[0].frame_idx).keys()
|
||
|
if len(object_ids) == 0:
|
||
|
raise Exception("First frame of the video has no objects")
|
||
|
|
||
|
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|