Make huggingface_hub soft dependency

This commit is contained in:
Niels
2024-08-05 09:37:53 +02:00
parent 0c28c630c2
commit 6aeee34775
3 changed files with 24 additions and 4 deletions

View File

@@ -103,6 +103,23 @@ class SAM2VideoPredictor(SAM2Base):
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
"""
Load a pretrained model from the Hugging Face model hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2ImagePredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)