Add interface for box prompt in SAM 2 video predictor (#174)
This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained). The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
This commit is contained in:
@@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
...
|
||||
```
|
||||
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
|
||||
|
||||
## Load from 🤗 Hugging Face
|
||||
|
||||
@@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
|
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
@@ -163,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
return len(inference_state["obj_idx_to_id"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_points(
|
||||
def add_new_points_or_box(
|
||||
self,
|
||||
inference_state,
|
||||
frame_idx,
|
||||
obj_id,
|
||||
points,
|
||||
labels,
|
||||
points=None,
|
||||
labels=None,
|
||||
clear_old_points=True,
|
||||
normalize_coords=True,
|
||||
box=None,
|
||||
):
|
||||
"""Add new points to a frame."""
|
||||
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
||||
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
||||
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
||||
|
||||
if not isinstance(points, torch.Tensor):
|
||||
if (points is not None) != (labels is not None):
|
||||
raise ValueError("points and labels must be provided together")
|
||||
if points is None and box is None:
|
||||
raise ValueError("at least one of points or box must be provided as input")
|
||||
|
||||
if points is None:
|
||||
points = torch.zeros(0, 2, dtype=torch.float32)
|
||||
elif not isinstance(points, torch.Tensor):
|
||||
points = torch.tensor(points, dtype=torch.float32)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
if labels is None:
|
||||
labels = torch.zeros(0, dtype=torch.int32)
|
||||
elif not isinstance(labels, torch.Tensor):
|
||||
labels = torch.tensor(labels, dtype=torch.int32)
|
||||
if points.dim() == 2:
|
||||
points = points.unsqueeze(0) # add batch dimension
|
||||
if labels.dim() == 1:
|
||||
labels = labels.unsqueeze(0) # add batch dimension
|
||||
|
||||
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
||||
# along with the user-provided points (consistent with how SAM 2 is trained).
|
||||
if box is not None:
|
||||
if not clear_old_points:
|
||||
raise ValueError(
|
||||
"cannot add box without clearing old points, since "
|
||||
"box prompt must be provided before any point prompt "
|
||||
"(please use clear_old_points=True instead)"
|
||||
)
|
||||
if inference_state["tracking_has_started"]:
|
||||
warnings.warn(
|
||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||
"use box prompt as an *initial* input before tracking, please call "
|
||||
"'reset_state' on the inference state to restart from scratch.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not isinstance(box, torch.Tensor):
|
||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||
box_coords = box.reshape(1, 2, 2)
|
||||
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
||||
box_labels = box_labels.reshape(1, 2)
|
||||
points = torch.cat([box_coords, points], dim=1)
|
||||
labels = torch.cat([box_labels, labels], dim=1)
|
||||
|
||||
if normalize_coords:
|
||||
video_H = inference_state["video_height"]
|
||||
video_W = inference_state["video_width"]
|
||||
@@ -268,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
return frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def add_new_points(self, *args, **kwargs):
|
||||
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
||||
return self.add_new_points_or_box(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_mask(
|
||||
self,
|
||||
@@ -548,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points` or `add_new_mask`)
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
|
Reference in New Issue
Block a user