better support for non-CUDA devices (CPU, MPS) (#192)
This commit is contained in:
@@ -45,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize a inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# the original video height and width, used for resizing final output scores
|
||||
inference_state["video_height"] = video_height
|
||||
inference_state["video_width"] = video_width
|
||||
inference_state["device"] = torch.device("cuda")
|
||||
inference_state["device"] = compute_device
|
||||
if offload_state_to_cpu:
|
||||
inference_state["storage_device"] = torch.device("cpu")
|
||||
else:
|
||||
inference_state["storage_device"] = torch.device("cuda")
|
||||
inference_state["storage_device"] = compute_device
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
|
||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
||||
device = inference_state["device"]
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||
current_out, _ = self._run_single_frame_inference(
|
||||
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
if backbone_out is None:
|
||||
# Cache miss -- we will run inference on a single image
|
||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
||||
device = inference_state["device"]
|
||||
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||
backbone_out = self.forward_image(image)
|
||||
# Cache the most recent frame's feature (for repeated interactions with
|
||||
# a frame; we can use an LRU cache for more frames in the future).
|
||||
|
Reference in New Issue
Block a user