Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step. 1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`. 2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step. We also fall back to the all available kernels if the Flash Attention kernel fails.
This commit is contained in:
53
setup.py
53
setup.py
@@ -3,6 +3,7 @@
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
@@ -36,8 +37,18 @@ EXTRA_PACKAGES = {
|
||||
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
|
||||
}
|
||||
|
||||
# By default, we also build the SAM 2 CUDA extension.
|
||||
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
|
||||
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
|
||||
# By default, we allow SAM 2 installation to proceed even with build errors.
|
||||
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
|
||||
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
||||
|
||||
|
||||
def get_extensions():
|
||||
if not BUILD_CUDA:
|
||||
return []
|
||||
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
@@ -52,6 +63,40 @@ def get_extensions():
|
||||
return ext_modules
|
||||
|
||||
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
# Catch and skip errors during extension building and print a warning message
|
||||
# (note that this message only shows up under verbose build mode
|
||||
# "pip install -v -e ." or "python setup.py build_ext -v")
|
||||
ERROR_MSG = (
|
||||
"{}\n\n"
|
||||
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
||||
"You can still use SAM 2, but some post-processing functionality may be limited "
|
||||
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
||||
)
|
||||
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(self.ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(self.ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(self.ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
|
||||
|
||||
# Setup configuration
|
||||
setup(
|
||||
name=NAME,
|
||||
@@ -68,5 +113,11 @@ setup(
|
||||
extras_require=EXTRA_PACKAGES,
|
||||
python_requires=">=3.10.0",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
|
||||
cmdclass={
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
Reference in New Issue
Block a user