diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index 8aa1a0b..d9f4e51 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module): if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = ( - torch.tensor(pos_list) - .pin_memory() - .to(device=device, non_blocking=True) + obj_pos = torch.tensor(pos_list).to( + device=device, non_blocking=True ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos)