update to latest SAM 2
This commit is contained in:
96
setup.py
96
setup.py
@@ -6,7 +6,6 @@
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
# Package metadata
|
||||
NAME = "SAM 2"
|
||||
@@ -18,35 +17,18 @@ AUTHOR_EMAIL = "segment-anything@meta.com"
|
||||
LICENSE = "Apache 2.0"
|
||||
|
||||
# Read the contents of README file
|
||||
with open("README.md", "r") as f:
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
LONG_DESCRIPTION = f.read()
|
||||
|
||||
# Required dependencies
|
||||
REQUIRED_PACKAGES = [
|
||||
"torch>=2.3.1",
|
||||
"torch>=2.3.1",
|
||||
"torchvision>=0.18.1",
|
||||
"transformers",
|
||||
"numpy>=1.24.4",
|
||||
"tqdm>=4.66.1",
|
||||
"hydra-core>=1.3.2",
|
||||
"iopath>=0.1.10",
|
||||
"pillow>=9.4.0",
|
||||
"huggingface_hub",
|
||||
"diffusers[torch]==0.15.1",
|
||||
"onnxruntime==1.14.1",
|
||||
"onnx==1.13.1",
|
||||
"ipykernel==6.16.2",
|
||||
"scipy",
|
||||
"gradio",
|
||||
"openai",
|
||||
"matplotlib>=3.9.1",
|
||||
"opencv-python>=4.7.0",
|
||||
"dds_cloudapi_sdk",
|
||||
"addict",
|
||||
"yapf",
|
||||
"timm",
|
||||
"supervision>=0.22.0",
|
||||
"pycocotools",
|
||||
"pillow>=9.4.0",
|
||||
]
|
||||
|
||||
EXTRA_PACKAGES = {
|
||||
@@ -67,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
||||
CUDA_ERROR_MSG = (
|
||||
"{}\n\n"
|
||||
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
||||
"You can still use SAM 2, but some post-processing functionality may be limited "
|
||||
"You can still use SAM 2 and it's OK to ignore the error above, although some "
|
||||
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
|
||||
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
||||
)
|
||||
|
||||
@@ -77,6 +60,8 @@ def get_extensions():
|
||||
return []
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
@@ -98,29 +83,46 @@ def get_extensions():
|
||||
return ext_modules
|
||||
|
||||
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
try:
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
|
||||
cmdclass = {
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
cmdclass = {}
|
||||
if BUILD_ALLOW_ERRORS:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
# Setup configuration
|
||||
@@ -135,15 +137,11 @@ setup(
|
||||
author_email=AUTHOR_EMAIL,
|
||||
license=LICENSE,
|
||||
packages=find_packages(exclude="notebooks"),
|
||||
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
extras_require=EXTRA_PACKAGES,
|
||||
python_requires=">=3.10.0",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
),
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
)
|
||||
|
Reference in New Issue
Block a user