SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)

This PR provides new features and updates for SAM 2:

- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
  * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
  * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
  * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
  * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
  * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
  * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
  * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
This commit is contained in:
Ronghang Hu
2024-12-11 15:00:55 -08:00
committed by GitHub
parent c2ec8e14a1
commit 393ae336a7
27 changed files with 1794 additions and 443 deletions

View File

@@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(
x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
x = x[:, :H, :W, :]
return x

View File

@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000,
normalize: bool = True,
scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
):
super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width"
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale
self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y):
# The positions are expected to be normalized
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
cache_key = (x.shape[-2], x.shape[-1])
def _pe(self, B, device, *cache_key):
H, W = cache_key
if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1])
.repeat(B, 1, W)
)
x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1)
.repeat(B, H, 1)
)
if self.normalize:
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0]
return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module):
"""

View File

@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding[labels == 2] += self.point_embeddings[2].weight
point_embedding[labels == 3] += self.point_embeddings[3].weight
point_embedding = torch.where(
(labels == -1).unsqueeze(-1),
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

View File

@@ -4,9 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import math
import warnings
from functools import partial
from typing import Tuple, Type
@@ -16,29 +14,6 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module):
@@ -265,20 +240,7 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs,
):
super().__init__(*args, **kwargs)
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis
self.freqs_cis = (
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat
def forward(
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)

View File

@@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device)
obj_pos = (
torch.tensor(pos_list)
.pin_memory()
.to(device=device, non_blocking=True)
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)