add tracking demo with gd 1.5

This commit is contained in:
rentainhe
2024-08-05 16:05:31 +08:00
parent 41640f4add
commit 2cdd3f2d92
3 changed files with 202 additions and 5 deletions

View File

@@ -0,0 +1,191 @@
# dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from track_utils import sample_points_from_masks
from video_utils import create_video_from_images
"""
Step 1: Environment settings and model initialization for SAM 2
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
"""
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
"""
# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
image = Image.open(img_path)
# Step 1: initialize the config
token = "Your API token"
config = Config(token)
# Step 2: initialize the client
client = Client(config)
# Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path)
task = DetectionTask(
image_url=image_url,
prompts=[TextPrompt(text="children. pillow")],
targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
print(input_boxes)
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
# process the detection results
OBJECTS = class_names
print(OBJECTS)
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 3:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
"""
Step 5: Visualize the segment results across the video and save them
"""
save_dir = "./tracking_results"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
for frame_idx, segments in video_segments.items():
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
object_ids = list(segments.keys())
masks = list(segments.values())
masks = np.concatenate(masks, axis=0)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
mask=masks, # (n, h, w)
class_id=np.array(object_ids, dtype=np.int32),
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
"""
Step 6: Convert the annotated frames to video
"""
output_video_path = "./children_tracking_demo_video.mp4"
create_video_from_images(save_dir, output_video_path)