[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
105
training/dataset/vos_sampler.py
Normal file
105
training/dataset/vos_sampler.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from training.dataset.vos_segment_loader import LazySegments
|
||||
|
||||
MAX_RETRIES = 1000
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampledFramesAndObjects:
|
||||
frames: List[int]
|
||||
object_ids: List[int]
|
||||
|
||||
|
||||
class VOSSampler:
|
||||
def __init__(self, sort_frames=True):
|
||||
# frames are ordered by frame id when sort_frames is True
|
||||
self.sort_frames = sort_frames
|
||||
|
||||
def sample(self, video):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RandomUniformSampler(VOSSampler):
|
||||
def __init__(
|
||||
self,
|
||||
num_frames,
|
||||
max_num_objects,
|
||||
reverse_time_prob=0.0,
|
||||
):
|
||||
self.num_frames = num_frames
|
||||
self.max_num_objects = max_num_objects
|
||||
self.reverse_time_prob = reverse_time_prob
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
|
||||
for retry in range(MAX_RETRIES):
|
||||
if len(video.frames) < self.num_frames:
|
||||
raise Exception(
|
||||
f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
|
||||
)
|
||||
start = random.randrange(0, len(video.frames) - self.num_frames + 1)
|
||||
frames = [video.frames[start + step] for step in range(self.num_frames)]
|
||||
if random.uniform(0, 1) < self.reverse_time_prob:
|
||||
# Reverse time
|
||||
frames = frames[::-1]
|
||||
|
||||
# Get first frame object ids
|
||||
visible_object_ids = []
|
||||
loaded_segms = segment_loader.load(frames[0].frame_idx)
|
||||
if isinstance(loaded_segms, LazySegments):
|
||||
# LazySegments for SA1BRawDataset
|
||||
visible_object_ids = list(loaded_segms.keys())
|
||||
else:
|
||||
for object_id, segment in segment_loader.load(
|
||||
frames[0].frame_idx
|
||||
).items():
|
||||
if segment.sum():
|
||||
visible_object_ids.append(object_id)
|
||||
|
||||
# First frame needs to have at least a target to track
|
||||
if len(visible_object_ids) > 0:
|
||||
break
|
||||
if retry >= MAX_RETRIES - 1:
|
||||
raise Exception("No visible objects")
|
||||
|
||||
object_ids = random.sample(
|
||||
visible_object_ids,
|
||||
min(len(visible_object_ids), self.max_num_objects),
|
||||
)
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
||||
|
||||
|
||||
class EvalSampler(VOSSampler):
|
||||
"""
|
||||
VOS Sampler for evaluation: sampling all the frames and all the objects in a video
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def sample(self, video, segment_loader, epoch=None):
|
||||
"""
|
||||
Sampling all the frames and all the objects
|
||||
"""
|
||||
if self.sort_frames:
|
||||
# ordered by frame id
|
||||
frames = sorted(video.frames, key=lambda x: x.frame_idx)
|
||||
else:
|
||||
# use the original order
|
||||
frames = video.frames
|
||||
object_ids = segment_loader.load(frames[0].frame_idx).keys()
|
||||
if len(object_ids) == 0:
|
||||
raise Exception("First frame of the video has no objects")
|
||||
|
||||
return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
|
Reference in New Issue
Block a user