remove .pin_memory()
in obj_pos
of SAM2Base
to resolve and error in MPS (#495)
In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487. (close https://github.com/facebookresearch/sam2/issues/487)
This commit is contained in:
@@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module):
|
|||||||
if self.add_tpos_enc_to_obj_ptrs:
|
if self.add_tpos_enc_to_obj_ptrs:
|
||||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||||
obj_pos = (
|
obj_pos = torch.tensor(pos_list).to(
|
||||||
torch.tensor(pos_list)
|
device=device, non_blocking=True
|
||||||
.pin_memory()
|
|
||||||
.to(device=device, non_blocking=True)
|
|
||||||
)
|
)
|
||||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||||
|
Reference in New Issue
Block a user