
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step. 1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`. 2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step. We also fall back to the all available kernels if the Flash Attention kernel fails.
124 lines
3.7 KiB
Python
124 lines
3.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import os
|
|
|
|
from setuptools import find_packages, setup
|
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
|
|
# Package metadata
|
|
NAME = "SAM 2"
|
|
VERSION = "1.0"
|
|
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
|
|
URL = "https://github.com/facebookresearch/segment-anything-2"
|
|
AUTHOR = "Meta AI"
|
|
AUTHOR_EMAIL = "segment-anything@meta.com"
|
|
LICENSE = "Apache 2.0"
|
|
|
|
# Read the contents of README file
|
|
with open("README.md", "r") as f:
|
|
LONG_DESCRIPTION = f.read()
|
|
|
|
# Required dependencies
|
|
REQUIRED_PACKAGES = [
|
|
"torch>=2.3.1",
|
|
"torchvision>=0.18.1",
|
|
"numpy>=1.24.4",
|
|
"tqdm>=4.66.1",
|
|
"hydra-core>=1.3.2",
|
|
"iopath>=0.1.10",
|
|
"pillow>=9.4.0",
|
|
]
|
|
|
|
EXTRA_PACKAGES = {
|
|
"demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
|
|
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
|
|
}
|
|
|
|
# By default, we also build the SAM 2 CUDA extension.
|
|
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
|
|
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
|
|
# By default, we allow SAM 2 installation to proceed even with build errors.
|
|
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
|
|
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
|
|
|
|
|
def get_extensions():
|
|
if not BUILD_CUDA:
|
|
return []
|
|
|
|
srcs = ["sam2/csrc/connected_components.cu"]
|
|
compile_args = {
|
|
"cxx": [],
|
|
"nvcc": [
|
|
"-DCUDA_HAS_FP16=1",
|
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
],
|
|
}
|
|
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
|
return ext_modules
|
|
|
|
|
|
class BuildExtensionIgnoreErrors(BuildExtension):
|
|
# Catch and skip errors during extension building and print a warning message
|
|
# (note that this message only shows up under verbose build mode
|
|
# "pip install -v -e ." or "python setup.py build_ext -v")
|
|
ERROR_MSG = (
|
|
"{}\n\n"
|
|
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
|
"You can still use SAM 2, but some post-processing functionality may be limited "
|
|
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
|
)
|
|
|
|
def finalize_options(self):
|
|
try:
|
|
super().finalize_options()
|
|
except Exception as e:
|
|
print(self.ERROR_MSG.format(e))
|
|
self.extensions = []
|
|
|
|
def build_extensions(self):
|
|
try:
|
|
super().build_extensions()
|
|
except Exception as e:
|
|
print(self.ERROR_MSG.format(e))
|
|
self.extensions = []
|
|
|
|
def get_ext_filename(self, ext_name):
|
|
try:
|
|
return super().get_ext_filename(ext_name)
|
|
except Exception as e:
|
|
print(self.ERROR_MSG.format(e))
|
|
self.extensions = []
|
|
return "_C.so"
|
|
|
|
|
|
# Setup configuration
|
|
setup(
|
|
name=NAME,
|
|
version=VERSION,
|
|
description=DESCRIPTION,
|
|
long_description=LONG_DESCRIPTION,
|
|
long_description_content_type="text/markdown",
|
|
url=URL,
|
|
author=AUTHOR,
|
|
author_email=AUTHOR_EMAIL,
|
|
license=LICENSE,
|
|
packages=find_packages(exclude="notebooks"),
|
|
install_requires=REQUIRED_PACKAGES,
|
|
extras_require=EXTRA_PACKAGES,
|
|
python_requires=">=3.10.0",
|
|
ext_modules=get_extensions(),
|
|
cmdclass={
|
|
"build_ext": (
|
|
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
|
if BUILD_ALLOW_ERRORS
|
|
else BuildExtension.with_options(no_python_abi_suffix=True)
|
|
),
|
|
},
|
|
)
|