update to latest SAM 2

This commit is contained in:
rentainhe
2024-08-21 18:11:44 +08:00
parent 35efb4a5cb
commit 6e0ddadf7c
12 changed files with 140 additions and 87 deletions

View File

@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
output_mode: str = "binary_mask", output_mode: str = "binary_mask",
use_m2m: bool = False, use_m2m: bool = False,
multimask_output: bool = True, multimask_output: bool = True,
**kwargs,
) -> None: ) -> None:
""" """
Using a SAM 2 model, generates masks for the entire image. Using a SAM 2 model, generates masks for the entire image.
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
self.use_m2m = use_m2m self.use_m2m = use_m2m
self.multimask_output = multimask_output self.multimask_output = multimask_output
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)
@torch.no_grad() @torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
""" """
@@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator:
orig_h, orig_w = orig_size orig_h, orig_w = orig_size
# Run model on this batch # Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device) points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords( in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size points, normalize=normalize, orig_hw=im_size
) )

View File

@@ -19,6 +19,7 @@ def build_sam2(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
**kwargs,
): ):
if apply_postprocessing: if apply_postprocessing:
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
**kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",

View File

@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
self.dim = dim self.dim = dim
self.dim_out = dim_out self.dim_out = dim_out
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.q_pool = q_pool self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3) self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out) self.proj = nn.Linear(dim_out, dim_out)

View File

@@ -16,7 +16,7 @@ from torch import nn
class PositionEmbeddingSine(nn.Module): class PositionEmbeddingSine(nn.Module):
""" """
This is a more standard version of the position embedding, very similar to the one This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images. used by the Attention Is All You Need paper, generalized to work on images.
""" """
def __init__( def __init__(
@@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len # repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k: if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2] r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

View File

@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
continue # skip padding frames continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases, # "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU). # so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True) feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval) # Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding # Temporal positional encoding
maskmem_enc = ( maskmem_enc = (
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]

View File

@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
mask_threshold=0.0, mask_threshold=0.0,
max_hole_area=0.0, max_hole_area=0.0,
max_sprinkle_area=0.0, max_sprinkle_area=0.0,
**kwargs,
) -> None: ) -> None:
""" """
Uses SAM-2 to calculate the image embedding for an image, and then Uses SAM-2 to calculate the image embedding for an image, and then
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
sam_model (Sam-2): The model to use for mask prediction. sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default. to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks. the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
""" """
super().__init__() super().__init__()
self.model = sam_model self.model = sam_model
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
from sam2.build_sam import build_sam2_hf from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs) sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model) return cls(sam_model, **kwargs)
@torch.no_grad() @torch.no_grad()
def set_image( def set_image(
@@ -180,7 +183,7 @@ class SAM2ImagePredictor:
normalize_coords=True, normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tupele of lists of masks, ious, and low_res_masks_logits. It returns a tuple of lists of masks, ious, and low_res_masks_logits.
""" """
assert self._is_batch, "This function should only be used when in batched mode" assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set: if not self._is_image_set:

View File

@@ -44,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
offload_state_to_cpu=False, offload_state_to_cpu=False,
async_loading_frames=False, async_loading_frames=False,
): ):
"""Initialize a inference state.""" """Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames( images, video_height, video_width = load_video_frames(
video_path=video_path, video_path=video_path,
image_size=self.image_size, image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu, offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames, async_loading_frames=async_loading_frames,
compute_device=compute_device,
) )
inference_state = {} inference_state = {}
inference_state["images"] = images inference_state["images"] = images
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
# the original video height and width, used for resizing final output scores # the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height inference_state["video_height"] = video_height
inference_state["video_width"] = video_width inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda") inference_state["device"] = compute_device
if offload_state_to_cpu: if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu") inference_state["storage_device"] = torch.device("cpu")
else: else:
inference_state["storage_device"] = torch.device("cuda") inference_state["storage_device"] = compute_device
# inputs on each frame # inputs on each frame
inference_state["point_inputs_per_obj"] = {} inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {}
@@ -119,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
from sam2.build_sam import build_sam2_video_predictor_hf from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model) return sam_model
def _obj_id_to_idx(self, inference_state, obj_id): def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index.""" """Map client-side object id to model-side object index."""
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None: if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference( current_out, _ = self._run_single_frame_inference(
@@ -586,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base):
# to `propagate_in_video_preflight`). # to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"] consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]: for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus # Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects # Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs # (these should be the frames that have just received clicks for mask inputs
@@ -595,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base):
for obj_temp_output_dict in temp_output_dict_per_obj.values(): for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds) consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame # consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds: for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj( consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
) )
if backbone_out is None: if backbone_out is None:
# Cache miss -- we will run inference on a single image # Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image) backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with # Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future). # a frame; we can use an LRU cache for more frames in the future).

