update to the latest sam2 version and support box prompts in video tracking

This commit is contained in:
rentainhe
2024-08-08 12:03:29 +08:00
parent 96cbab92e0
commit 077064c365
8 changed files with 272 additions and 45 deletions

View File

@@ -28,8 +28,8 @@ TEXT_PROMPT = "hippopotamus."
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "mask" # "point"
API_TOKEN_FOR_GD1_5 = "3491a2a256fb7ed01b2e757b713c4cb0"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
"""
Step 1: Environment settings and model initialization for SAM 2
@@ -152,7 +152,7 @@ if masks.ndim == 4:
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
# If you are using point prompts, we uniformly sample positive points based on the mask
if PROMPT_TYPE_FOR_VIDEO == "point":
@@ -161,13 +161,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point":
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using box prompt
elif PROMPT_TYPE_FOR_VIDEO == "box":
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
box=box,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
@@ -178,7 +187,8 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask":
obj_id=object_id,
mask=mask
)
else:
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame