Use classmethod

This commit is contained in:
Niels
2024-08-03 14:14:12 +02:00
parent b72a8a97f0
commit 17b74501fb

View File

@@ -62,7 +62,8 @@ class SAM2ImagePredictor:
(64, 64), (64, 64),
] ]
def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": @classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
""" """
Load a pretrained model from the Hugging Face model hub. Load a pretrained model from the Hugging Face model hub.
@@ -74,7 +75,7 @@ class SAM2ImagePredictor:
(SAM2ImagePredictor): The loaded model. (SAM2ImagePredictor): The loaded model.
""" """
sam_model = build_sam2_hf(model_id, **kwargs) sam_model = build_sam2_hf(model_id, **kwargs)
return SAM2ImagePredictor(sam_model) return cls(sam_model)
@torch.no_grad() @torch.no_grad()
def set_image( def set_image(