From 6e0ddadf7cf8f33610ba7f6fdce30673526afa0c Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Wed, 21 Aug 2024 18:11:44 +0800 Subject: [PATCH] update to latest SAM 2 --- sam2/automatic_mask_generator.py | 22 ++++++- sam2/build_sam.py | 2 + sam2/modeling/backbones/hieradet.py | 4 -- sam2/modeling/position_encoding.py | 9 ++- sam2/modeling/sam2_base.py | 6 +- sam2/sam2_image_predictor.py | 11 ++-- sam2/sam2_video_predictor.py | 20 +++--- sam2/utils/misc.py | 48 +++++++++++---- sam2/utils/transforms.py | 5 +- sav_dataset/sav_evaluator.py | 2 +- sav_dataset/utils/sav_benchmark.py | 2 +- setup.py | 96 ++++++++++++++--------------- 12 files changed, 140 insertions(+), 87 deletions(-) diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py index 67668b2..065e469 100644 --- a/sam2/automatic_mask_generator.py +++ b/sam2/automatic_mask_generator.py @@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator: output_mode: str = "binary_mask", use_m2m: bool = False, multimask_output: bool = True, + **kwargs, ) -> None: """ Using a SAM 2 model, generates masks for the entire image. @@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator: self.use_m2m = use_m2m self.multimask_output = multimask_output + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: """ @@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator: orig_h, orig_w = orig_size # Run model on this batch - points = torch.as_tensor(points, device=self.predictor.device) + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) in_points = self.predictor._transforms.transform_coords( points, normalize=normalize, orig_hw=im_size ) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index e5911d4..3a29eda 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -19,6 +19,7 @@ def build_sam2( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + **kwargs, ): if apply_postprocessing: @@ -47,6 +48,7 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + **kwargs, ): hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py index 1ae7d4c..973d622 100644 --- a/sam2/modeling/backbones/hieradet.py +++ b/sam2/modeling/backbones/hieradet.py @@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module): self.dim = dim self.dim_out = dim_out - self.num_heads = num_heads - head_dim = dim_out // num_heads - self.scale = head_dim**-0.5 - self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index f4b57ae..52ac226 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -16,7 +16,7 @@ from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. + used by the Attention Is All You Need paper, generalized to work on images. """ def __init__( @@ -211,6 +211,11 @@ def apply_rotary_enc( # repeat freqs along seq_len dim to match k seq_len if repeat_freqs_k: r = xk_.shape[-2] // xq_.shape[-2] - freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index 2b5251f..224a8c1 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module): continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) + feats = prev["maskmem_features"].to(device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = ( @@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module): pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem - # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index f6f9a5a..41ce53a 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -24,6 +24,7 @@ class SAM2ImagePredictor: mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, + **kwargs, ) -> None: """ Uses SAM-2 to calculate the image embedding for an image, and then @@ -33,8 +34,10 @@ class SAM2ImagePredictor: sam_model (Sam-2): The model to use for mask prediction. mask_threshold (float): The threshold to use when converting mask logits to binary masks. Masks are thresholded at 0 by default. - fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to - the maximum area of fill_hole_area in low_res_masks. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. """ super().__init__() self.model = sam_model @@ -77,7 +80,7 @@ class SAM2ImagePredictor: from sam2.build_sam import build_sam2_hf sam_model = build_sam2_hf(model_id, **kwargs) - return cls(sam_model) + return cls(sam_model, **kwargs) @torch.no_grad() def set_image( @@ -180,7 +183,7 @@ class SAM2ImagePredictor: normalize_coords=True, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. - It returns a tupele of lists of masks, ious, and low_res_masks_logits. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. """ assert self._is_batch, "This function should only be used when in batched mode" if not self._is_image_set: diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index b5a6bdf..8b2fd6c 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -44,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base): offload_state_to_cpu=False, async_loading_frames=False, ): - """Initialize a inference state.""" + """Initialize an inference state.""" + compute_device = self.device # device of the model images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, + compute_device=compute_device, ) inference_state = {} inference_state["images"] = images @@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base): # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width - inference_state["device"] = torch.device("cuda") + inference_state["device"] = compute_device if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: - inference_state["storage_device"] = torch.device("cuda") + inference_state["storage_device"] = compute_device # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} @@ -119,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base): from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) - return cls(sam_model) + return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" @@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base): prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: - prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( @@ -586,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base): # to `propagate_in_video_preflight`). consolidated_frame_inds = inference_state["consolidated_frame_inds"] for is_cond in [False, True]: - # Separately consolidate conditioning and non-conditioning temp outptus + # Separately consolidate conditioning and non-conditioning temp outputs storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" # Find all the frames that contain temporary outputs for any objects # (these should be the frames that have just received clicks for mask inputs @@ -595,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base): for obj_temp_output_dict in temp_output_dict_per_obj.values(): temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) consolidated_frame_inds[storage_key].update(temp_frame_inds) - # consolidate the temprary output across all objects on this frame + # consolidate the temporary output across all objects on this frame for frame_idx in temp_frame_inds: consolidated_out = self._consolidate_temp_output_across_obj( inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True @@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base): ) if backbone_out is None: # Cache miss -- we will run inference on a single image - image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 3c5ff13..525e8cb 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor): compute bounding box given an input mask Inputs: - - masks: [B, 1, H, W] boxes, dtype=torch.Tensor + - masks: [B, 1, H, W] masks, dtype=torch.Tensor Returns: - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor @@ -106,19 +106,28 @@ class AsyncVideoFrameLoader: A list of video frames to be load asynchronously without blocking session start. """ - def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): self.img_paths = img_paths self.image_size = image_size self.offload_video_to_cpu = offload_video_to_cpu self.img_mean = img_mean self.img_std = img_std - # items in `self._images` will be loaded asynchronously + # items in `self.images` will be loaded asynchronously self.images = [None] * len(img_paths) # catch and raise any exceptions in the async loading thread self.exception = None # video_height and video_width be filled when loading the first image self.video_height = None self.video_width = None + self.compute_device = compute_device # load the first frame to fill video_height and video_width and also # to cache it (since it's most likely where the user will click) @@ -152,7 +161,7 @@ class AsyncVideoFrameLoader: img -= self.img_mean img /= self.img_std if not self.offload_video_to_cpu: - img = img.cuda(non_blocking=True) + img = img.to(self.compute_device, non_blocking=True) self.images[index] = img return img @@ -167,6 +176,7 @@ def load_video_frames( img_mean=(0.485, 0.456, 0.406), img_std=(0.229, 0.224, 0.225), async_loading_frames=False, + compute_device=torch.device("cuda"), ): """ Load the video frames from a directory of JPEG files (".jpg" format). @@ -179,12 +189,20 @@ def load_video_frames( if isinstance(video_path, str) and os.path.isdir(video_path): jpg_folder = video_path else: - raise NotImplementedError("Only JPEG frames are supported at this moment") + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) frame_names = [ p for p in os.listdir(jpg_folder) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) num_frames = len(frame_names) @@ -196,7 +214,12 @@ def load_video_frames( if async_loading_frames: lazy_images = AsyncVideoFrameLoader( - img_paths, image_size, offload_video_to_cpu, img_mean, img_std + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, ) return lazy_images, lazy_images.video_height, lazy_images.video_width @@ -204,9 +227,9 @@ def load_video_frames( for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) # normalize by mean and std images -= img_mean images /= img_std @@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area): except Exception as e: # Skip the post-processing step on removing small holes if the CUDA kernel fails warnings.warn( - f"{e}\n\nSkipping the post-processing step due to the error above. " - "Consider building SAM 2 with CUDA extension to enable post-processing (see " + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", category=UserWarning, stacklevel=2, diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py index 995baf9..65ef770 100644 --- a/sam2/utils/transforms.py +++ b/sam2/utils/transforms.py @@ -105,8 +105,9 @@ class SAM2Transforms(nn.Module): except Exception as e: # Skip the post-processing step if the CUDA kernel fails warnings.warn( - f"{e}\n\nSkipping the post-processing step due to the error above. " - "Consider building SAM 2 with CUDA extension to enable post-processing (see " + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", category=UserWarning, stacklevel=2, diff --git a/sav_dataset/sav_evaluator.py b/sav_dataset/sav_evaluator.py index 1c319e1..d4b0ef0 100644 --- a/sav_dataset/sav_evaluator.py +++ b/sav_dataset/sav_evaluator.py @@ -72,7 +72,7 @@ parser.add_argument( parser.add_argument( "--do_not_skip_first_and_last_frame", help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. " - "Set this to true for evaluation on settings that doen't skip first and last frames", + "Set this to true for evaluation on settings that doesn't skip first and last frames", action="store_true", ) diff --git a/sav_dataset/utils/sav_benchmark.py b/sav_dataset/utils/sav_benchmark.py index babb330..c4b2444 100644 --- a/sav_dataset/utils/sav_benchmark.py +++ b/sav_dataset/utils/sav_benchmark.py @@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None): assert not ( width > w | height > h | abs(ar1 - ar2) > 0.01 - ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) + ), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height) e = np.zeros_like(seg) s = np.zeros_like(seg) diff --git a/setup.py b/setup.py index d49ddc5..ebef97c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,6 @@ import os from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension # Package metadata NAME = "SAM 2" @@ -18,35 +17,18 @@ AUTHOR_EMAIL = "segment-anything@meta.com" LICENSE = "Apache 2.0" # Read the contents of README file -with open("README.md", "r") as f: +with open("README.md", "r", encoding="utf-8") as f: LONG_DESCRIPTION = f.read() # Required dependencies REQUIRED_PACKAGES = [ - "torch>=2.3.1", + "torch>=2.3.1", "torchvision>=0.18.1", - "transformers", "numpy>=1.24.4", "tqdm>=4.66.1", "hydra-core>=1.3.2", "iopath>=0.1.10", - "pillow>=9.4.0", - "huggingface_hub", - "diffusers[torch]==0.15.1", - "onnxruntime==1.14.1", - "onnx==1.13.1", - "ipykernel==6.16.2", - "scipy", - "gradio", - "openai", - "matplotlib>=3.9.1", - "opencv-python>=4.7.0", - "dds_cloudapi_sdk", - "addict", - "yapf", - "timm", - "supervision>=0.22.0", - "pycocotools", + "pillow>=9.4.0", ] EXTRA_PACKAGES = { @@ -67,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" CUDA_ERROR_MSG = ( "{}\n\n" "Failed to build the SAM 2 CUDA extension due to the error above. " - "You can still use SAM 2, but some post-processing functionality may be limited " + "You can still use SAM 2 and it's OK to ignore the error above, although some " + "post-processing functionality may be limited (which doesn't affect the results in most cases; " "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n" ) @@ -77,6 +60,8 @@ def get_extensions(): return [] try: + from torch.utils.cpp_extension import CUDAExtension + srcs = ["sam2/csrc/connected_components.cu"] compile_args = { "cxx": [], @@ -98,29 +83,46 @@ def get_extensions(): return ext_modules -class BuildExtensionIgnoreErrors(BuildExtension): +try: + from torch.utils.cpp_extension import BuildExtension - def finalize_options(self): - try: - super().finalize_options() - except Exception as e: - print(CUDA_ERROR_MSG.format(e)) - self.extensions = [] + class BuildExtensionIgnoreErrors(BuildExtension): - def build_extensions(self): - try: - super().build_extensions() - except Exception as e: - print(CUDA_ERROR_MSG.format(e)) - self.extensions = [] + def finalize_options(self): + try: + super().finalize_options() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] - def get_ext_filename(self, ext_name): - try: - return super().get_ext_filename(ext_name) - except Exception as e: - print(CUDA_ERROR_MSG.format(e)) - self.extensions = [] - return "_C.so" + def build_extensions(self): + try: + super().build_extensions() + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + + def get_ext_filename(self, ext_name): + try: + return super().get_ext_filename(ext_name) + except Exception as e: + print(CUDA_ERROR_MSG.format(e)) + self.extensions = [] + return "_C.so" + + cmdclass = { + "build_ext": ( + BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) + if BUILD_ALLOW_ERRORS + else BuildExtension.with_options(no_python_abi_suffix=True) + ) + } +except Exception as e: + cmdclass = {} + if BUILD_ALLOW_ERRORS: + print(CUDA_ERROR_MSG.format(e)) + else: + raise e # Setup configuration @@ -135,15 +137,11 @@ setup( author_email=AUTHOR_EMAIL, license=LICENSE, packages=find_packages(exclude="notebooks"), + package_data={"": ["*.yaml"]}, # SAM 2 configuration files + include_package_data=True, install_requires=REQUIRED_PACKAGES, extras_require=EXTRA_PACKAGES, python_requires=">=3.10.0", ext_modules=get_extensions(), - cmdclass={ - "build_ext": ( - BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) - if BUILD_ALLOW_ERRORS - else BuildExtension.with_options(no_python_abi_suffix=True) - ), - }, + cmdclass=cmdclass, )