add tracking demo with gd 1.5

This commit is contained in:
rentainhe
2024-08-05 16:05:31 +08:00
parent 41640f4add
commit 2cdd3f2d92
3 changed files with 202 additions and 5 deletions

View File

@@ -109,6 +109,13 @@ python grounded_sam2_tracking_demo.py
We have observed that the video predictor in SAM 2 currently supports **only point prompts** (please feel free to point out any updates or functionalities we may have overlooked during development). However, Grounding DINO provides box prompts, which need to be converted into point prompts for use in video tracking. A straightforward approach is to directly sample the center point of the box as a point prompt. Nevertheless, this method may encounter certain issues in practical testing scenarios. To **get a more stable segmentation results**, we reuse the SAM 2 image predictor to get the prediction mask for each object first, then we **uniformly sample points from the prediction mask** to prompt SAM 2 video predictor. We have observed that the video predictor in SAM 2 currently supports **only point prompts** (please feel free to point out any updates or functionalities we may have overlooked during development). However, Grounding DINO provides box prompts, which need to be converted into point prompts for use in video tracking. A straightforward approach is to directly sample the center point of the box as a point prompt. Nevertheless, this method may encounter certain issues in practical testing scenarios. To **get a more stable segmentation results**, we reuse the SAM 2 image predictor to get the prediction mask for each object first, then we **uniformly sample points from the prediction mask** to prompt SAM 2 video predictor.
### Grounded-SAM-2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)
We've also support video object tracking demo based on our stronger `Grounding DINO 1.5` model and `SAM 2`, you can try the following demo after applying the API keys for running `Grounding DINO 1.5`:
```bash
python grounded_sam2_tracking_demo_with_gd1.5.py
```
### Citation ### Citation

View File

@@ -0,0 +1,191 @@
# dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from track_utils import sample_points_from_masks
from video_utils import create_video_from_images
"""
Step 1: Environment settings and model initialization for SAM 2
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
"""
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
"""
# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
image = Image.open(img_path)
# Step 1: initialize the config
token = "Your API token"
config = Config(token)
# Step 2: initialize the client
client = Client(config)
# Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path)
task = DetectionTask(
image_url=image_url,
prompts=[TextPrompt(text="children. pillow")],
targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
print(input_boxes)
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
# process the detection results
OBJECTS = class_names
print(OBJECTS)
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 3:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
"""
Step 5: Visualize the segment results across the video and save them
"""
save_dir = "./tracking_results"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
for frame_idx, segments in video_segments.items():
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
object_ids = list(segments.keys())
masks = list(segments.values())
masks = np.concatenate(masks, axis=0)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
mask=masks, # (n, h, w)
class_id=np.array(object_ids, dtype=np.int32),
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
"""
Step 6: Convert the annotated frames to video
"""
output_video_path = "./children_tracking_demo_video.mp4"
create_video_from_images(save_dir, output_video_path)

View File

@@ -1,5 +1,4 @@
import numpy as np import numpy as np
from scipy.ndimage import center_of_mass
def sample_points_from_masks(masks, num_points): def sample_points_from_masks(masks, num_points):
""" """
@@ -16,7 +15,7 @@ def sample_points_from_masks(masks, num_points):
points = [] points = []
for i in range(n): for i in range(n):
# 找到当前mask中值为1的位置的坐标 # find the valid mask points
indices = np.argwhere(masks[i] == 1) indices = np.argwhere(masks[i] == 1)
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2) # the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
# we should convert it to (x, y) # we should convert it to (x, y)
@@ -24,11 +23,11 @@ def sample_points_from_masks(masks, num_points):
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
if len(indices) == 0: if len(indices) == 0:
# 如果没有有效点,返回一个空数组 # if there are no valid points, append an empty array
points.append(np.array([])) points.append(np.array([]))
continue continue
# 如果mask中的点少于需要的数量则重复采样 # resampling if there's not enough points
if len(indices) < num_points: if len(indices) < num_points:
sampled_indices = np.random.choice(len(indices), num_points, replace=True) sampled_indices = np.random.choice(len(indices), num_points, replace=True)
else: else:
@@ -37,6 +36,6 @@ def sample_points_from_masks(masks, num_points):
sampled_points = indices[sampled_indices] sampled_points = indices[sampled_indices]
points.append(sampled_points) points.append(sampled_points)
# 将结果转换为numpy数组 # convert to np.array
points = np.array(points, dtype=np.float32) points = np.array(points, dtype=np.float32)
return points return points