better support for non-CUDA devices (CPU, MPS) (#192)
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -284,7 +284,9 @@ class SAM2AutomaticMaskGenerator:
|
|||||||
orig_h, orig_w = orig_size
|
orig_h, orig_w = orig_size
|
||||||
|
|
||||||
# Run model on this batch
|
# Run model on this batch
|
||||||
points = torch.as_tensor(points, device=self.predictor.device)
|
points = torch.as_tensor(
|
||||||
|
points, dtype=torch.float32, device=self.predictor.device
|
||||||
|
)
|
||||||
in_points = self.predictor._transforms.transform_coords(
|
in_points = self.predictor._transforms.transform_coords(
|
||||||
points, normalize=normalize, orig_hw=im_size
|
points, normalize=normalize, orig_hw=im_size
|
||||||
)
|
)
|
||||||
|
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
|||||||
# repeat freqs along seq_len dim to match k seq_len
|
# repeat freqs along seq_len dim to match k seq_len
|
||||||
if repeat_freqs_k:
|
if repeat_freqs_k:
|
||||||
r = xk_.shape[-2] // xq_.shape[-2]
|
r = xk_.shape[-2] // xq_.shape[-2]
|
||||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
if freqs_cis.is_cuda:
|
||||||
|
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||||
|
else:
|
||||||
|
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||||
|
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
|||||||
continue # skip padding frames
|
continue # skip padding frames
|
||||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||||
# Temporal positional encoding
|
# Temporal positional encoding
|
||||||
maskmem_enc = (
|
maskmem_enc = (
|
||||||
|
@@ -45,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
async_loading_frames=False,
|
async_loading_frames=False,
|
||||||
):
|
):
|
||||||
"""Initialize a inference state."""
|
"""Initialize a inference state."""
|
||||||
|
compute_device = self.device # device of the model
|
||||||
images, video_height, video_width = load_video_frames(
|
images, video_height, video_width = load_video_frames(
|
||||||
video_path=video_path,
|
video_path=video_path,
|
||||||
image_size=self.image_size,
|
image_size=self.image_size,
|
||||||
offload_video_to_cpu=offload_video_to_cpu,
|
offload_video_to_cpu=offload_video_to_cpu,
|
||||||
async_loading_frames=async_loading_frames,
|
async_loading_frames=async_loading_frames,
|
||||||
|
compute_device=compute_device,
|
||||||
)
|
)
|
||||||
inference_state = {}
|
inference_state = {}
|
||||||
inference_state["images"] = images
|
inference_state["images"] = images
|
||||||
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# the original video height and width, used for resizing final output scores
|
# the original video height and width, used for resizing final output scores
|
||||||
inference_state["video_height"] = video_height
|
inference_state["video_height"] = video_height
|
||||||
inference_state["video_width"] = video_width
|
inference_state["video_width"] = video_width
|
||||||
inference_state["device"] = torch.device("cuda")
|
inference_state["device"] = compute_device
|
||||||
if offload_state_to_cpu:
|
if offload_state_to_cpu:
|
||||||
inference_state["storage_device"] = torch.device("cpu")
|
inference_state["storage_device"] = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
inference_state["storage_device"] = torch.device("cuda")
|
inference_state["storage_device"] = compute_device
|
||||||
# inputs on each frame
|
# inputs on each frame
|
||||||
inference_state["point_inputs_per_obj"] = {}
|
inference_state["point_inputs_per_obj"] = {}
|
||||||
inference_state["mask_inputs_per_obj"] = {}
|
inference_state["mask_inputs_per_obj"] = {}
|
||||||
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||||
|
|
||||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
device = inference_state["device"]
|
||||||
|
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||||
current_out, _ = self._run_single_frame_inference(
|
current_out, _ = self._run_single_frame_inference(
|
||||||
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
)
|
)
|
||||||
if backbone_out is None:
|
if backbone_out is None:
|
||||||
# Cache miss -- we will run inference on a single image
|
# Cache miss -- we will run inference on a single image
|
||||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
device = inference_state["device"]
|
||||||
|
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||||
backbone_out = self.forward_image(image)
|
backbone_out = self.forward_image(image)
|
||||||
# Cache the most recent frame's feature (for repeated interactions with
|
# Cache the most recent frame's feature (for repeated interactions with
|
||||||
# a frame; we can use an LRU cache for more frames in the future).
|
# a frame; we can use an LRU cache for more frames in the future).
|
||||||
|
@@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
|
|||||||
A list of video frames to be load asynchronously without blocking session start.
|
A list of video frames to be load asynchronously without blocking session start.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_paths,
|
||||||
|
image_size,
|
||||||
|
offload_video_to_cpu,
|
||||||
|
img_mean,
|
||||||
|
img_std,
|
||||||
|
compute_device,
|
||||||
|
):
|
||||||
self.img_paths = img_paths
|
self.img_paths = img_paths
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.offload_video_to_cpu = offload_video_to_cpu
|
self.offload_video_to_cpu = offload_video_to_cpu
|
||||||
@@ -119,6 +127,7 @@ class AsyncVideoFrameLoader:
|
|||||||
# video_height and video_width be filled when loading the first image
|
# video_height and video_width be filled when loading the first image
|
||||||
self.video_height = None
|
self.video_height = None
|
||||||
self.video_width = None
|
self.video_width = None
|
||||||
|
self.compute_device = compute_device
|
||||||
|
|
||||||
# load the first frame to fill video_height and video_width and also
|
# load the first frame to fill video_height and video_width and also
|
||||||
# to cache it (since it's most likely where the user will click)
|
# to cache it (since it's most likely where the user will click)
|
||||||
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
|||||||
img -= self.img_mean
|
img -= self.img_mean
|
||||||
img /= self.img_std
|
img /= self.img_std
|
||||||
if not self.offload_video_to_cpu:
|
if not self.offload_video_to_cpu:
|
||||||
img = img.cuda(non_blocking=True)
|
img = img.to(self.compute_device, non_blocking=True)
|
||||||
self.images[index] = img
|
self.images[index] = img
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@@ -167,6 +176,7 @@ def load_video_frames(
|
|||||||
img_mean=(0.485, 0.456, 0.406),
|
img_mean=(0.485, 0.456, 0.406),
|
||||||
img_std=(0.229, 0.224, 0.225),
|
img_std=(0.229, 0.224, 0.225),
|
||||||
async_loading_frames=False,
|
async_loading_frames=False,
|
||||||
|
compute_device=torch.device("cuda"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||||
@@ -196,7 +206,12 @@ def load_video_frames(
|
|||||||
|
|
||||||
if async_loading_frames:
|
if async_loading_frames:
|
||||||
lazy_images = AsyncVideoFrameLoader(
|
lazy_images = AsyncVideoFrameLoader(
|
||||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
img_paths,
|
||||||
|
image_size,
|
||||||
|
offload_video_to_cpu,
|
||||||
|
img_mean,
|
||||||
|
img_std,
|
||||||
|
compute_device,
|
||||||
)
|
)
|
||||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||||
|
|
||||||
@@ -204,9 +219,9 @@ def load_video_frames(
|
|||||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||||
if not offload_video_to_cpu:
|
if not offload_video_to_cpu:
|
||||||
images = images.cuda()
|
images = images.to(compute_device)
|
||||||
img_mean = img_mean.cuda()
|
img_mean = img_mean.to(compute_device)
|
||||||
img_std = img_std.cuda()
|
img_std = img_std.to(compute_device)
|
||||||
# normalize by mean and std
|
# normalize by mean and std
|
||||||
images -= img_mean
|
images -= img_mean
|
||||||
images /= img_std
|
images /= img_std
|
||||||
|
Reference in New Issue
Block a user