update to latest SAM 2
This commit is contained in:
@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to calculate the image embedding for an image, and then
|
||||
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
|
||||
sam_model (Sam-2): The model to use for mask prediction.
|
||||
mask_threshold (float): The threshold to use when converting mask logits
|
||||
to binary masks. Masks are thresholded at 0 by default.
|
||||
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of fill_hole_area in low_res_masks.
|
||||
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(
|
||||
@@ -180,7 +183,7 @@ class SAM2ImagePredictor:
|
||||
normalize_coords=True,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
||||
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
|
||||
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
||||
"""
|
||||
assert self._is_batch, "This function should only be used when in batched mode"
|
||||
if not self._is_image_set:
|
||||
|
Reference in New Issue
Block a user