support more prompt in simple demo
This commit is contained in:
@@ -107,18 +107,45 @@ elif masks.ndim == 4:
|
|||||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
|
||||||
|
|
||||||
|
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
||||||
|
|
||||||
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||||
|
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||||
# sample the positive points from mask for each objects
|
# sample the positive points from mask for each objects
|
||||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||||
|
|
||||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
frame_idx=ann_frame_idx,
|
frame_idx=ann_frame_idx,
|
||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
points=points,
|
points=points,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
|
# Using box prompt
|
||||||
|
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
||||||
|
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
|
# Using mask prompt is a more straightforward way
|
||||||
|
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||||
|
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||||
|
labels = np.ones((1), dtype=np.int32)
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@@ -28,7 +28,7 @@ TEXT_PROMPT = "hippopotamus."
|
|||||||
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
||||||
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
||||||
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
||||||
API_TOKEN_FOR_GD1_5 = "3491a2a256fb7ed01b2e757b713c4cb0"
|
API_TOKEN_FOR_GD1_5 = "Your API token"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@@ -129,9 +129,9 @@ elif masks.ndim == 4:
|
|||||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PROMPT_TYPE_FOR_VIDEO = "mask" # or "point"
|
PROMPT_TYPE_FOR_VIDEO = "box" # or "point"
|
||||||
|
|
||||||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
|
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
||||||
|
|
||||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||||
if PROMPT_TYPE_FOR_VIDEO == "point":
|
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||||
@@ -140,13 +140,22 @@ if PROMPT_TYPE_FOR_VIDEO == "point":
|
|||||||
|
|
||||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
frame_idx=ann_frame_idx,
|
frame_idx=ann_frame_idx,
|
||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
points=points,
|
points=points,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
|
# Using box prompt
|
||||||
|
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
||||||
|
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
# Using mask prompt is a more straightforward way
|
# Using mask prompt is a more straightforward way
|
||||||
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||||
@@ -157,6 +166,9 @@ elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
|||||||
obj_id=object_id,
|
obj_id=object_id,
|
||||||
mask=mask
|
mask=mask
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user