SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
This commit is contained in:
92
sam2/benchmark.py
Normal file
92
sam2/benchmark.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
# Only cuda supported
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Config and checkpoint
|
||||
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||
|
||||
# Build video predictor with vos_optimized=True setting
|
||||
predictor = build_sam2_video_predictor(
|
||||
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
|
||||
)
|
||||
|
||||
|
||||
# Initialize with video
|
||||
video_dir = "notebooks/videos/bedroom"
|
||||
# scan all the JPEG frame names in this directory
|
||||
frame_names = [
|
||||
p
|
||||
for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
inference_state = predictor.init_state(video_path=video_dir)
|
||||
|
||||
|
||||
# Number of runs, warmup etc
|
||||
warm_up, runs = 5, 25
|
||||
verbose = True
|
||||
num_frames = len(frame_names)
|
||||
total, count = 0, 0
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# We will select an object with a click.
|
||||
# See video_predictor_example.ipynb for more detailed explanation
|
||||
ann_frame_idx, ann_obj_id = 0, 1
|
||||
# Add a positive click at (x, y) = (210, 350)
|
||||
# For labels, `1` means positive click
|
||||
points = np.array([[210, 350]], dtype=np.float32)
|
||||
labels = np.array([1], np.int32)
|
||||
|
||||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=ann_obj_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Warmup and then average FPS over several runs
|
||||
with torch.autocast("cuda", torch.bfloat16):
|
||||
with torch.inference_mode():
|
||||
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
|
||||
start = time.time()
|
||||
# Start tracking
|
||||
for (
|
||||
out_frame_idx,
|
||||
out_obj_ids,
|
||||
out_mask_logits,
|
||||
) in predictor.propagate_in_video(inference_state):
|
||||
pass
|
||||
|
||||
end = time.time()
|
||||
total += end - start
|
||||
count += 1
|
||||
if i == warm_up - 1:
|
||||
print("Warmup FPS: ", count * num_frames / total)
|
||||
total = 0
|
||||
count = 0
|
||||
|
||||
print("FPS: ", count * num_frames / total)
|
@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
vos_optimized=False,
|
||||
**kwargs,
|
||||
):
|
||||
hydra_overrides = [
|
||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||
]
|
||||
if vos_optimized:
|
||||
hydra_overrides = [
|
||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
|
||||
"++model.compile_image_encoder=True", # Let sam2_base handle this
|
||||
]
|
||||
|
||||
if apply_postprocessing:
|
||||
hydra_overrides_extra = hydra_overrides_extra.copy()
|
||||
hydra_overrides_extra += [
|
||||
|
@@ -36,7 +36,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -47,7 +47,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -40,7 +40,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -51,7 +51,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -39,7 +39,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -50,7 +50,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -39,7 +39,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -50,7 +50,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -97,7 +97,7 @@ trainer:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -108,7 +108,7 @@ trainer:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -36,7 +36,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -47,7 +47,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -40,7 +40,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -51,7 +51,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -39,7 +39,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -50,7 +50,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -39,7 +39,7 @@ model:
|
||||
self_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
downsample_rate: 1
|
||||
@@ -50,7 +50,7 @@ model:
|
||||
cross_attention:
|
||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||
rope_theta: 10000.0
|
||||
feat_sizes: [32, 32]
|
||||
feat_sizes: [64, 64]
|
||||
rope_k_repeat: True
|
||||
embedding_dim: 256
|
||||
num_heads: 1
|
||||
|
@@ -32,9 +32,7 @@ def window_partition(x, window_size):
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = (
|
||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(
|
||||
x = windows.reshape(
|
||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
x = x[:, :H, :W, :]
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
|
||||
temperature: int = 10000,
|
||||
normalize: bool = True,
|
||||
scale: Optional[float] = None,
|
||||
# Following settings only relevant
|
||||
# for warmping up cache for compilation
|
||||
warmup_cache: bool = True,
|
||||
image_size: int = 1024,
|
||||
strides: Tuple[int] = (4, 8, 16, 32),
|
||||
):
|
||||
super().__init__()
|
||||
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
||||
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
|
||||
self.scale = scale
|
||||
|
||||
self.cache = {}
|
||||
if warmup_cache and torch.cuda.is_available():
|
||||
# Warmup cache for cuda, to help with compilation
|
||||
device = torch.device("cuda")
|
||||
for stride in strides:
|
||||
cache_key = (image_size // stride, image_size // stride)
|
||||
self._pe(1, device, *cache_key)
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
# The positions are expected to be normalized
|
||||
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
def _pe(self, B, device, *cache_key):
|
||||
H, W = cache_key
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
||||
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||
|
||||
y_embed = (
|
||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
||||
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||
.view(1, -1, 1)
|
||||
.repeat(x.shape[0], 1, x.shape[-1])
|
||||
.repeat(B, 1, W)
|
||||
)
|
||||
x_embed = (
|
||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
||||
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
||||
.view(1, 1, -1)
|
||||
.repeat(x.shape[0], x.shape[-2], 1)
|
||||
.repeat(B, H, 1)
|
||||
)
|
||||
|
||||
if self.normalize:
|
||||
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
|
||||
self.cache[cache_key] = pos[0]
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
B = x.shape[0]
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
return self._pe(B, x.device, *cache_key)
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
|
@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
|
||||
point_embedding = self.pe_layer.forward_with_coords(
|
||||
points, self.input_image_size
|
||||
)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
||||
|
||||
point_embedding = torch.where(
|
||||
(labels == -1).unsqueeze(-1),
|
||||
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = torch.where(
|
||||
(labels == 0).unsqueeze(-1),
|
||||
point_embedding + self.point_embeddings[0].weight,
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = torch.where(
|
||||
(labels == 1).unsqueeze(-1),
|
||||
point_embedding + self.point_embeddings[1].weight,
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = torch.where(
|
||||
(labels == 2).unsqueeze(-1),
|
||||
point_embedding + self.point_embeddings[2].weight,
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = torch.where(
|
||||
(labels == 3).unsqueeze(-1),
|
||||
point_embedding + self.point_embeddings[3].weight,
|
||||
point_embedding,
|
||||
)
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
|
@@ -4,9 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
@@ -16,29 +14,6 @@ from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
# Check whether Flash Attention is available (and use it by default)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
||||
ALLOW_ALL_KERNELS = False
|
||||
|
||||
|
||||
def sdp_kernel_context(dropout_p):
|
||||
"""
|
||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
||||
by default, but fall back to all available kernels if Flash Attention fails.
|
||||
"""
|
||||
if ALLOW_ALL_KERNELS:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
@@ -265,20 +240,7 @@ class Attention(nn.Module):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
|
||||
# whether to repeat q rope to match k length
|
||||
# this is needed for cross-attention to memories
|
||||
rope_k_repeat=False,
|
||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
||||
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
|
||||
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
||||
)
|
||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||
self.freqs_cis = freqs_cis
|
||||
self.freqs_cis = (
|
||||
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
||||
)
|
||||
self.rope_k_repeat = rope_k_repeat
|
||||
|
||||
def forward(
|
||||
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
@@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
|
||||
if self.add_tpos_enc_to_obj_ptrs:
|
||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||
obj_pos = torch.tensor(pos_list, device=device)
|
||||
obj_pos = (
|
||||
torch.tensor(pos_list)
|
||||
.pin_memory()
|
||||
.to(device=device, non_blocking=True)
|
||||
)
|
||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
||||
|
@@ -8,6 +8,7 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -26,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
||||
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
||||
clear_non_cond_mem_around_input=False,
|
||||
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
||||
clear_non_cond_mem_for_multi_obj=False,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
@@ -37,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
self.fill_hole_area = fill_hole_area
|
||||
self.non_overlap_masks = non_overlap_masks
|
||||
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
||||
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -87,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||
inference_state["obj_ids"] = []
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
inference_state["output_dict"] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
inference_state["output_dict_per_obj"] = {}
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
@@ -99,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["temp_output_dict_per_obj"] = {}
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
inference_state["consolidated_frame_inds"] = {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
}
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = {}
|
||||
inference_state["frames_tracked_per_obj"] = {}
|
||||
# Warm up the visual backbone and cache the image feature on frame 0
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
return inference_state
|
||||
@@ -133,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# This is a new object id not sent to the server before. We only allow adding
|
||||
# new objects *before* the tracking starts.
|
||||
allow_new_object = not inference_state["tracking_has_started"]
|
||||
# We always allow adding new objects (including after tracking starts).
|
||||
allow_new_object = True
|
||||
if allow_new_object:
|
||||
# get the next object slot
|
||||
obj_idx = len(inference_state["obj_id_to_idx"])
|
||||
@@ -153,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
inference_state["frames_tracked_per_obj"][obj_idx] = {}
|
||||
return obj_idx
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@@ -213,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"box prompt must be provided before any point prompt "
|
||||
"(please use clear_old_points=True instead)"
|
||||
)
|
||||
if inference_state["tracking_has_started"]:
|
||||
warnings.warn(
|
||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||
"use box prompt as an *initial* input before tracking, please call "
|
||||
"'reset_state' on the inference state to restart from scratch.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not isinstance(box, torch.Tensor):
|
||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||
box_coords = box.reshape(1, 2, 2)
|
||||
@@ -251,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
||||
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
# whether to track in reverse time order
|
||||
if is_init_cond_frame:
|
||||
reverse = False
|
||||
else:
|
||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
||||
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
@@ -305,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@@ -356,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
||||
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
# whether to track in reverse time order
|
||||
if is_init_cond_frame:
|
||||
reverse = False
|
||||
else:
|
||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
||||
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
@@ -393,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@@ -428,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond,
|
||||
run_mem_encoder,
|
||||
consolidate_at_video_res=False,
|
||||
):
|
||||
"""
|
||||
@@ -445,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# Optionally, we allow consolidating the temporary outputs at the original
|
||||
# video resolution (to provide a better editing experience for mask prompts).
|
||||
if consolidate_at_video_res:
|
||||
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
||||
consolidated_H = inference_state["video_height"]
|
||||
consolidated_W = inference_state["video_width"]
|
||||
consolidated_mask_key = "pred_masks_video_res"
|
||||
@@ -458,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||
consolidated_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
consolidated_mask_key: torch.full(
|
||||
size=(batch_size, 1, consolidated_H, consolidated_W),
|
||||
fill_value=NO_OBJ_SCORE,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["storage_device"],
|
||||
),
|
||||
"obj_ptr": torch.full(
|
||||
size=(batch_size, self.hidden_dim),
|
||||
fill_value=NO_OBJ_SCORE,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
"object_score_logits": torch.full(
|
||||
size=(batch_size, 1),
|
||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||
fill_value=10.0,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
}
|
||||
empty_mask_ptr = None
|
||||
for obj_idx in range(batch_size):
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
@@ -498,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||
if out is None:
|
||||
# Fill in dummy object pointers for those objects without any inputs or
|
||||
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
||||
# i.e. when we need to build the memory for tracking).
|
||||
if run_mem_encoder:
|
||||
if empty_mask_ptr is None:
|
||||
empty_mask_ptr = self._get_empty_mask_ptr(
|
||||
inference_state, frame_idx
|
||||
)
|
||||
# fill object pointer with a dummy pointer (based on an empty mask)
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
||||
continue
|
||||
# Add the temporary object output mask to consolidated output mask
|
||||
obj_mask = out["pred_masks"]
|
||||
@@ -523,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
align_corners=False,
|
||||
)
|
||||
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
||||
"object_score_logits"
|
||||
]
|
||||
|
||||
# Optionally, apply non-overlapping constraints on the consolidated scores
|
||||
# and rerun the memory encoder
|
||||
if run_mem_encoder:
|
||||
device = inference_state["device"]
|
||||
high_res_masks = torch.nn.functional.interpolate(
|
||||
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
if self.non_overlap_masks_for_mem_enc:
|
||||
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
||||
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
high_res_masks=high_res_masks,
|
||||
object_score_logits=consolidated_out["object_score_logits"],
|
||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||
)
|
||||
consolidated_out["maskmem_features"] = maskmem_features
|
||||
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
|
||||
return consolidated_out
|
||||
|
||||
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
||||
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
||||
# A dummy (empty) mask with a single object
|
||||
batch_size = 1
|
||||
mask_inputs = torch.zeros(
|
||||
(batch_size, 1, self.image_size, self.image_size),
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
)
|
||||
|
||||
# Retrieve correct image features
|
||||
(
|
||||
_,
|
||||
_,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
||||
|
||||
# Feed the empty mask and image feature above to get a dummy object pointer
|
||||
current_out = self.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=True,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=None,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict={},
|
||||
num_frames=inference_state["num_frames"],
|
||||
track_in_reverse=False,
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=None,
|
||||
)
|
||||
return current_out["obj_ptr"]
|
||||
|
||||
@torch.inference_mode()
|
||||
def propagate_in_video_preflight(self, inference_state):
|
||||
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
||||
# Tracking has started and we don't allow adding new objects until session is reset.
|
||||
inference_state["tracking_has_started"] = True
|
||||
# Check and make sure that every object has received input points or masks.
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
if batch_size == 0:
|
||||
raise RuntimeError(
|
||||
"No input points or masks are provided for any object; please add inputs first."
|
||||
)
|
||||
|
||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||
# add them into "output_dict".
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
output_dict = inference_state["output_dict"]
|
||||
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
||||
# temporary outputs have been added (either in this call or any previous calls
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temporary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = (
|
||||
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
)
|
||||
# merge them into "output_dict" and also create per-object slices
|
||||
output_dict[storage_key][frame_idx] = consolidated_out
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, consolidated_out, storage_key
|
||||
)
|
||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||
)
|
||||
if clear_non_cond_mem:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
for frame_idx, out in obj_temp_output_dict[storage_key].items():
|
||||
# Run memory encoder on the temporary outputs (if the memory feature is missing)
|
||||
if out["maskmem_features"] is None:
|
||||
high_res_masks = torch.nn.functional.interpolate(
|
||||
out["pred_masks"].to(inference_state["device"]),
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
high_res_masks=high_res_masks,
|
||||
object_score_logits=out["object_score_logits"],
|
||||
# these frames are what the user interacted with
|
||||
is_mask_from_pts=True,
|
||||
)
|
||||
out["maskmem_features"] = maskmem_features
|
||||
out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
obj_output_dict[storage_key][frame_idx] = out
|
||||
if self.clear_non_cond_mem_around_input:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_obj_non_cond_mem_around_input(
|
||||
inference_state, frame_idx, obj_idx
|
||||
)
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
obj_temp_output_dict[storage_key].clear()
|
||||
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in output_dict["cond_frame_outputs"]:
|
||||
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
||||
# check and make sure that every object has received input points or masks
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
||||
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
||||
raise RuntimeError(
|
||||
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
|
||||
)
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
assert frame_idx in output_dict["cond_frame_outputs"]
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
|
||||
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
||||
# with either points or mask inputs (which should be true under a correct workflow).
|
||||
all_consolidated_frame_inds = (
|
||||
consolidated_frame_inds["cond_frame_outputs"]
|
||||
| consolidated_frame_inds["non_cond_frame_outputs"]
|
||||
)
|
||||
input_frames_inds = set()
|
||||
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
||||
input_frames_inds.update(point_inputs_per_frame.keys())
|
||||
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
||||
input_frames_inds.update(mask_inputs_per_frame.keys())
|
||||
assert all_consolidated_frame_inds == input_frames_inds
|
||||
|
||||
@torch.inference_mode()
|
||||
def propagate_in_video(
|
||||
@@ -670,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"""Propagate the input points across frames to track in the entire video."""
|
||||
self.propagate_in_video_preflight(inference_state)
|
||||
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
num_frames = inference_state["num_frames"]
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
raise RuntimeError("No points are provided; please add points first")
|
||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||
)
|
||||
|
||||
# set start index, end index, and processing order
|
||||
if start_frame_idx is None:
|
||||
# default: start from the earliest frame with input points
|
||||
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
||||
start_frame_idx = min(
|
||||
t
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
||||
for t in obj_output_dict["cond_frame_outputs"]
|
||||
)
|
||||
if max_frame_num_to_track is None:
|
||||
# default: track all the frames in the video
|
||||
max_frame_num_to_track = num_frames
|
||||
@@ -701,78 +581,53 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
||||
|
||||
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
||||
# We skip those frames already in consolidated outputs (these are frames
|
||||
# that received input clicks or mask). Note that we cannot directly run
|
||||
# batched forward on them via `_run_single_frame_inference` because the
|
||||
# number of clicks on each object might be different.
|
||||
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
if clear_non_cond_mem:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out, pred_masks = self._run_single_frame_inference(
|
||||
inference_state=inference_state,
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=reverse,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
output_dict[storage_key][frame_idx] = current_out
|
||||
# Create slices of per-object outputs for subsequent interaction with each
|
||||
# individual object after tracking.
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, current_out, storage_key
|
||||
)
|
||||
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
||||
pred_masks_per_obj = [None] * batch_size
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
# We skip those frames already in consolidated outputs (these are frames
|
||||
# that received input clicks or mask). Note that we cannot directly run
|
||||
# batched forward on them via `_run_single_frame_inference` because the
|
||||
# number of clicks on each object might be different.
|
||||
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = obj_output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
if self.clear_non_cond_mem_around_input:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_obj_non_cond_mem_around_input(
|
||||
inference_state, frame_idx, obj_idx
|
||||
)
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out, pred_masks = self._run_single_frame_inference(
|
||||
inference_state=inference_state,
|
||||
output_dict=obj_output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=reverse,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
obj_output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
|
||||
"reverse": reverse
|
||||
}
|
||||
pred_masks_per_obj[obj_idx] = pred_masks
|
||||
|
||||
# Resize the output mask to the original video resolution (we directly use
|
||||
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
||||
if len(pred_masks_per_obj) > 1:
|
||||
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
||||
else:
|
||||
all_pred_masks = pred_masks_per_obj[0]
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
inference_state, pred_masks
|
||||
inference_state, all_pred_masks
|
||||
)
|
||||
yield frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def _add_output_per_object(
|
||||
self, inference_state, frame_idx, current_out, storage_key
|
||||
):
|
||||
"""
|
||||
Split a multi-object output into per-object output slices and add them into
|
||||
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
||||
"""
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
||||
|
||||
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
||||
|
||||
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
||||
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
||||
obj_slice = slice(obj_idx, obj_idx + 1)
|
||||
obj_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
||||
}
|
||||
if maskmem_features is not None:
|
||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||
if maskmem_pos_enc is not None:
|
||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
@torch.inference_mode()
|
||||
def clear_all_prompts_in_frame(
|
||||
self, inference_state, frame_idx, obj_id, need_output=True
|
||||
@@ -788,41 +643,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
||||
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
|
||||
# Check and see if there are still any inputs left on this frame
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
frame_has_input = False
|
||||
for obj_idx2 in range(batch_size):
|
||||
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
|
||||
# If this frame has no remaining inputs for any objects, we further clear its
|
||||
# conditioning frame status
|
||||
if not frame_has_input:
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if out is not None:
|
||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
||||
# Similarly, do it for the sliced output on each object.
|
||||
for obj_idx2 in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
||||
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if obj_out is not None:
|
||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
||||
|
||||
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
self._reset_tracking_results(inference_state)
|
||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if out is not None:
|
||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||
inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
|
||||
|
||||
if not need_output:
|
||||
return
|
||||
@@ -836,7 +664,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@@ -856,6 +683,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["mask_inputs_per_obj"].clear()
|
||||
inference_state["output_dict_per_obj"].clear()
|
||||
inference_state["temp_output_dict_per_obj"].clear()
|
||||
inference_state["frames_tracked_per_obj"].clear()
|
||||
|
||||
def _reset_tracking_results(self, inference_state):
|
||||
"""Reset all tracking inputs and results across the videos."""
|
||||
@@ -869,12 +697,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
for v in inference_state["temp_output_dict_per_obj"].values():
|
||||
v["cond_frame_outputs"].clear()
|
||||
v["non_cond_frame_outputs"].clear()
|
||||
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
||||
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
||||
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
||||
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"].clear()
|
||||
for v in inference_state["frames_tracked_per_obj"].values():
|
||||
v.clear()
|
||||
|
||||
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
||||
"""Compute the image features on a given frame."""
|
||||
@@ -1092,8 +916,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["obj_ids"] = new_obj_ids
|
||||
|
||||
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
||||
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
||||
# it's already handled in Step 0)
|
||||
def _map_keys(container):
|
||||
new_kvs = []
|
||||
for k in old_obj_inds:
|
||||
@@ -1106,30 +928,9 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
_map_keys(inference_state["mask_inputs_per_obj"])
|
||||
_map_keys(inference_state["output_dict_per_obj"])
|
||||
_map_keys(inference_state["temp_output_dict_per_obj"])
|
||||
_map_keys(inference_state["frames_tracked_per_obj"])
|
||||
|
||||
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
||||
def _slice_state(output_dict, storage_key):
|
||||
for frame_idx, out in output_dict[storage_key].items():
|
||||
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
||||
out["maskmem_pos_enc"] = [
|
||||
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
||||
]
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
||||
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
||||
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
||||
out["object_score_logits"] = out["object_score_logits"][
|
||||
remain_old_obj_inds
|
||||
]
|
||||
# also update the per-object slices
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, out, storage_key
|
||||
)
|
||||
|
||||
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
||||
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
||||
|
||||
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||
# could show an updated mask for objects previously occluded by the object being removed
|
||||
if need_output:
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
@@ -1142,7 +943,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@@ -1164,9 +964,259 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
r = self.memory_temporal_stride_for_eval
|
||||
frame_idx_begin = frame_idx - r * self.num_maskmem
|
||||
frame_idx_end = frame_idx + r * self.num_maskmem
|
||||
output_dict = inference_state["output_dict"]
|
||||
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
non_cond_frame_outputs.pop(t, None)
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
non_cond_frame_outputs.pop(t, None)
|
||||
|
||||
|
||||
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
||||
"""Optimized for the VOS setting"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._compile_all_components()
|
||||
|
||||
def _compile_all_components(self):
|
||||
print("Compiling all components for VOS setting. First time may be very slow.")
|
||||
self.memory_encoder.forward = torch.compile(
|
||||
self.memory_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
self.memory_attention.forward = torch.compile(
|
||||
self.memory_attention.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True, # Num. of memories varies
|
||||
)
|
||||
|
||||
self.sam_prompt_encoder.forward = torch.compile(
|
||||
self.sam_prompt_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=False, # Accuracy regression on True
|
||||
)
|
||||
|
||||
self.sam_mask_decoder.forward = torch.compile(
|
||||
self.sam_mask_decoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=False, # Accuracy regression on True
|
||||
)
|
||||
|
||||
def forward_image(self, img_batch: torch.Tensor):
|
||||
"""
|
||||
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||
cloning the backbone features and pos encoding to enable compilation.
|
||||
"""
|
||||
backbone_out = self.image_encoder(img_batch)
|
||||
if self.use_high_res_features_in_sam:
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
||||
backbone_out["backbone_fpn"][0]
|
||||
)
|
||||
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
||||
backbone_out["backbone_fpn"][1]
|
||||
)
|
||||
# Clone to help torch.compile
|
||||
for i in range(len(backbone_out["backbone_fpn"])):
|
||||
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
|
||||
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
|
||||
i
|
||||
].clone()
|
||||
return backbone_out
|
||||
|
||||
def _forward_sam_heads(
|
||||
self,
|
||||
backbone_features,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
high_res_features=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||
cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
|
||||
"""
|
||||
B = backbone_features.size(0)
|
||||
device = backbone_features.device
|
||||
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
||||
assert backbone_features.size(2) == self.sam_image_embedding_size
|
||||
assert backbone_features.size(3) == self.sam_image_embedding_size
|
||||
|
||||
# a) Handle point prompts
|
||||
if point_inputs is not None:
|
||||
sam_point_coords = point_inputs["point_coords"]
|
||||
sam_point_labels = point_inputs["point_labels"]
|
||||
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
||||
else:
|
||||
# If no points are provide, pad with an empty point (with label -1)
|
||||
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
||||
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
||||
|
||||
# b) Handle mask prompts
|
||||
if mask_inputs is not None:
|
||||
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
||||
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
||||
sam_mask_prompt = F.interpolate(
|
||||
mask_inputs.float(),
|
||||
size=self.sam_prompt_encoder.mask_input_size,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True, # use antialias for downsampling
|
||||
)
|
||||
else:
|
||||
sam_mask_prompt = mask_inputs
|
||||
else:
|
||||
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
||||
# a learned `no_mask_embed` to indicate no mask input in this case).
|
||||
sam_mask_prompt = None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
||||
points=(sam_point_coords, sam_point_labels),
|
||||
boxes=None,
|
||||
masks=sam_mask_prompt,
|
||||
)
|
||||
# Clone image_pe and the outputs of sam_prompt_encoder
|
||||
# to enable compilation
|
||||
sparse_embeddings = sparse_embeddings.clone()
|
||||
dense_embeddings = dense_embeddings.clone()
|
||||
image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
|
||||
(
|
||||
low_res_multimasks,
|
||||
ious,
|
||||
sam_output_tokens,
|
||||
object_score_logits,
|
||||
) = self.sam_mask_decoder(
|
||||
image_embeddings=backbone_features,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=False, # the image is already batched
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# Clone the output of sam_mask_decoder
|
||||
# to enable compilation
|
||||
low_res_multimasks = low_res_multimasks.clone()
|
||||
ious = ious.clone()
|
||||
sam_output_tokens = sam_output_tokens.clone()
|
||||
object_score_logits = object_score_logits.clone()
|
||||
|
||||
if self.pred_obj_scores:
|
||||
is_obj_appearing = object_score_logits > 0
|
||||
|
||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||
# consistent with the actual mask prediction
|
||||
low_res_multimasks = torch.where(
|
||||
is_obj_appearing[:, None, None],
|
||||
low_res_multimasks,
|
||||
NO_OBJ_SCORE,
|
||||
)
|
||||
|
||||
# convert masks from possibly bfloat16 (or float16) to float32
|
||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||
low_res_multimasks = low_res_multimasks.float()
|
||||
high_res_multimasks = F.interpolate(
|
||||
low_res_multimasks,
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
sam_output_token = sam_output_tokens[:, 0]
|
||||
if multimask_output:
|
||||
# take the best mask prediction (with the highest IoU estimation)
|
||||
best_iou_inds = torch.argmax(ious, dim=-1)
|
||||
batch_inds = torch.arange(B, device=device)
|
||||
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
if sam_output_tokens.size(1) > 1:
|
||||
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
||||
else:
|
||||
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
||||
|
||||
# Extract object pointer from the SAM output token (with occlusion handling)
|
||||
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
||||
if self.pred_obj_scores:
|
||||
# Allow *soft* no obj ptr, unlike for masks
|
||||
if self.soft_no_obj_ptr:
|
||||
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
||||
else:
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
|
||||
if self.fixed_no_obj_ptr:
|
||||
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||
|
||||
return (
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
)
|
||||
|
||||
def _encode_new_memory(
|
||||
self,
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
pred_masks_high_res,
|
||||
object_score_logits,
|
||||
is_mask_from_pts,
|
||||
):
|
||||
"""
|
||||
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||
cloning the memories and their pos enc to enable compilation.
|
||||
"""
|
||||
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||
C = self.hidden_dim
|
||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||
# top-level feature, (HW)BC => BCHW
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||
if self.non_overlap_masks_for_mem_enc and not self.training:
|
||||
# optionally, apply non-overlapping constraints to the masks (it's applied
|
||||
# in the batch dimension and should only be used during eval, where all
|
||||
# the objects come from the same video under batch size 1).
|
||||
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
||||
pred_masks_high_res
|
||||
)
|
||||
# scale the raw mask logits with a temperature before applying sigmoid
|
||||
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
||||
if binarize and not self.training:
|
||||
mask_for_mem = (pred_masks_high_res > 0).float()
|
||||
else:
|
||||
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
||||
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
||||
# apply scale and bias terms to the sigmoid probabilities
|
||||
if self.sigmoid_scale_for_mem_enc != 1.0:
|
||||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(
|
||||
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
||||
)
|
||||
# Clone the feats and pos_enc to enable compilation
|
||||
maskmem_features = maskmem_out["vision_features"].clone()
|
||||
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
|
||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
||||
if self.no_obj_embed_spatial is not None:
|
||||
is_obj_appearing = (object_score_logits > 0).float()
|
||||
maskmem_features += (
|
||||
1 - is_obj_appearing[..., None, None]
|
||||
) * self.no_obj_embed_spatial[..., None, None].expand(
|
||||
*maskmem_features.shape
|
||||
)
|
||||
|
||||
return maskmem_features, maskmem_pos_enc
|
||||
|
1172
sam2/sam2_video_predictor_legacy.py
Normal file
1172
sam2/sam2_video_predictor_legacy.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user