better support for non-CUDA devices (CPU, MPS) (#192)
This commit is contained in:
@@ -284,7 +284,9 @@ class SAM2AutomaticMaskGenerator:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
points = torch.as_tensor(points, device=self.predictor.device)
|
||||
points = torch.as_tensor(
|
||||
points, dtype=torch.float32, device=self.predictor.device
|
||||
)
|
||||
in_points = self.predictor._transforms.transform_coords(
|
||||
points, normalize=normalize, orig_hw=im_size
|
||||
)
|
||||
|
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
|
@@ -45,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize a inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# the original video height and width, used for resizing final output scores
|
||||
inference_state["video_height"] = video_height
|
||||
inference_state["video_width"] = video_width
|
||||
inference_state["device"] = torch.device("cuda")
|
||||
inference_state["device"] = compute_device
|
||||
if offload_state_to_cpu:
|
||||
inference_state["storage_device"] = torch.device("cpu")
|
||||
else:
|
||||
inference_state["storage_device"] = torch.device("cuda")
|
||||
inference_state["storage_device"] = compute_device
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
|
||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
||||
device = inference_state["device"]
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||
current_out, _ = self._run_single_frame_inference(
|
||||
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
if backbone_out is None:
|
||||
# Cache miss -- we will run inference on a single image
|
||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
||||
device = inference_state["device"]
|
||||
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||
backbone_out = self.forward_image(image)
|
||||
# Cache the most recent frame's feature (for repeated interactions with
|
||||
# a frame; we can use an LRU cache for more frames in the future).
|
||||
|
@@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
|
||||
A list of video frames to be load asynchronously without blocking session start.
|
||||
"""
|
||||
|
||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
||||
def __init__(
|
||||
self,
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
):
|
||||
self.img_paths = img_paths
|
||||
self.image_size = image_size
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
@@ -119,6 +127,7 @@ class AsyncVideoFrameLoader:
|
||||
# video_height and video_width be filled when loading the first image
|
||||
self.video_height = None
|
||||
self.video_width = None
|
||||
self.compute_device = compute_device
|
||||
|
||||
# load the first frame to fill video_height and video_width and also
|
||||
# to cache it (since it's most likely where the user will click)
|
||||
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
||||
img -= self.img_mean
|
||||
img /= self.img_std
|
||||
if not self.offload_video_to_cpu:
|
||||
img = img.cuda(non_blocking=True)
|
||||
img = img.to(self.compute_device, non_blocking=True)
|
||||
self.images[index] = img
|
||||
return img
|
||||
|
||||
@@ -167,6 +176,7 @@ def load_video_frames(
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
@@ -196,7 +206,12 @@ def load_video_frames(
|
||||
|
||||
if async_loading_frames:
|
||||
lazy_images = AsyncVideoFrameLoader(
|
||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
)
|
||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||
|
||||
@@ -204,9 +219,9 @@ def load_video_frames(
|
||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||
if not offload_video_to_cpu:
|
||||
images = images.cuda()
|
||||
img_mean = img_mean.cuda()
|
||||
img_std = img_std.cuda()
|
||||
images = images.to(compute_device)
|
||||
img_mean = img_mean.to(compute_device)
|
||||
img_std = img_std.to(compute_device)
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
|
Reference in New Issue
Block a user