update to latest SAM 2
This commit is contained in:
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
||||
output_mode: str = "binary_mask",
|
||||
use_m2m: bool = False,
|
||||
multimask_output: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM 2 model, generates masks for the entire image.
|
||||
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
||||
self.use_m2m = use_m2m
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2AutomaticMaskGenerator): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
points = torch.as_tensor(points, device=self.predictor.device)
|
||||
points = torch.as_tensor(
|
||||
points, dtype=torch.float32, device=self.predictor.device
|
||||
)
|
||||
in_points = self.predictor._transforms.transform_coords(
|
||||
points, normalize=normalize, orig_hw=im_size
|
||||
)
|
||||
|
@@ -19,6 +19,7 @@ def build_sam2(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if apply_postprocessing:
|
||||
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
hydra_overrides = [
|
||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||
|
@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
|
@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to calculate the image embedding for an image, and then
|
||||
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
|
||||
sam_model (Sam-2): The model to use for mask prediction.
|
||||
mask_threshold (float): The threshold to use when converting mask logits
|
||||
to binary masks. Masks are thresholded at 0 by default.
|
||||
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of fill_hole_area in low_res_masks.
|
||||
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(
|
||||
@@ -180,7 +183,7 @@ class SAM2ImagePredictor:
|
||||
normalize_coords=True,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
||||
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
|
||||
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
||||
"""
|
||||
assert self._is_batch, "This function should only be used when in batched mode"
|
||||
if not self._is_image_set:
|
||||
|
@@ -44,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
offload_state_to_cpu=False,
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize a inference state."""
|
||||
"""Initialize an inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# the original video height and width, used for resizing final output scores
|
||||
inference_state["video_height"] = video_height
|
||||
inference_state["video_width"] = video_width
|
||||
inference_state["device"] = torch.device("cuda")
|
||||
inference_state["device"] = compute_device
|
||||
if offload_state_to_cpu:
|
||||
inference_state["storage_device"] = torch.device("cpu")
|
||||
else:
|
||||
inference_state["storage_device"] = torch.device("cuda")
|
||||
inference_state["storage_device"] = compute_device
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
@@ -119,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
from sam2.build_sam import build_sam2_video_predictor_hf
|
||||
|
||||
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
||||
return cls(sam_model)
|
||||
return sam_model
|
||||
|
||||
def _obj_id_to_idx(self, inference_state, obj_id):
|
||||
"""Map client-side object id to model-side object index."""
|
||||
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
|
||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
||||
device = inference_state["device"]
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||
current_out, _ = self._run_single_frame_inference(
|
||||
@@ -586,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
@@ -595,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temprary output across all objects on this frame
|
||||
# consolidate the temporary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
if backbone_out is None:
|
||||
# Cache miss -- we will run inference on a single image
|
||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
||||
device = inference_state["device"]
|
||||
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||
backbone_out = self.forward_image(image)
|
||||
# Cache the most recent frame's feature (for repeated interactions with
|
||||
# a frame; we can use an LRU cache for more frames in the future).
|
||||
|
@@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
|
||||
compute bounding box given an input mask
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
|
||||
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
|
||||
@@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
|
||||
A list of video frames to be load asynchronously without blocking session start.
|
||||
"""
|
||||
|
||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
||||
def __init__(
|
||||
self,
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
):
|
||||
self.img_paths = img_paths
|
||||
self.image_size = image_size
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
self.img_mean = img_mean
|
||||
self.img_std = img_std
|
||||
# items in `self._images` will be loaded asynchronously
|
||||
# items in `self.images` will be loaded asynchronously
|
||||
self.images = [None] * len(img_paths)
|
||||
# catch and raise any exceptions in the async loading thread
|
||||
self.exception = None
|
||||
# video_height and video_width be filled when loading the first image
|
||||
self.video_height = None
|
||||
self.video_width = None
|
||||
self.compute_device = compute_device
|
||||
|
||||
# load the first frame to fill video_height and video_width and also
|
||||
# to cache it (since it's most likely where the user will click)
|
||||
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
||||
img -= self.img_mean
|
||||
img /= self.img_std
|
||||
if not self.offload_video_to_cpu:
|
||||
img = img.cuda(non_blocking=True)
|
||||
img = img.to(self.compute_device, non_blocking=True)
|
||||
self.images[index] = img
|
||||
return img
|
||||
|
||||
@@ -167,6 +176,7 @@ def load_video_frames(
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
@@ -179,12 +189,20 @@ def load_video_frames(
|
||||
if isinstance(video_path, str) and os.path.isdir(video_path):
|
||||
jpg_folder = video_path
|
||||
else:
|
||||
raise NotImplementedError("Only JPEG frames are supported at this moment")
|
||||
raise NotImplementedError(
|
||||
"Only JPEG frames are supported at this moment. For video files, you may use "
|
||||
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
|
||||
"```\n"
|
||||
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
|
||||
"```\n"
|
||||
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
|
||||
"ffmpeg to start the JPEG file from 00000.jpg."
|
||||
)
|
||||
|
||||
frame_names = [
|
||||
p
|
||||
for p in os.listdir(jpg_folder)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
num_frames = len(frame_names)
|
||||
@@ -196,7 +214,12 @@ def load_video_frames(
|
||||
|
||||
if async_loading_frames:
|
||||
lazy_images = AsyncVideoFrameLoader(
|
||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
)
|
||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||
|
||||
@@ -204,9 +227,9 @@ def load_video_frames(
|
||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||
if not offload_video_to_cpu:
|
||||
images = images.cuda()
|
||||
img_mean = img_mean.cuda()
|
||||
img_std = img_std.cuda()
|
||||
images = images.to(compute_device)
|
||||
img_mean = img_mean.to(compute_device)
|
||||
img_std = img_std.to(compute_device)
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
@@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
except Exception as e:
|
||||
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
|
@@ -105,8 +105,9 @@ class SAM2Transforms(nn.Module):
|
||||
except Exception as e:
|
||||
# Skip the post-processing step if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
|
@@ -72,7 +72,7 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--do_not_skip_first_and_last_frame",
|
||||
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
|
||||
"Set this to true for evaluation on settings that doen't skip first and last frames",
|
||||
"Set this to true for evaluation on settings that doesn't skip first and last frames",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
|
@@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None):
|
||||
|
||||
assert not (
|
||||
width > w | height > h | abs(ar1 - ar2) > 0.01
|
||||
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
||||
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
||||
|
||||
e = np.zeros_like(seg)
|
||||
s = np.zeros_like(seg)
|
||||
|
92
setup.py
92
setup.py
@@ -6,7 +6,6 @@
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
# Package metadata
|
||||
NAME = "SAM 2"
|
||||
@@ -18,35 +17,18 @@ AUTHOR_EMAIL = "segment-anything@meta.com"
|
||||
LICENSE = "Apache 2.0"
|
||||
|
||||
# Read the contents of README file
|
||||
with open("README.md", "r") as f:
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
LONG_DESCRIPTION = f.read()
|
||||
|
||||
# Required dependencies
|
||||
REQUIRED_PACKAGES = [
|
||||
"torch>=2.3.1",
|
||||
"torchvision>=0.18.1",
|
||||
"transformers",
|
||||
"numpy>=1.24.4",
|
||||
"tqdm>=4.66.1",
|
||||
"hydra-core>=1.3.2",
|
||||
"iopath>=0.1.10",
|
||||
"pillow>=9.4.0",
|
||||
"huggingface_hub",
|
||||
"diffusers[torch]==0.15.1",
|
||||
"onnxruntime==1.14.1",
|
||||
"onnx==1.13.1",
|
||||
"ipykernel==6.16.2",
|
||||
"scipy",
|
||||
"gradio",
|
||||
"openai",
|
||||
"matplotlib>=3.9.1",
|
||||
"opencv-python>=4.7.0",
|
||||
"dds_cloudapi_sdk",
|
||||
"addict",
|
||||
"yapf",
|
||||
"timm",
|
||||
"supervision>=0.22.0",
|
||||
"pycocotools",
|
||||
]
|
||||
|
||||
EXTRA_PACKAGES = {
|
||||
@@ -67,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
||||
CUDA_ERROR_MSG = (
|
||||
"{}\n\n"
|
||||
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
||||
"You can still use SAM 2, but some post-processing functionality may be limited "
|
||||
"You can still use SAM 2 and it's OK to ignore the error above, although some "
|
||||
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
|
||||
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
||||
)
|
||||
|
||||
@@ -77,6 +60,8 @@ def get_extensions():
|
||||
return []
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
@@ -98,29 +83,46 @@ def get_extensions():
|
||||
return ext_modules
|
||||
|
||||
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
try:
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
|
||||
cmdclass = {
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
cmdclass = {}
|
||||
if BUILD_ALLOW_ERRORS:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
# Setup configuration
|
||||
@@ -135,15 +137,11 @@ setup(
|
||||
author_email=AUTHOR_EMAIL,
|
||||
license=LICENSE,
|
||||
packages=find_packages(exclude="notebooks"),
|
||||
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
extras_require=EXTRA_PACKAGES,
|
||||
python_requires=">=3.10.0",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
),
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
)
|
||||
|
Reference in New Issue
Block a user