Add interface for box prompt in SAM 2 video predictor (#174)
This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained). The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
This commit is contained in:
@@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
...
|
||||
```
|
||||
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
|
||||
|
||||
## Load from 🤗 Hugging Face
|
||||
|
||||
@@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
|
Reference in New Issue
Block a user