update to latest SAM 2

This commit is contained in:
rentainhe
2024-08-21 18:11:44 +08:00
parent 35efb4a5cb
commit 6e0ddadf7c
12 changed files with 140 additions and 87 deletions

View File

@@ -44,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize a inference state."""
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
@@ -119,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
return sam_model
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
@@ -586,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base):
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
@@ -595,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base):
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).