support mask prompt for video tracking
This commit is contained in:
@@ -11,7 +11,7 @@ import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from supervision.draw.color import ColorPalette
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
@@ -29,6 +29,7 @@ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
||||
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
||||
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
||||
API_TOKEN_FOR_GD1_5 = "Your API token"
|
||||
PROMPT_TYPE_FOR_VIDEO = "mask" # "point"
|
||||
|
||||
"""
|
||||
Step 1: Environment settings and model initialization for SAM 2
|
||||
@@ -151,18 +152,32 @@ if masks.ndim == 4:
|
||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||
"""
|
||||
|
||||
# sample the positive points from mask for each objects
|
||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
|
||||
|
||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||
# sample the positive points from mask for each objects
|
||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||
|
||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
# Using mask prompt is a more straightforward way
|
||||
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||
labels = np.ones((1), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
mask=mask
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
@@ -129,18 +129,34 @@ elif masks.ndim == 4:
|
||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||
"""
|
||||
|
||||
# sample the positive points from mask for each objects
|
||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||
PROMPT_TYPE_FOR_VIDEO = "mask" # or "point"
|
||||
|
||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
|
||||
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if PROMPT_TYPE_FOR_VIDEO == "point":
|
||||
# sample the positive points from mask for each objects
|
||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||
|
||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
# Using mask prompt is a more straightforward way
|
||||
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
||||
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
||||
labels = np.ones((1), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
mask=mask
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
Reference in New Issue
Block a user