From b72a8a97f062eb918736925eb8942fb7a179c60e Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 3 Aug 2024 12:57:05 +0200 Subject: [PATCH] First draft --- sam2/build_sam.py | 16 +++++++++++++++- sam2/sam2_image_predictor.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 39defc4..e55f85e 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -11,6 +11,8 @@ from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf +from huggingface_hub import hf_hub_download + def build_sam2( config_file, @@ -76,6 +78,18 @@ def build_sam2_video_predictor( return model +def build_sam2_hf(model_id, **kwargs): + config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") + ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") + return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") + ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") + return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + + def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["model"] @@ -86,4 +100,4 @@ def _load_checkpoint(model, ckpt_path): if unexpected_keys: logging.error(unexpected_keys) raise RuntimeError() - logging.info("Loaded checkpoint sucessfully") + logging.info("Loaded checkpoint sucessfully") \ No newline at end of file diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 9411131..e7eebbe 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -13,7 +13,7 @@ import torch from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base - +from sam2.build_sam import build_sam2_hf from sam2.utils.transforms import SAM2Transforms @@ -62,6 +62,20 @@ class SAM2ImagePredictor: (64, 64), ] + def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face model hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + sam_model = build_sam2_hf(model_id, **kwargs) + return SAM2ImagePredictor(sam_model) + @torch.no_grad() def set_image( self,