Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step. 1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`. 2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step. We also fall back to the all available kernels if the Flash Attention kernel fails.
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
@@ -14,12 +15,30 @@ import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
# Check whether Flash Attention is available (and use it by default)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
||||
ALLOW_ALL_KERNELS = False
|
||||
|
||||
|
||||
def sdp_kernel_context(dropout_p):
|
||||
"""
|
||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
||||
by default, but fall back to all available kernels if Flash Attention fails.
|
||||
"""
|
||||
if ALLOW_ALL_KERNELS:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
@@ -246,12 +265,19 @@ class Attention(nn.Module):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
@@ -313,12 +339,19 @@ class RoPEAttention(Attention):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
|
Reference in New Issue
Block a user