Merge pull request #205 from facebookresearch/haitham/fix_hf_image_predictor
Fix HF image predictor
This commit is contained in:
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
|||||||
output_mode: str = "binary_mask",
|
output_mode: str = "binary_mask",
|
||||||
use_m2m: bool = False,
|
use_m2m: bool = False,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Using a SAM 2 model, generates masks for the entire image.
|
Using a SAM 2 model, generates masks for the entire image.
|
||||||
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
|||||||
self.use_m2m = use_m2m
|
self.use_m2m = use_m2m
|
||||||
self.multimask_output = multimask_output
|
self.multimask_output = multimask_output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
||||||
|
"""
|
||||||
|
Load a pretrained model from the Hugging Face hub.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_id (str): The Hugging Face repository ID.
|
||||||
|
**kwargs: Additional arguments to pass to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(SAM2AutomaticMaskGenerator): The loaded model.
|
||||||
|
"""
|
||||||
|
from sam2.build_sam import build_sam2_hf
|
||||||
|
|
||||||
|
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||||
|
return cls(sam_model, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
@@ -19,6 +19,7 @@ def build_sam2(
|
|||||||
mode="eval",
|
mode="eval",
|
||||||
hydra_overrides_extra=[],
|
hydra_overrides_extra=[],
|
||||||
apply_postprocessing=True,
|
apply_postprocessing=True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if apply_postprocessing:
|
if apply_postprocessing:
|
||||||
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
|
|||||||
mode="eval",
|
mode="eval",
|
||||||
hydra_overrides_extra=[],
|
hydra_overrides_extra=[],
|
||||||
apply_postprocessing=True,
|
apply_postprocessing=True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
hydra_overrides = [
|
hydra_overrides = [
|
||||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||||
|
@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
|
|||||||
mask_threshold=0.0,
|
mask_threshold=0.0,
|
||||||
max_hole_area=0.0,
|
max_hole_area=0.0,
|
||||||
max_sprinkle_area=0.0,
|
max_sprinkle_area=0.0,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Uses SAM-2 to calculate the image embedding for an image, and then
|
Uses SAM-2 to calculate the image embedding for an image, and then
|
||||||
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
|
|||||||
sam_model (Sam-2): The model to use for mask prediction.
|
sam_model (Sam-2): The model to use for mask prediction.
|
||||||
mask_threshold (float): The threshold to use when converting mask logits
|
mask_threshold (float): The threshold to use when converting mask logits
|
||||||
to binary masks. Masks are thresholded at 0 by default.
|
to binary masks. Masks are thresholded at 0 by default.
|
||||||
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
||||||
the maximum area of fill_hole_area in low_res_masks.
|
the maximum area of max_hole_area in low_res_masks.
|
||||||
|
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||||
|
the maximum area of max_sprinkle_area in low_res_masks.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = sam_model
|
self.model = sam_model
|
||||||
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
|
|||||||
from sam2.build_sam import build_sam2_hf
|
from sam2.build_sam import build_sam2_hf
|
||||||
|
|
||||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||||
return cls(sam_model)
|
return cls(sam_model, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def set_image(
|
def set_image(
|
||||||
|
@@ -121,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
from sam2.build_sam import build_sam2_video_predictor_hf
|
from sam2.build_sam import build_sam2_video_predictor_hf
|
||||||
|
|
||||||
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
||||||
return cls(sam_model)
|
return sam_model
|
||||||
|
|
||||||
def _obj_id_to_idx(self, inference_state, obj_id):
|
def _obj_id_to_idx(self, inference_state, obj_id):
|
||||||
"""Map client-side object id to model-side object index."""
|
"""Map client-side object id to model-side object index."""
|
||||||
|
Reference in New Issue
Block a user