Add interface for box prompt in SAM 2 video predictor (#174)

This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained).

The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
This commit is contained in:
Ronghang Hu
2024-08-07 11:54:30 -07:00
committed by GitHub
parent 6ba4c65cb2
commit 6ecb5ff8d0
3 changed files with 451 additions and 89 deletions

View File

@@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
## Load from 🤗 Hugging Face
@@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):