update to the latest sam2 version and support box prompts in video tracking
This commit is contained in:
@@ -22,7 +22,7 @@ TEXT_PROMPT = "hippopotamus."
|
|||||||
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
||||||
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
||||||
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "mask" # "point"
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -128,7 +128,7 @@ if masks.ndim == 4:
|
|||||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
|
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
||||||
|
|
||||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||||
if PROMPT_TYPE_FOR_VIDEO == "point":
|
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||||
@@ -137,13 +137,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point":
|
|||||||
|
|
||||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
frame_idx=ann_frame_idx,
|
frame_idx=ann_frame_idx,
|
||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
points=points,
|
points=points,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
|
# Using box prompt
|
||||||
|
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
||||||
|
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
# Using mask prompt is a more straightforward way
|
# Using mask prompt is a more straightforward way
|
||||||
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||||
@@ -154,6 +163,8 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
|||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
mask=mask
|
mask=mask
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@@ -28,8 +28,8 @@ TEXT_PROMPT = "hippopotamus."
|
|||||||
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
||||||
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
||||||
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
||||||
API_TOKEN_FOR_GD1_5 = "Your API token"
|
API_TOKEN_FOR_GD1_5 = "3491a2a256fb7ed01b2e757b713c4cb0"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "mask" # "point"
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -152,7 +152,7 @@ if masks.ndim == 4:
|
|||||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
|
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
||||||
|
|
||||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||||
if PROMPT_TYPE_FOR_VIDEO == "point":
|
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||||
@@ -161,13 +161,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point":
|
|||||||
|
|
||||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
frame_idx=ann_frame_idx,
|
frame_idx=ann_frame_idx,
|
||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
points=points,
|
points=points,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
|
# Using box prompt
|
||||||
|
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
||||||
|
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
# Using mask prompt is a more straightforward way
|
# Using mask prompt is a more straightforward way
|
||||||
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||||
@@ -178,7 +187,8 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
|||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
mask=mask
|
mask=mask
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||||
|
@@ -76,6 +76,44 @@ def build_sam2_video_predictor(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam2_hf(model_id, **kwargs):
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
model_id_to_filenames = {
|
||||||
|
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||||
|
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||||
|
"facebook/sam2-hiera-base-plus": (
|
||||||
|
"sam2_hiera_b+.yaml",
|
||||||
|
"sam2_hiera_base_plus.pt",
|
||||||
|
),
|
||||||
|
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||||
|
}
|
||||||
|
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||||
|
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||||
|
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
model_id_to_filenames = {
|
||||||
|
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||||
|
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||||
|
"facebook/sam2-hiera-base-plus": (
|
||||||
|
"sam2_hiera_b+.yaml",
|
||||||
|
"sam2_hiera_base_plus.pt",
|
||||||
|
),
|
||||||
|
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||||
|
}
|
||||||
|
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||||
|
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||||
|
return build_sam2_video_predictor(
|
||||||
|
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_checkpoint(model, ckpt_path):
|
def _load_checkpoint(model, ckpt_path):
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
sd = torch.load(ckpt_path, map_location="cpu")["model"]
|
sd = torch.load(ckpt_path, map_location="cpu")["model"]
|
||||||
|
@@ -62,6 +62,23 @@ class SAM2ImagePredictor:
|
|||||||
(64, 64),
|
(64, 64),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
||||||
|
"""
|
||||||
|
Load a pretrained model from the Hugging Face hub.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_id (str): The Hugging Face repository ID.
|
||||||
|
**kwargs: Additional arguments to pass to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(SAM2ImagePredictor): The loaded model.
|
||||||
|
"""
|
||||||
|
from sam2.build_sam import build_sam2_hf
|
||||||
|
|
||||||
|
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||||
|
return cls(sam_model)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def set_image(
|
def set_image(
|
||||||
self,
|
self,
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -103,6 +104,23 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||||
return inference_state
|
return inference_state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
||||||
|
"""
|
||||||
|
Load a pretrained model from the Hugging Face hub.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_id (str): The Hugging Face repository ID.
|
||||||
|
**kwargs: Additional arguments to pass to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(SAM2VideoPredictor): The loaded model.
|
||||||
|
"""
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor_hf
|
||||||
|
|
||||||
|
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
||||||
|
return cls(sam_model)
|
||||||
|
|
||||||
def _obj_id_to_idx(self, inference_state, obj_id):
|
def _obj_id_to_idx(self, inference_state, obj_id):
|
||||||
"""Map client-side object id to model-side object index."""
|
"""Map client-side object id to model-side object index."""
|
||||||
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||||
@@ -146,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
return len(inference_state["obj_idx_to_id"])
|
return len(inference_state["obj_idx_to_id"])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def add_new_points(
|
def add_new_points_or_box(
|
||||||
self,
|
self,
|
||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
obj_id,
|
obj_id,
|
||||||
points,
|
points=None,
|
||||||
labels,
|
labels=None,
|
||||||
clear_old_points=True,
|
clear_old_points=True,
|
||||||
normalize_coords=True,
|
normalize_coords=True,
|
||||||
|
box=None,
|
||||||
):
|
):
|
||||||
"""Add new points to a frame."""
|
"""Add new points to a frame."""
|
||||||
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
||||||
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
||||||
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
||||||
|
|
||||||
if not isinstance(points, torch.Tensor):
|
if (points is not None) != (labels is not None):
|
||||||
|
raise ValueError("points and labels must be provided together")
|
||||||
|
if points is None and box is None:
|
||||||
|
raise ValueError("at least one of points or box must be provided as input")
|
||||||
|
|
||||||
|
if points is None:
|
||||||
|
points = torch.zeros(0, 2, dtype=torch.float32)
|
||||||
|
elif not isinstance(points, torch.Tensor):
|
||||||
points = torch.tensor(points, dtype=torch.float32)
|
points = torch.tensor(points, dtype=torch.float32)
|
||||||
if not isinstance(labels, torch.Tensor):
|
if labels is None:
|
||||||
|
labels = torch.zeros(0, dtype=torch.int32)
|
||||||
|
elif not isinstance(labels, torch.Tensor):
|
||||||
labels = torch.tensor(labels, dtype=torch.int32)
|
labels = torch.tensor(labels, dtype=torch.int32)
|
||||||
if points.dim() == 2:
|
if points.dim() == 2:
|
||||||
points = points.unsqueeze(0) # add batch dimension
|
points = points.unsqueeze(0) # add batch dimension
|
||||||
if labels.dim() == 1:
|
if labels.dim() == 1:
|
||||||
labels = labels.unsqueeze(0) # add batch dimension
|
labels = labels.unsqueeze(0) # add batch dimension
|
||||||
|
|
||||||
|
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
||||||
|
# along with the user-provided points (consistent with how SAM 2 is trained).
|
||||||
|
if box is not None:
|
||||||
|
if not clear_old_points:
|
||||||
|
raise ValueError(
|
||||||
|
"cannot add box without clearing old points, since "
|
||||||
|
"box prompt must be provided before any point prompt "
|
||||||
|
"(please use clear_old_points=True instead)"
|
||||||
|
)
|
||||||
|
if inference_state["tracking_has_started"]:
|
||||||
|
warnings.warn(
|
||||||
|
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||||
|
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||||
|
"use box prompt as an *initial* input before tracking, please call "
|
||||||
|
"'reset_state' on the inference state to restart from scratch.",
|
||||||
|
category=UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if not isinstance(box, torch.Tensor):
|
||||||
|
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||||
|
box_coords = box.reshape(1, 2, 2)
|
||||||
|
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
||||||
|
box_labels = box_labels.reshape(1, 2)
|
||||||
|
points = torch.cat([box_coords, points], dim=1)
|
||||||
|
labels = torch.cat([box_labels, labels], dim=1)
|
||||||
|
|
||||||
if normalize_coords:
|
if normalize_coords:
|
||||||
video_H = inference_state["video_height"]
|
video_H = inference_state["video_height"]
|
||||||
video_W = inference_state["video_width"]
|
video_W = inference_state["video_width"]
|
||||||
@@ -251,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
)
|
)
|
||||||
return frame_idx, obj_ids, video_res_masks
|
return frame_idx, obj_ids, video_res_masks
|
||||||
|
|
||||||
|
def add_new_points(self, *args, **kwargs):
|
||||||
|
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
||||||
|
return self.add_new_points_or_box(*args, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def add_new_mask(
|
def add_new_mask(
|
||||||
self,
|
self,
|
||||||
@@ -531,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||||
# Find all the frames that contain temporary outputs for any objects
|
# Find all the frames that contain temporary outputs for any objects
|
||||||
# (these should be the frames that have just received clicks for mask inputs
|
# (these should be the frames that have just received clicks for mask inputs
|
||||||
# via `add_new_points` or `add_new_mask`)
|
# via `add_new_points_or_box` or `add_new_mask`)
|
||||||
temp_frame_inds = set()
|
temp_frame_inds = set()
|
||||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||||
|
@@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
|
|||||||
# Holes are those connected components in background with area <= self.max_area
|
# Holes are those connected components in background with area <= self.max_area
|
||||||
# (background regions are those with mask scores <= 0)
|
# (background regions are those with mask scores <= 0)
|
||||||
assert max_area > 0, "max_area must be positive"
|
assert max_area > 0, "max_area must be positive"
|
||||||
labels, areas = get_connected_components(mask <= 0)
|
|
||||||
is_hole = (labels > 0) & (areas <= max_area)
|
input_mask = mask
|
||||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
try:
|
||||||
mask = torch.where(is_hole, 0.1, mask)
|
labels, areas = get_connected_components(mask <= 0)
|
||||||
|
is_hole = (labels > 0) & (areas <= max_area)
|
||||||
|
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||||
|
mask = torch.where(is_hole, 0.1, mask)
|
||||||
|
except Exception as e:
|
||||||
|
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
||||||
|
warnings.warn(
|
||||||
|
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||||
|
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||||
|
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||||
|
category=UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
mask = input_mask
|
||||||
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
@@ -4,6 +4,8 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -78,22 +80,38 @@ class SAM2Transforms(nn.Module):
|
|||||||
from sam2.utils.misc import get_connected_components
|
from sam2.utils.misc import get_connected_components
|
||||||
|
|
||||||
masks = masks.float()
|
masks = masks.float()
|
||||||
if self.max_hole_area > 0:
|
input_masks = masks
|
||||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||||
# (background regions are those with mask scores <= self.mask_threshold)
|
try:
|
||||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
if self.max_hole_area > 0:
|
||||||
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
|
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
# (background regions are those with mask scores <= self.mask_threshold)
|
||||||
is_hole = is_hole.reshape_as(masks)
|
labels, areas = get_connected_components(
|
||||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
mask_flat <= self.mask_threshold
|
||||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
)
|
||||||
|
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||||
|
is_hole = is_hole.reshape_as(masks)
|
||||||
|
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||||
|
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||||
|
|
||||||
if self.max_sprinkle_area > 0:
|
if self.max_sprinkle_area > 0:
|
||||||
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
|
labels, areas = get_connected_components(
|
||||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
mask_flat > self.mask_threshold
|
||||||
is_hole = is_hole.reshape_as(masks)
|
)
|
||||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
is_hole = is_hole.reshape_as(masks)
|
||||||
|
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||||
|
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||||
|
except Exception as e:
|
||||||
|
# Skip the post-processing step if the CUDA kernel fails
|
||||||
|
warnings.warn(
|
||||||
|
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||||
|
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||||
|
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||||
|
category=UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
masks = input_masks
|
||||||
|
|
||||||
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
||||||
return masks
|
return masks
|
||||||
|
84
setup.py
84
setup.py
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
@@ -36,22 +37,75 @@ EXTRA_PACKAGES = {
|
|||||||
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
|
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# By default, we also build the SAM 2 CUDA extension.
|
||||||
|
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
|
||||||
|
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
|
||||||
|
# By default, we allow SAM 2 installation to proceed even with build errors.
|
||||||
|
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
|
||||||
|
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
||||||
|
|
||||||
|
# Catch and skip errors during extension building and print a warning message
|
||||||
|
# (note that this message only shows up under verbose build mode
|
||||||
|
# "pip install -v -e ." or "python setup.py build_ext -v")
|
||||||
|
CUDA_ERROR_MSG = (
|
||||||
|
"{}\n\n"
|
||||||
|
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
||||||
|
"You can still use SAM 2, but some post-processing functionality may be limited "
|
||||||
|
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_extensions():
|
def get_extensions():
|
||||||
srcs = ["sam2/csrc/connected_components.cu"]
|
if not BUILD_CUDA:
|
||||||
compile_args = {
|
return []
|
||||||
"cxx": [],
|
|
||||||
"nvcc": [
|
try:
|
||||||
"-DCUDA_HAS_FP16=1",
|
srcs = ["sam2/csrc/connected_components.cu"]
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
compile_args = {
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
"cxx": [],
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
"nvcc": [
|
||||||
],
|
"-DCUDA_HAS_FP16=1",
|
||||||
}
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||||
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
||||||
|
except Exception as e:
|
||||||
|
if BUILD_ALLOW_ERRORS:
|
||||||
|
print(CUDA_ERROR_MSG.format(e))
|
||||||
|
ext_modules = []
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
return ext_modules
|
return ext_modules
|
||||||
|
|
||||||
|
|
||||||
|
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
try:
|
||||||
|
super().finalize_options()
|
||||||
|
except Exception as e:
|
||||||
|
print(CUDA_ERROR_MSG.format(e))
|
||||||
|
self.extensions = []
|
||||||
|
|
||||||
|
def build_extensions(self):
|
||||||
|
try:
|
||||||
|
super().build_extensions()
|
||||||
|
except Exception as e:
|
||||||
|
print(CUDA_ERROR_MSG.format(e))
|
||||||
|
self.extensions = []
|
||||||
|
|
||||||
|
def get_ext_filename(self, ext_name):
|
||||||
|
try:
|
||||||
|
return super().get_ext_filename(ext_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(CUDA_ERROR_MSG.format(e))
|
||||||
|
self.extensions = []
|
||||||
|
return "_C.so"
|
||||||
|
|
||||||
|
|
||||||
# Setup configuration
|
# Setup configuration
|
||||||
setup(
|
setup(
|
||||||
name=NAME,
|
name=NAME,
|
||||||
@@ -68,5 +122,11 @@ setup(
|
|||||||
extras_require=EXTRA_PACKAGES,
|
extras_require=EXTRA_PACKAGES,
|
||||||
python_requires=">=3.10.0",
|
python_requires=">=3.10.0",
|
||||||
ext_modules=get_extensions(),
|
ext_modules=get_extensions(),
|
||||||
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
|
cmdclass={
|
||||||
|
"build_ext": (
|
||||||
|
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||||
|
if BUILD_ALLOW_ERRORS
|
||||||
|
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user