fix: Fix setup

This commit is contained in:
kiennt
2025-08-16 09:47:52 +00:00
parent 856dde20ae
commit e1420f9335

View File

@@ -27,14 +27,19 @@ import subprocess
import subprocess
import sys
def install_torch():
try:
import torch
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
# Call the function to ensure torch is installed
install_torch()
# install_torch()
sys.path.insert(
0, "/home/nguyendc/cuong-dev/GroundingDINO/.venv/lib/python3.13/site-packages"
)
import torch
from setuptools import find_packages, setup
@@ -48,7 +53,11 @@ cwd = os.path.dirname(os.path.abspath(__file__))
sha = "Unknown"
try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd)
.decode("ascii")
.strip()
)
except Exception:
pass
@@ -67,7 +76,9 @@ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
extensions_dir = os.path.join(
this_dir, "groundingdino", "models", "GroundingDINO", "csrc"
)
main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
@@ -82,7 +93,9 @@ def get_extensions():
extra_compile_args = {"cxx": []}
define_macros = []
if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
if CUDA_HOME is not None and (
torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ
):
print("Compiling with CUDA")
extension = CUDAExtension
sources += source_cuda
@@ -92,6 +105,10 @@ def get_extensions():
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"-gencode=arch=compute_70,code=sm_70",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
]
else:
print("Compiling without CUDA")
@@ -99,7 +116,7 @@ def get_extensions():
extra_compile_args["nvcc"] = []
return None
sources = [os.path.join(extensions_dir, s) for s in sources]
sources = [x.replace(this_dir + "/", "") for x in sources]
include_dirs = [extensions_dir]
ext_modules = [
@@ -208,7 +225,7 @@ if __name__ == "__main__":
url="https://github.com/IDEA-Research/GroundingDINO",
description="open-set object detector",
license=license,
install_requires=parse_requirements("requirements.txt"),
# install_requires=parse_requirements("requirements.txt"),
packages=find_packages(
exclude=(
"configs",