From 077064c365bfdd57fd96019d3619e04cfdced208 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Thu, 8 Aug 2024 12:03:29 +0800 Subject: [PATCH] update to the latest sam2 version and support box prompts in video tracking --- ..._demo_custom_video_input_gd1.0_hf_model.py | 17 +++- ..._tracking_demo_custom_video_input_gd1.5.py | 20 +++-- sam2/build_sam.py | 38 +++++++++ sam2/sam2_image_predictor.py | 17 ++++ sam2/sam2_video_predictor.py | 71 ++++++++++++++-- sam2/utils/misc.py | 22 ++++- sam2/utils/transforms.py | 48 +++++++---- setup.py | 84 ++++++++++++++++--- 8 files changed, 272 insertions(+), 45 deletions(-) diff --git a/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py b/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py index 1a72cc4..742294e 100644 --- a/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py +++ b/grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py @@ -22,7 +22,7 @@ TEXT_PROMPT = "hippopotamus." OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4" SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" SAVE_TRACKING_RESULTS_DIR = "./tracking_results" -PROMPT_TYPE_FOR_VIDEO = "mask" # "point" +PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] """ Step 1: Environment settings and model initialization for SAM 2 @@ -128,7 +128,7 @@ if masks.ndim == 4: Step 3: Register each object's positive points to video predictor with seperate add_new_points call """ -assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"] +assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt" # If you are using point prompts, we uniformly sample positive points based on the mask if PROMPT_TYPE_FOR_VIDEO == "point": @@ -137,13 +137,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point": for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): labels = np.ones((points.shape[0]), dtype=np.int32) - _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=object_id, points=points, labels=labels, ) +# Using box prompt +elif PROMPT_TYPE_FOR_VIDEO == "box": + for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1): + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + box=box, + ) # Using mask prompt is a more straightforward way elif PROMPT_TYPE_FOR_VIDEO == "mask": for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): @@ -154,6 +163,8 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask": obj_id=object_id, mask=mask ) +else: + raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts") """ diff --git a/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py b/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py index e509f4a..ef95854 100644 --- a/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py +++ b/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py @@ -28,8 +28,8 @@ TEXT_PROMPT = "hippopotamus." OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4" SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" SAVE_TRACKING_RESULTS_DIR = "./tracking_results" -API_TOKEN_FOR_GD1_5 = "Your API token" -PROMPT_TYPE_FOR_VIDEO = "mask" # "point" +API_TOKEN_FOR_GD1_5 = "3491a2a256fb7ed01b2e757b713c4cb0" +PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] """ Step 1: Environment settings and model initialization for SAM 2 @@ -152,7 +152,7 @@ if masks.ndim == 4: Step 3: Register each object's positive points to video predictor with seperate add_new_points call """ -assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"] +assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt" # If you are using point prompts, we uniformly sample positive points based on the mask if PROMPT_TYPE_FOR_VIDEO == "point": @@ -161,13 +161,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point": for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): labels = np.ones((points.shape[0]), dtype=np.int32) - _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=object_id, points=points, labels=labels, ) +# Using box prompt +elif PROMPT_TYPE_FOR_VIDEO == "box": + for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1): + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + box=box, + ) # Using mask prompt is a more straightforward way elif PROMPT_TYPE_FOR_VIDEO == "mask": for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): @@ -178,7 +187,8 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask": obj_id=object_id, mask=mask ) - +else: + raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts") """ Step 4: Propagate the video predictor to get the segmentation results for each frame diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 39defc4..e5911d4 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -76,6 +76,44 @@ def build_sam2_video_predictor( return model +def build_sam2_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["model"] diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 9411131..f6f9a5a 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -62,6 +62,23 @@ class SAM2ImagePredictor: (64, 64), ] + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model) + @torch.no_grad() def set_image( self, diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 0defcec..b5a6bdf 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import warnings from collections import OrderedDict import torch @@ -103,6 +104,23 @@ class SAM2VideoPredictor(SAM2Base): self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return cls(sam_model) + def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) @@ -146,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base): return len(inference_state["obj_idx_to_id"]) @torch.inference_mode() - def add_new_points( + def add_new_points_or_box( self, inference_state, frame_idx, obj_id, - points, - labels, + points=None, + labels=None, clear_old_points=True, normalize_coords=True, + box=None, ): """Add new points to a frame.""" obj_idx = self._obj_id_to_idx(inference_state, obj_id) point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] - if not isinstance(points, torch.Tensor): + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): points = torch.tensor(points, dtype=torch.float32) - if not isinstance(labels, torch.Tensor): + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): labels = torch.tensor(labels, dtype=torch.int32) if points.dim() == 2: points = points.unsqueeze(0) # add batch dimension if labels.dim() == 1: labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + if normalize_coords: video_H = inference_state["video_height"] video_W = inference_state["video_width"] @@ -251,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base): ) return frame_idx, obj_ids, video_res_masks + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + @torch.inference_mode() def add_new_mask( self, @@ -531,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base): storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Find all the frames that contain temporary outputs for any objects # (these should be the frames that have just received clicks for mask inputs - # via `add_new_points` or `add_new_mask`) + # via `add_new_points_or_box` or `add_new_mask`) temp_frame_inds = set() for obj_temp_output_dict in temp_output_dict_per_obj.values(): temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bf6a179..df97b4a 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area): # Holes are those connected components in background with area <= self.max_area # (background regions are those with mask scores <= 0) assert max_area > 0, "max_area must be positive" - labels, areas = get_connected_components(mask <= 0) - is_hole = (labels > 0) & (areas <= max_area) - # We fill holes with a small positive mask score (0.1) to change them to foreground. - mask = torch.where(is_hole, 0.1, mask) + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + return mask diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py index d05cd3e..995baf9 100644 --- a/sam2/utils/transforms.py +++ b/sam2/utils/transforms.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -78,22 +80,38 @@ class SAM2Transforms(nn.Module): from sam2.utils.misc import get_connected_components masks = masks.float() - if self.max_hole_area > 0: - # Holes are those connected components in background with area <= self.fill_hole_area - # (background regions are those with mask scores <= self.mask_threshold) - mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image - labels, areas = get_connected_components(mask_flat <= self.mask_threshold) - is_hole = (labels > 0) & (areas <= self.max_hole_area) - is_hole = is_hole.reshape_as(masks) - # We fill holes with a small positive mask score (10.0) to change them to foreground. - masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) - if self.max_sprinkle_area > 0: - labels, areas = get_connected_components(mask_flat > self.mask_threshold) - is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) - is_hole = is_hole.reshape_as(masks) - # We fill holes with negative mask score (-10.0) to change them to background. - masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks diff --git a/setup.py b/setup.py index 85ae842..94f41b5 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -36,22 +37,75 @@ EXTRA_PACKAGES = { "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], } +# By default, we also build the SAM 2 CUDA extension. +# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. +BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" +# By default, we allow SAM 2 installation to proceed even with build errors. +# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. +BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" + +# Catch and skip errors during extension building and print a warning message +# (note that this message only shows up under verbose build mode +# "pip install -v -e ." or "python setup.py build_ext -v") +CUDA_ERROR_MSG = ( + "{}\n\n" + "Failed to build the SAM 2 CUDA extension due to the error above. " + "You can still use SAM 2, but some post-processing functionality may be limited " + "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n" +) + def get_extensions(): - srcs = ["sam2/csrc/connected_components.cu"] - compile_args = { - "cxx": [], - "nvcc": [ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ], - } - ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + if not BUILD_CUDA: + return [] + + try: + srcs = ["sam2/csrc/connected_components.cu"] + compile_args = { + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + } + ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] + except Exception as e: + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + ext_modules = [] + else: + raise e + return ext_modules +class BuildExtensionIgnoreErrors(BuildExtension): + + def finalize_options(self): + try: + super().finalize_options() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def build_extensions(self): + try: + super().build_extensions() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def get_ext_filename(self, ext_name): + try: + return super().get_ext_filename(ext_name) + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + return "_C.so" + + # Setup configuration setup( name=NAME, @@ -68,5 +122,11 @@ setup( extras_require=EXTRA_PACKAGES, python_requires=">=3.10.0", ext_modules=get_extensions(), - cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, + cmdclass={ + "build_ext": ( + BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) + if BUILD_ALLOW_ERRORS + else BuildExtension.with_options(no_python_abi_suffix=True) + ), + }, )