diff --git a/sam2/build_sam.py b/sam2/build_sam.py index e5911d4..3a29eda 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -19,6 +19,7 @@ def build_sam2( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + **kwargs, ): if apply_postprocessing: @@ -47,6 +48,7 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + **kwargs, ): hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index f6f9a5a..56d9325 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -24,6 +24,7 @@ class SAM2ImagePredictor: mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, + **kwargs, ) -> None: """ Uses SAM-2 to calculate the image embedding for an image, and then @@ -33,8 +34,10 @@ class SAM2ImagePredictor: sam_model (Sam-2): The model to use for mask prediction. mask_threshold (float): The threshold to use when converting mask logits to binary masks. Masks are thresholded at 0 by default. - fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to - the maximum area of fill_hole_area in low_res_masks. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. """ super().__init__() self.model = sam_model @@ -77,7 +80,7 @@ class SAM2ImagePredictor: from sam2.build_sam import build_sam2_hf sam_model = build_sam2_hf(model_id, **kwargs) - return cls(sam_model) + return cls(sam_model, **kwargs) @torch.no_grad() def set_image( diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 78284e2..e0a9c99 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -121,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base): from sam2.build_sam import build_sam2_video_predictor_hf sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) - return cls(sam_model) + return sam_model def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index."""