SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
This commit is contained in:
@@ -4,9 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
@@ -16,29 +14,6 @@ from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
# Check whether Flash Attention is available (and use it by default)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
||||
ALLOW_ALL_KERNELS = False
|
||||
|
||||
|
||||
def sdp_kernel_context(dropout_p):
|
||||
"""
|
||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
||||
by default, but fall back to all available kernels if Flash Attention fails.
|
||||
"""
|
||||
if ALLOW_ALL_KERNELS:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
@@ -265,20 +240,7 @@ class Attention(nn.Module):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
|
||||
# whether to repeat q rope to match k length
|
||||
# this is needed for cross-attention to memories
|
||||
rope_k_repeat=False,
|
||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
||||
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
|
||||
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
||||
)
|
||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||
self.freqs_cis = freqs_cis
|
||||
self.freqs_cis = (
|
||||
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
||||
)
|
||||
self.rope_k_repeat = rope_k_repeat
|
||||
|
||||
def forward(
|
||||
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
Reference in New Issue
Block a user