diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 48a9f50..7346adc 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -135,7 +135,7 @@ def build_sam2_video_predictor_hf(model_id, **kwargs): def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: - sd = torch.load(ckpt_path, map_location="cpu")["model"] + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] missing_keys, unexpected_keys = model.load_state_dict(sd) if missing_keys: logging.error(missing_keys)