Fix HF image predictor

This commit is contained in:
Haitham Khedr
2024-08-12 23:41:41 +00:00
parent dce7b5446f
commit 1191677e1e
3 changed files with 9 additions and 4 deletions

View File

@@ -19,6 +19,7 @@ def build_sam2(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
**kwargs,
): ):
if apply_postprocessing: if apply_postprocessing:
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
**kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",

View File

@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
mask_threshold=0.0, mask_threshold=0.0,
max_hole_area=0.0, max_hole_area=0.0,
max_sprinkle_area=0.0, max_sprinkle_area=0.0,
**kwargs,
) -> None: ) -> None:
""" """
Uses SAM-2 to calculate the image embedding for an image, and then Uses SAM-2 to calculate the image embedding for an image, and then
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
sam_model (Sam-2): The model to use for mask prediction. sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default. to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks. the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
""" """
super().__init__() super().__init__()
self.model = sam_model self.model = sam_model
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
from sam2.build_sam import build_sam2_hf from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs) sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model) return cls(sam_model, **kwargs)
@torch.no_grad() @torch.no_grad()
def set_image( def set_image(

View File

@@ -121,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
from sam2.build_sam import build_sam2_video_predictor_hf from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model) return sam_model
def _obj_id_to_idx(self, inference_state, obj_id): def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index.""" """Map client-side object id to model-side object index."""