accept kwargs in auto_mask_generator
This commit is contained in:
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
|||||||
output_mode: str = "binary_mask",
|
output_mode: str = "binary_mask",
|
||||||
use_m2m: bool = False,
|
use_m2m: bool = False,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Using a SAM 2 model, generates masks for the entire image.
|
Using a SAM 2 model, generates masks for the entire image.
|
||||||
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
|||||||
self.use_m2m = use_m2m
|
self.use_m2m = use_m2m
|
||||||
self.multimask_output = multimask_output
|
self.multimask_output = multimask_output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
||||||
|
"""
|
||||||
|
Load a pretrained model from the Hugging Face hub.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_id (str): The Hugging Face repository ID.
|
||||||
|
**kwargs: Additional arguments to pass to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(SAM2AutomaticMaskGenerator): The loaded model.
|
||||||
|
"""
|
||||||
|
from sam2.build_sam import build_sam2_hf
|
||||||
|
|
||||||
|
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||||
|
return cls(sam_model, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user