support mask prompt for video tracking

This commit is contained in:
rentainhe
2024-08-07 16:42:49 +08:00
parent 7c0995e9c3
commit 37cf27cfe3
2 changed files with 54 additions and 23 deletions

View File

@@ -11,7 +11,7 @@ import cv2
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
from supervision.draw.color import ColorPalette
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
@@ -29,6 +29,7 @@ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./tracking_results" SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token" API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "mask" # "point"
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -151,18 +152,32 @@ if masks.ndim == 4:
Step 3: Register each object's positive points to video predictor with seperate add_new_points call Step 3: Register each object's positive points to video predictor with seperate add_new_points call
""" """
# sample the positive points from mask for each objects assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): # If you are using point prompts, we uniformly sample positive points based on the mask
labels = np.ones((points.shape[0]), dtype=np.int32) if PROMPT_TYPE_FOR_VIDEO == "point":
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points( # sample the positive points from mask for each objects
inference_state=inference_state, all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
frame_idx=ann_frame_idx,
obj_id=object_id, for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
points=points, labels = np.ones((points.shape[0]), dtype=np.int32)
labels=labels, _, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
) inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
labels = np.ones((1), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
mask=mask
)
""" """

View File

@@ -129,18 +129,34 @@ elif masks.ndim == 4:
Step 3: Register each object's positive points to video predictor with seperate add_new_points call Step 3: Register each object's positive points to video predictor with seperate add_new_points call
""" """
# sample the positive points from mask for each objects PROMPT_TYPE_FOR_VIDEO = "mask" # or "point"
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"]
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points( # If you are using point prompts, we uniformly sample positive points based on the mask
inference_state=inference_state, if PROMPT_TYPE_FOR_VIDEO == "point":
frame_idx=ann_frame_idx, # sample the positive points from mask for each objects
obj_id=object_id, all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
points=points,
labels=labels, for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
) labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
labels = np.ones((1), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
mask=mask
)
""" """