Address comment

This commit is contained in:
Niels
2024-08-07 17:48:12 +02:00
parent 43c385c263
commit 9b58611e24

View File

@@ -127,8 +127,14 @@ from sam2.sam2_video_predictor import SAM2VideoPredictor
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>) state = predictor.init_state(<your_video>)
masks, _, _ = predictor.predict(<input_prompts>)
# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
``` ```
## Model Description ## Model Description