diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index e7eebbe..5d2980c 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -62,7 +62,8 @@ class SAM2ImagePredictor: (64, 64), ] - def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": """ Load a pretrained model from the Hugging Face model hub. @@ -74,7 +75,7 @@ class SAM2ImagePredictor: (SAM2ImagePredictor): The loaded model. """ sam_model = build_sam2_hf(model_id, **kwargs) - return SAM2ImagePredictor(sam_model) + return cls(sam_model) @torch.no_grad() def set_image(