diff --git a/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py b/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py index 4b98438..e509f4a 100644 --- a/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py +++ b/grounded_sam2_tracking_demo_custom_video_input_gd1.5.py @@ -11,7 +11,7 @@ import cv2 import torch import numpy as np import supervision as sv -from supervision.draw.color import ColorPalette + from pathlib import Path from tqdm import tqdm from PIL import Image @@ -29,6 +29,7 @@ OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4" SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" SAVE_TRACKING_RESULTS_DIR = "./tracking_results" API_TOKEN_FOR_GD1_5 = "Your API token" +PROMPT_TYPE_FOR_VIDEO = "mask" # "point" """ Step 1: Environment settings and model initialization for SAM 2 @@ -151,18 +152,32 @@ if masks.ndim == 4: Step 3: Register each object's positive points to video predictor with seperate add_new_points call """ -# sample the positive points from mask for each objects -all_sample_points = sample_points_from_masks(masks=masks, num_points=10) +assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"] -for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): - labels = np.ones((points.shape[0]), dtype=np.int32) - _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( - inference_state=inference_state, - frame_idx=ann_frame_idx, - obj_id=object_id, - points=points, - labels=labels, - ) +# If you are using point prompts, we uniformly sample positive points based on the mask +if PROMPT_TYPE_FOR_VIDEO == "point": + # sample the positive points from mask for each objects + all_sample_points = sample_points_from_masks(masks=masks, num_points=10) + + for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): + labels = np.ones((points.shape[0]), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + points=points, + labels=labels, + ) +# Using mask prompt is a more straightforward way +elif PROMPT_TYPE_FOR_VIDEO == "mask": + for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): + labels = np.ones((1), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + mask=mask + ) """ diff --git a/grounded_sam2_tracking_demo_with_gd1.5.py b/grounded_sam2_tracking_demo_with_gd1.5.py index 93fe2c1..7cd42b5 100644 --- a/grounded_sam2_tracking_demo_with_gd1.5.py +++ b/grounded_sam2_tracking_demo_with_gd1.5.py @@ -129,18 +129,34 @@ elif masks.ndim == 4: Step 3: Register each object's positive points to video predictor with seperate add_new_points call """ -# sample the positive points from mask for each objects -all_sample_points = sample_points_from_masks(masks=masks, num_points=10) +PROMPT_TYPE_FOR_VIDEO = "mask" # or "point" -for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): - labels = np.ones((points.shape[0]), dtype=np.int32) - _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( - inference_state=inference_state, - frame_idx=ann_frame_idx, - obj_id=object_id, - points=points, - labels=labels, - ) +assert PROMPT_TYPE_FOR_VIDEO in ["point", "mask"] + +# If you are using point prompts, we uniformly sample positive points based on the mask +if PROMPT_TYPE_FOR_VIDEO == "point": + # sample the positive points from mask for each objects + all_sample_points = sample_points_from_masks(masks=masks, num_points=10) + + for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): + labels = np.ones((points.shape[0]), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + points=points, + labels=labels, + ) +# Using mask prompt is a more straightforward way +elif PROMPT_TYPE_FOR_VIDEO == "mask": + for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): + labels = np.ones((1), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + mask=mask + ) """