remove .pin_memory() in obj_pos of SAM2Base to resolve and error in MPS (#495)

In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487.

(close https://github.com/facebookresearch/sam2/issues/487)
This commit is contained in:
Ronghang Hu
2024-12-15 16:47:17 -08:00
committed by GitHub
parent 722d1d1511
commit 2b90b9f5ce

View File

@@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs: if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1 t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = ( obj_pos = torch.tensor(pos_list).to(
torch.tensor(pos_list) device=device, non_blocking=True
.pin_memory()
.to(device=device, non_blocking=True)
) )
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = self.obj_ptr_tpos_proj(obj_pos)