View File

@@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
compute bounding box given an input mask compute bounding box given an input mask
Inputs: Inputs:
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor - masks: [B, 1, H, W] masks, dtype=torch.Tensor
Returns: Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
@@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start. A list of video frames to be load asynchronously without blocking session start.
""" """
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths self.img_paths = img_paths
self.image_size = image_size self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean self.img_mean = img_mean
self.img_std = img_std self.img_std = img_std
# items in `self._images` will be loaded asynchronously # items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths) self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread # catch and raise any exceptions in the async loading thread
self.exception = None self.exception = None
# video_height and video_width be filled when loading the first image # video_height and video_width be filled when loading the first image
self.video_height = None self.video_height = None
self.video_width = None self.video_width = None
self.compute_device = compute_device
# load the first frame to fill video_height and video_width and also # load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click) # to cache it (since it's most likely where the user will click)
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
img -= self.img_mean img -= self.img_mean
img /= self.img_std img /= self.img_std
if not self.offload_video_to_cpu: if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True) img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img self.images[index] = img
return img return img
@@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406), img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225), img_std=(0.229, 0.224, 0.225),
async_loading_frames=False, async_loading_frames=False,
compute_device=torch.device("cuda"),
): ):
""" """
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format). Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
@@ -179,12 +189,20 @@ def load_video_frames(
if isinstance(video_path, str) and os.path.isdir(video_path): if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path jpg_folder = video_path
else: else:
raise NotImplementedError("Only JPEG frames are supported at this moment") raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)
frame_names = [ frame_names = [
p p
for p in os.listdir(jpg_folder) for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
] ]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names) num_frames = len(frame_names)
@@ -196,7 +214,12 @@ def load_video_frames(
if async_loading_frames: if async_loading_frames:
lazy_images = AsyncVideoFrameLoader( lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
) )
return lazy_images, lazy_images.video_height, lazy_images.video_width return lazy_images, lazy_images.video_height, lazy_images.video_width
@@ -204,9 +227,9 @@ def load_video_frames(
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu: if not offload_video_to_cpu:
images = images.cuda() images = images.to(compute_device)
img_mean = img_mean.cuda() img_mean = img_mean.to(compute_device)
img_std = img_std.cuda() img_std = img_std.to(compute_device)
# normalize by mean and std # normalize by mean and std
images -= img_mean images -= img_mean
images /= img_std images /= img_std
@@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area):
except Exception as e: except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails # Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn( warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. " f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"Consider building SAM 2 with CUDA extension to enable post-processing (see " "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning, category=UserWarning,
stacklevel=2, stacklevel=2,

View File

@@ -105,8 +105,9 @@ class SAM2Transforms(nn.Module):
except Exception as e: except Exception as e:
# Skip the post-processing step if the CUDA kernel fails # Skip the post-processing step if the CUDA kernel fails
warnings.warn( warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. " f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"Consider building SAM 2 with CUDA extension to enable post-processing (see " "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning, category=UserWarning,
stacklevel=2, stacklevel=2,

View File

@@ -72,7 +72,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--do_not_skip_first_and_last_frame", "--do_not_skip_first_and_last_frame",
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. " help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
"Set this to true for evaluation on settings that doen't skip first and last frames", "Set this to true for evaluation on settings that doesn't skip first and last frames",
action="store_true", action="store_true",
) )

View File

@@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None):
assert not ( assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01 width > w | height > h | abs(ar1 - ar2) > 0.01
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) ), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
e = np.zeros_like(seg) e = np.zeros_like(seg)
s = np.zeros_like(seg) s = np.zeros_like(seg)

View File

@@ -6,7 +6,6 @@
import os import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Package metadata # Package metadata
NAME = "SAM 2" NAME = "SAM 2"
@@ -18,35 +17,18 @@ AUTHOR_EMAIL = "segment-anything@meta.com"
LICENSE = "Apache 2.0" LICENSE = "Apache 2.0"
# Read the contents of README file # Read the contents of README file
with open("README.md", "r") as f: with open("README.md", "r", encoding="utf-8") as f:
LONG_DESCRIPTION = f.read() LONG_DESCRIPTION = f.read()
# Required dependencies # Required dependencies
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
"torch>=2.3.1", "torch>=2.3.1",
"torchvision>=0.18.1", "torchvision>=0.18.1",
"transformers",
"numpy>=1.24.4", "numpy>=1.24.4",
"tqdm>=4.66.1", "tqdm>=4.66.1",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
"iopath>=0.1.10", "iopath>=0.1.10",
"pillow>=9.4.0", "pillow>=9.4.0",
"huggingface_hub",
"diffusers[torch]==0.15.1",
"onnxruntime==1.14.1",
"onnx==1.13.1",
"ipykernel==6.16.2",
"scipy",
"gradio",
"openai",
"matplotlib>=3.9.1",
"opencv-python>=4.7.0",
"dds_cloudapi_sdk",
"addict",
"yapf",
"timm",
"supervision>=0.22.0",
"pycocotools",
] ]
EXTRA_PACKAGES = { EXTRA_PACKAGES = {
@@ -67,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
CUDA_ERROR_MSG = ( CUDA_ERROR_MSG = (
"{}\n\n" "{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. " "Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited " "You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n" "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
) )
@@ -77,6 +60,8 @@ def get_extensions():
return [] return []
try: try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2/csrc/connected_components.cu"] srcs = ["sam2/csrc/connected_components.cu"]
compile_args = { compile_args = {
"cxx": [], "cxx": [],
@@ -98,29 +83,46 @@ def get_extensions():
return ext_modules return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension): try:
from torch.utils.cpp_extension import BuildExtension
def finalize_options(self): class BuildExtensionIgnoreErrors(BuildExtension):
try:
super().finalize_options()
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
def build_extensions(self): def finalize_options(self):
try: try:
super().build_extensions() super().finalize_options()
except Exception as e: except Exception as e:
print(CUDA_ERROR_MSG.format(e)) print(CUDA_ERROR_MSG.format(e))
self.extensions = [] self.extensions = []
def get_ext_filename(self, ext_name): def build_extensions(self):
try: try:
return super().get_ext_filename(ext_name) super().build_extensions()
except Exception as e: except Exception as e:
print(CUDA_ERROR_MSG.format(e)) print(CUDA_ERROR_MSG.format(e))
self.extensions = [] self.extensions = []
return "_C.so"
def get_ext_filename(self, ext_name):
try:
return super().get_ext_filename(ext_name)
except Exception as e:
print(CUDA_ERROR_MSG.format(e))
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration # Setup configuration
@@ -135,15 +137,11 @@ setup(
author_email=AUTHOR_EMAIL, author_email=AUTHOR_EMAIL,
license=LICENSE, license=LICENSE,
packages=find_packages(exclude="notebooks"), packages=find_packages(exclude="notebooks"),
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
include_package_data=True,
install_requires=REQUIRED_PACKAGES, install_requires=REQUIRED_PACKAGES,
extras_require=EXTRA_PACKAGES, extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0", python_requires=">=3.10.0",
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={ cmdclass=cmdclass,
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
) )