From 2b90b9f5ceec907a1c18123530e92e794ad901a4 Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Sun, 15 Dec 2024 16:47:17 -0800 Subject: [PATCH] remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS (#495) In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487. (close https://github.com/facebookresearch/sam2/issues/487) --- sam2/modeling/sam2_base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index 8aa1a0b..d9f4e51 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module): if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = ( - torch.tensor(pos_list) - .pin_memory() - .to(device=device, non_blocking=True) + obj_pos = torch.tensor(pos_list).to( + device=device, non_blocking=True ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos)