Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)

In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
This commit is contained in:
Ronghang Hu
2024-08-06 10:52:01 -07:00
committed by GitHub
parent 0230c5ff93
commit 6f7e700c37
5 changed files with 173 additions and 33 deletions

View File

@@ -11,6 +11,28 @@ Then, install SAM 2 from the root of this repository via
pip install -e ".[demo]" pip install -e ".[demo]"
``` ```
Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
```bash
# skip the SAM 2 CUDA extension
SAM2_BUILD_CUDA=0 pip install -e ".[demo]"
```
This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
### Building the SAM 2 CUDA extension
By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, you can still use SAM 2 for both image and video applications, but the post-processing step (removing small holes and sprinkles in the output masks) will be skipped. This shouldn't affect the results in most cases.
If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
```bash
pip uninstall -y SAM-2; SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
```
Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
### Common Installation Issues ### Common Installation Issues
Click each issue for its solutions: Click each issue for its solutions:
@@ -22,6 +44,8 @@ I got `ImportError: cannot import name '_C' from 'sam2'`
<br/> <br/>
This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails. This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/segment-anything-2/issues/77.
</details> </details>
<details> <details>

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import math import math
import warnings import warnings
from functools import partial from functools import partial
@@ -14,12 +15,30 @@ import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):
@@ -246,12 +265,19 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
with torch.backends.cuda.sdp_kernel( try:
enable_flash=USE_FLASH_ATTN, with sdp_kernel_context(dropout_p):
# if Flash attention kernel is off, then math kernel needs to be enabled out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, except Exception as e:
enable_mem_efficient=OLD_GPU, # Fall back to all kernels if the Flash attention kernel fails
): warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
@@ -313,12 +339,19 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
with torch.backends.cuda.sdp_kernel( try:
enable_flash=USE_FLASH_ATTN, with sdp_kernel_context(dropout_p):
# if Flash attention kernel is off, then math kernel needs to be enabled out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, except Exception as e:
enable_mem_efficient=OLD_GPU, # Fall back to all kernels if the Flash attention kernel fails
): warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)

View File

@@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
# Holes are those connected components in background with area <= self.max_area # Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0) # (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive" assert max_area > 0, "max_area must be positive"
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area) input_mask = mask
# We fill holes with a small positive mask score (0.1) to change them to foreground. try:
mask = torch.where(is_hole, 0.1, mask) labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask
return mask return mask

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -78,22 +80,38 @@ class SAM2Transforms(nn.Module):
from sam2.utils.misc import get_connected_components from sam2.utils.misc import get_connected_components
masks = masks.float() masks = masks.float()
if self.max_hole_area > 0: input_masks = masks
# Holes are those connected components in background with area <= self.fill_hole_area mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
# (background regions are those with mask scores <= self.mask_threshold) try:
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image if self.max_hole_area > 0:
labels, areas = get_connected_components(mask_flat <= self.mask_threshold) # Holes are those connected components in background with area <= self.fill_hole_area
is_hole = (labels > 0) & (areas <= self.max_hole_area) # (background regions are those with mask scores <= self.mask_threshold)
is_hole = is_hole.reshape_as(masks) labels, areas = get_connected_components(
# We fill holes with a small positive mask score (10.0) to change them to foreground. mask_flat <= self.mask_threshold
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) )
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0: if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(mask_flat > self.mask_threshold) labels, areas = get_connected_components(
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) mask_flat > self.mask_threshold
is_hole = is_hole.reshape_as(masks) )
# We fill holes with negative mask score (-10.0) to change them to background. is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks return masks

View File

@@ -3,6 +3,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
@@ -36,8 +37,18 @@ EXTRA_PACKAGES = {
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
} }
# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
def get_extensions(): def get_extensions():
if not BUILD_CUDA:
return []
srcs = ["sam2/csrc/connected_components.cu"] srcs = ["sam2/csrc/connected_components.cu"]
compile_args = { compile_args = {
"cxx": [], "cxx": [],
@@ -52,6 +63,40 @@ def get_extensions():
return ext_modules return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension):
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(self.ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
# Setup configuration # Setup configuration
setup( setup(
name=NAME, name=NAME,
@@ -68,5 +113,11 @@ setup(
extras_require=EXTRA_PACKAGES, extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0", python_requires=">=3.10.0",
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
) )