add tracking demo with gd 1.5
This commit is contained in:
@@ -109,6 +109,13 @@ python grounded_sam2_tracking_demo.py
|
|||||||
|
|
||||||
We have observed that the video predictor in SAM 2 currently supports **only point prompts** (please feel free to point out any updates or functionalities we may have overlooked during development). However, Grounding DINO provides box prompts, which need to be converted into point prompts for use in video tracking. A straightforward approach is to directly sample the center point of the box as a point prompt. Nevertheless, this method may encounter certain issues in practical testing scenarios. To **get a more stable segmentation results**, we reuse the SAM 2 image predictor to get the prediction mask for each object first, then we **uniformly sample points from the prediction mask** to prompt SAM 2 video predictor.
|
We have observed that the video predictor in SAM 2 currently supports **only point prompts** (please feel free to point out any updates or functionalities we may have overlooked during development). However, Grounding DINO provides box prompts, which need to be converted into point prompts for use in video tracking. A straightforward approach is to directly sample the center point of the box as a point prompt. Nevertheless, this method may encounter certain issues in practical testing scenarios. To **get a more stable segmentation results**, we reuse the SAM 2 image predictor to get the prediction mask for each object first, then we **uniformly sample points from the prediction mask** to prompt SAM 2 video predictor.
|
||||||
|
|
||||||
|
### Grounded-SAM-2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)
|
||||||
|
|
||||||
|
We've also support video object tracking demo based on our stronger `Grounding DINO 1.5` model and `SAM 2`, you can try the following demo after applying the API keys for running `Grounding DINO 1.5`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_tracking_demo_with_gd1.5.py
|
||||||
|
```
|
||||||
|
|
||||||
### Citation
|
### Citation
|
||||||
|
|
||||||
|
191
grounded_sam2_tracking_demo_with_gd1.5.py
Normal file
191
grounded_sam2_tracking_demo_with_gd1.5.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
|
from dds_cloudapi_sdk import Config
|
||||||
|
from dds_cloudapi_sdk import Client
|
||||||
|
from dds_cloudapi_sdk import DetectionTask
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
from PIL import Image
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||||
|
from track_utils import sample_points_from_masks
|
||||||
|
from video_utils import create_video_from_images
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
|
"""
|
||||||
|
# use bfloat16 for the entire notebook
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# init sam image predictor and video predictor model
|
||||||
|
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||||
|
model_cfg = "sam2_hiera_l.yaml"
|
||||||
|
|
||||||
|
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||||
|
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
||||||
|
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
||||||
|
|
||||||
|
|
||||||
|
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||||
|
video_dir = "notebooks/videos/bedroom"
|
||||||
|
|
||||||
|
# scan all the JPEG frame names in this directory
|
||||||
|
frame_names = [
|
||||||
|
p for p in os.listdir(video_dir)
|
||||||
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||||
|
]
|
||||||
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
|
||||||
|
# init video predictor state
|
||||||
|
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||||
|
|
||||||
|
ann_frame_idx = 0 # the frame index we interact with
|
||||||
|
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
||||||
|
"""
|
||||||
|
|
||||||
|
# prompt grounding dino to get the box coordinates on specific frame
|
||||||
|
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
||||||
|
image = Image.open(img_path)
|
||||||
|
|
||||||
|
# Step 1: initialize the config
|
||||||
|
token = "Your API token"
|
||||||
|
config = Config(token)
|
||||||
|
|
||||||
|
# Step 2: initialize the client
|
||||||
|
client = Client(config)
|
||||||
|
|
||||||
|
# Step 3: run the task by DetectionTask class
|
||||||
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
|
task = DetectionTask(
|
||||||
|
image_url=image_url,
|
||||||
|
prompts=[TextPrompt(text="children. pillow")],
|
||||||
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
|
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
|
||||||
|
)
|
||||||
|
|
||||||
|
client.run_task(task)
|
||||||
|
result = task.result
|
||||||
|
|
||||||
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
|
input_boxes = []
|
||||||
|
confidences = []
|
||||||
|
class_names = []
|
||||||
|
|
||||||
|
for idx, obj in enumerate(objects):
|
||||||
|
input_boxes.append(obj.bbox)
|
||||||
|
confidences.append(obj.score)
|
||||||
|
class_names.append(obj.category)
|
||||||
|
|
||||||
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
print(input_boxes)
|
||||||
|
|
||||||
|
# prompt SAM image predictor to get the mask for the object
|
||||||
|
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||||
|
|
||||||
|
# process the detection results
|
||||||
|
OBJECTS = class_names
|
||||||
|
|
||||||
|
print(OBJECTS)
|
||||||
|
|
||||||
|
# prompt SAM 2 image predictor to get the mask for the object
|
||||||
|
masks, scores, logits = image_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert the mask shape to (n, H, W)
|
||||||
|
if masks.ndim == 3:
|
||||||
|
masks = masks[None]
|
||||||
|
scores = scores[None]
|
||||||
|
logits = logits[None]
|
||||||
|
elif masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||||
|
"""
|
||||||
|
|
||||||
|
# sample the positive points from mask for each objects
|
||||||
|
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||||
|
|
||||||
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||||
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||||
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=object_id,
|
||||||
|
points=points,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||||
|
"""
|
||||||
|
video_segments = {} # video_segments contains the per-frame segmentation results
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
||||||
|
video_segments[out_frame_idx] = {
|
||||||
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 5: Visualize the segment results across the video and save them
|
||||||
|
"""
|
||||||
|
|
||||||
|
save_dir = "./tracking_results"
|
||||||
|
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
||||||
|
for frame_idx, segments in video_segments.items():
|
||||||
|
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
|
||||||
|
|
||||||
|
object_ids = list(segments.keys())
|
||||||
|
masks = list(segments.values())
|
||||||
|
masks = np.concatenate(masks, axis=0)
|
||||||
|
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
||||||
|
mask=masks, # (n, h, w)
|
||||||
|
class_id=np.array(object_ids, dtype=np.int32),
|
||||||
|
)
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 6: Convert the annotated frames to video
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_video_path = "./children_tracking_demo_video.mp4"
|
||||||
|
create_video_from_images(save_dir, output_video_path)
|
@@ -1,5 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.ndimage import center_of_mass
|
|
||||||
|
|
||||||
def sample_points_from_masks(masks, num_points):
|
def sample_points_from_masks(masks, num_points):
|
||||||
"""
|
"""
|
||||||
@@ -16,7 +15,7 @@ def sample_points_from_masks(masks, num_points):
|
|||||||
points = []
|
points = []
|
||||||
|
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
# 找到当前mask中值为1的位置的坐标
|
# find the valid mask points
|
||||||
indices = np.argwhere(masks[i] == 1)
|
indices = np.argwhere(masks[i] == 1)
|
||||||
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
|
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
|
||||||
# we should convert it to (x, y)
|
# we should convert it to (x, y)
|
||||||
@@ -24,11 +23,11 @@ def sample_points_from_masks(masks, num_points):
|
|||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
# 如果没有有效点,返回一个空数组
|
# if there are no valid points, append an empty array
|
||||||
points.append(np.array([]))
|
points.append(np.array([]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果mask中的点少于需要的数量,则重复采样
|
# resampling if there's not enough points
|
||||||
if len(indices) < num_points:
|
if len(indices) < num_points:
|
||||||
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
|
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
|
||||||
else:
|
else:
|
||||||
@@ -37,6 +36,6 @@ def sample_points_from_masks(masks, num_points):
|
|||||||
sampled_points = indices[sampled_indices]
|
sampled_points = indices[sampled_indices]
|
||||||
points.append(sampled_points)
|
points.append(sampled_points)
|
||||||
|
|
||||||
# 将结果转换为numpy数组
|
# convert to np.array
|
||||||
points = np.array(points, dtype=np.float32)
|
points = np.array(points, dtype=np.float32)
|
||||||
return points
|
return points
|
||||||
|
Reference in New Issue
Block a user