support mask prompt for video tracking

This commit is contained in:
rentainhe
2024-08-07 16:42:49 +08:00
parent 7c0995e9c3
commit 37cf27cfe3
2 changed files with 54 additions and 23 deletions

View File

@@ -129,18 +129,34 @@ elif masks.ndim == 4:
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
PROMPT_TYPE_FOR_VIDEO = "mask" # or "point"
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
# If you are using point prompts, we uniformly sample positive points based on the mask
if PROMPT_TYPE_FOR_VIDEO == "point":
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
labels = np.ones((1), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
mask=mask
)
"""