improving warning message and adding further tips for installation (#204)

This commit is contained in:
Ronghang Hu
2024-08-12 11:37:41 -07:00
committed by GitHub
parent 1034ee2a1a
commit dce7b5446f
5 changed files with 84 additions and 40 deletions

View File

@@ -6,7 +6,6 @@
import os
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Package metadata
NAME = "SAM 2"
@@ -50,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
@@ -60,6 +60,8 @@ def get_extensions():
return []
try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
@@ -81,29 +83,46 @@ def get_extensions():
return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension):
try:
from torch.utils.cpp_extension import BuildExtension
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
class BuildExtensionIgnoreErrors(BuildExtension):
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration
@@ -124,11 +143,5 @@ setup(
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
cmdclass=cmdclass,
)