diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py index 697b421..065e469 100644 --- a/sam2/automatic_mask_generator.py +++ b/sam2/automatic_mask_generator.py @@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator: output_mode: str = "binary_mask", use_m2m: bool = False, multimask_output: bool = True, + **kwargs, ) -> None: """ Using a SAM 2 model, generates masks for the entire image. @@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator: self.use_m2m = use_m2m self.multimask_output = multimask_output + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: """