diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 4637706..5a7e1a0 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -591,7 +591,8 @@ class SAM2VideoPredictor(SAM2Base): if frame_idx in obj_output_dict["cond_frame_outputs"]: storage_key = "cond_frame_outputs" current_out = obj_output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] + device = inference_state["device"] + pred_masks = current_out["pred_masks"].to(device, non_blocking=True) if self.clear_non_cond_mem_around_input: # clear non-conditioning memory of the surrounding frames self._clear_obj_non_cond_mem_around_input(