update to the latest sam2 version and support box prompts in video tracking

This commit is contained in:
rentainhe
2024-08-08 12:03:29 +08:00
parent 96cbab92e0
commit 077064c365
8 changed files with 272 additions and 45 deletions

View File

@@ -3,6 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
@@ -36,22 +37,75 @@ EXTRA_PACKAGES = {
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}
# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
def get_extensions():
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
if not BUILD_CUDA:
return []
try:
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
"nvcc": [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
}
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e
return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension):
def finalize_options(self):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
# Setup configuration
setup(
name=NAME,
@@ -68,5 +122,11 @@ setup(
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
)