fix: Fix setup
This commit is contained in:
29
setup.py
29
setup.py
@@ -27,14 +27,19 @@ import subprocess
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def install_torch():
|
def install_torch():
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
|
||||||
|
|
||||||
|
|
||||||
# Call the function to ensure torch is installed
|
# Call the function to ensure torch is installed
|
||||||
install_torch()
|
# install_torch()
|
||||||
|
sys.path.insert(
|
||||||
|
0, "/home/nguyendc/cuong-dev/GroundingDINO/.venv/lib/python3.13/site-packages"
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
@@ -48,7 +53,11 @@ cwd = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
sha = "Unknown"
|
sha = "Unknown"
|
||||||
try:
|
try:
|
||||||
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
|
sha = (
|
||||||
|
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
|
||||||
|
.decode("ascii")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -67,7 +76,9 @@ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
|
|||||||
|
|
||||||
def get_extensions():
|
def get_extensions():
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
|
extensions_dir = os.path.join(
|
||||||
|
this_dir, "groundingdino", "models", "GroundingDINO", "csrc"
|
||||||
|
)
|
||||||
|
|
||||||
main_source = os.path.join(extensions_dir, "vision.cpp")
|
main_source = os.path.join(extensions_dir, "vision.cpp")
|
||||||
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
|
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
|
||||||
@@ -82,7 +93,9 @@ def get_extensions():
|
|||||||
extra_compile_args = {"cxx": []}
|
extra_compile_args = {"cxx": []}
|
||||||
define_macros = []
|
define_macros = []
|
||||||
|
|
||||||
if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
|
if CUDA_HOME is not None and (
|
||||||
|
torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ
|
||||||
|
):
|
||||||
print("Compiling with CUDA")
|
print("Compiling with CUDA")
|
||||||
extension = CUDAExtension
|
extension = CUDAExtension
|
||||||
sources += source_cuda
|
sources += source_cuda
|
||||||
@@ -92,6 +105,10 @@ def get_extensions():
|
|||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||||
|
"-gencode=arch=compute_70,code=sm_70",
|
||||||
|
"-gencode=arch=compute_75,code=sm_75",
|
||||||
|
"-gencode=arch=compute_80,code=sm_80",
|
||||||
|
"-gencode=arch=compute_86,code=sm_86",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
print("Compiling without CUDA")
|
print("Compiling without CUDA")
|
||||||
@@ -99,7 +116,7 @@ def get_extensions():
|
|||||||
extra_compile_args["nvcc"] = []
|
extra_compile_args["nvcc"] = []
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
sources = [x.replace(this_dir + "/", "") for x in sources]
|
||||||
include_dirs = [extensions_dir]
|
include_dirs = [extensions_dir]
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
@@ -208,7 +225,7 @@ if __name__ == "__main__":
|
|||||||
url="https://github.com/IDEA-Research/GroundingDINO",
|
url="https://github.com/IDEA-Research/GroundingDINO",
|
||||||
description="open-set object detector",
|
description="open-set object detector",
|
||||||
license=license,
|
license=license,
|
||||||
install_requires=parse_requirements("requirements.txt"),
|
# install_requires=parse_requirements("requirements.txt"),
|
||||||
packages=find_packages(
|
packages=find_packages(
|
||||||
exclude=(
|
exclude=(
|
||||||
"configs",
|
"configs",
|
||||||
|
Reference in New Issue
Block a user