2025-04-20 01:04:26 +08:00
|
|
|
# dds cloudapi for Grounding DINO 1.5 - 更新至V2Task API
|
2024-08-06 17:11:55 +08:00
|
|
|
from dds_cloudapi_sdk import Config
|
|
|
|
from dds_cloudapi_sdk import Client
|
2025-04-20 01:04:26 +08:00
|
|
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
import os
|
|
|
|
import cv2
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import supervision as sv
|
2024-08-07 16:42:49 +08:00
|
|
|
|
2024-08-06 17:11:55 +08:00
|
|
|
from pathlib import Path
|
|
|
|
from tqdm import tqdm
|
|
|
|
from PIL import Image
|
|
|
|
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
|
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
2024-08-09 02:33:24 +02:00
|
|
|
from utils.track_utils import sample_points_from_masks
|
|
|
|
from utils.video_utils import create_video_from_images
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
Hyperparam for Ground and Tracking
|
|
|
|
"""
|
|
|
|
VIDEO_PATH = "./assets/hippopotamus.mp4"
|
|
|
|
TEXT_PROMPT = "hippopotamus."
|
|
|
|
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
|
|
|
|
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
|
|
|
|
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
2024-08-08 12:26:59 +08:00
|
|
|
API_TOKEN_FOR_GD1_5 = "Your API token"
|
2024-08-08 12:03:29 +08:00
|
|
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
2024-10-31 15:50:14 +08:00
|
|
|
BOX_THRESHOLD = 0.2
|
2025-04-20 01:04:26 +08:00
|
|
|
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
Step 1: Environment settings and model initialization for SAM 2
|
|
|
|
"""
|
|
|
|
# use bfloat16 for the entire notebook
|
|
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
|
|
|
|
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
|
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
# init sam image predictor and video predictor model
|
2024-10-10 14:55:50 +08:00
|
|
|
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
|
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
|
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
|
|
|
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
|
|
|
|
|
|
|
|
|
|
|
# # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
|
|
|
# video_dir = "notebooks/videos/bedroom"
|
|
|
|
|
|
|
|
"""
|
|
|
|
Custom video input directly using video files
|
|
|
|
"""
|
|
|
|
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
|
|
|
|
print(video_info)
|
|
|
|
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
|
|
|
|
|
|
|
|
# saving video to frames
|
|
|
|
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
|
|
|
|
source_frames.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
with sv.ImageSink(
|
|
|
|
target_dir_path=source_frames,
|
|
|
|
overwrite=True,
|
|
|
|
image_name_pattern="{:05d}.jpg"
|
|
|
|
) as sink:
|
|
|
|
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
|
|
|
|
sink.save_image(frame)
|
|
|
|
|
|
|
|
# scan all the JPEG frame names in this directory
|
|
|
|
frame_names = [
|
|
|
|
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
|
|
|
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
|
|
|
]
|
|
|
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
|
|
|
|
|
|
|
# init video predictor state
|
|
|
|
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
|
|
|
|
|
|
|
|
ann_frame_idx = 0 # the frame index we interact with
|
|
|
|
"""
|
|
|
|
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
|
|
|
"""
|
|
|
|
|
|
|
|
# prompt grounding dino to get the box coordinates on specific frame
|
|
|
|
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
|
|
|
|
image = Image.open(img_path)
|
|
|
|
|
|
|
|
# Step 1: initialize the config
|
|
|
|
config = Config(API_TOKEN_FOR_GD1_5)
|
|
|
|
|
|
|
|
# Step 2: initialize the client
|
|
|
|
client = Client(config)
|
|
|
|
|
2025-04-20 01:04:26 +08:00
|
|
|
# Step 3: run the task using V2Task class
|
2024-08-06 17:11:55 +08:00
|
|
|
# if you are processing local image file, upload them to DDS server to get the image url
|
|
|
|
image_url = client.upload_file(img_path)
|
|
|
|
|
2025-04-20 01:04:26 +08:00
|
|
|
task = V2Task(
|
|
|
|
api_path="/v2/task/grounding_dino/detection",
|
|
|
|
api_body={
|
|
|
|
"model": "GroundingDino-1.5-Pro",
|
|
|
|
"image": image_url,
|
|
|
|
"prompt": {
|
|
|
|
"type": "text",
|
|
|
|
"text": TEXT_PROMPT
|
|
|
|
},
|
|
|
|
"targets": ["bbox"],
|
|
|
|
"bbox_threshold": BOX_THRESHOLD,
|
|
|
|
"iou_threshold": IOU_THRESHOLD,
|
|
|
|
}
|
2024-08-06 17:11:55 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
client.run_task(task)
|
|
|
|
result = task.result
|
|
|
|
|
2025-04-20 01:04:26 +08:00
|
|
|
objects = result["objects"] # the list of detected objects
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
input_boxes = []
|
|
|
|
confidences = []
|
|
|
|
class_names = []
|
|
|
|
|
|
|
|
for idx, obj in enumerate(objects):
|
2025-04-20 01:04:26 +08:00
|
|
|
input_boxes.append(obj["bbox"])
|
|
|
|
confidences.append(obj["score"])
|
|
|
|
class_names.append(obj["category"])
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
input_boxes = np.array(input_boxes)
|
|
|
|
|
|
|
|
print(input_boxes)
|
|
|
|
|
|
|
|
# prompt SAM image predictor to get the mask for the object
|
|
|
|
image_predictor.set_image(np.array(image.convert("RGB")))
|
|
|
|
|
|
|
|
# process the detection results
|
|
|
|
OBJECTS = class_names
|
|
|
|
|
|
|
|
print(OBJECTS)
|
|
|
|
|
|
|
|
# prompt SAM 2 image predictor to get the mask for the object
|
|
|
|
masks, scores, logits = image_predictor.predict(
|
|
|
|
point_coords=None,
|
|
|
|
point_labels=None,
|
|
|
|
box=input_boxes,
|
|
|
|
multimask_output=False,
|
|
|
|
)
|
|
|
|
# convert the mask shape to (n, H, W)
|
|
|
|
if masks.ndim == 4:
|
|
|
|
masks = masks.squeeze(1)
|
|
|
|
|
|
|
|
"""
|
|
|
|
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
|
|
|
"""
|
|
|
|
|
2024-08-08 12:03:29 +08:00
|
|
|
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
|
2024-08-07 16:42:49 +08:00
|
|
|
|
|
|
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
|
|
|
if PROMPT_TYPE_FOR_VIDEO == "point":
|
|
|
|
# sample the positive points from mask for each objects
|
|
|
|
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
|
|
|
|
|
|
|
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
|
|
|
labels = np.ones((points.shape[0]), dtype=np.int32)
|
2024-08-08 12:03:29 +08:00
|
|
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
2024-08-07 16:42:49 +08:00
|
|
|
inference_state=inference_state,
|
|
|
|
frame_idx=ann_frame_idx,
|
|
|
|
obj_id=object_id,
|
|
|
|
points=points,
|
|
|
|
labels=labels,
|
|
|
|
)
|
2024-08-08 12:03:29 +08:00
|
|
|
# Using box prompt
|
|
|
|
elif PROMPT_TYPE_FOR_VIDEO == "box":
|
|
|
|
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
|
|
|
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
|
|
|
|
inference_state=inference_state,
|
|
|
|
frame_idx=ann_frame_idx,
|
|
|
|
obj_id=object_id,
|
|
|
|
box=box,
|
|
|
|
)
|
2024-08-07 16:42:49 +08:00
|
|
|
# Using mask prompt is a more straightforward way
|
|
|
|
elif PROMPT_TYPE_FOR_VIDEO == "mask":
|
|
|
|
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
|
|
|
|
labels = np.ones((1), dtype=np.int32)
|
|
|
|
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
|
|
|
inference_state=inference_state,
|
|
|
|
frame_idx=ann_frame_idx,
|
|
|
|
obj_id=object_id,
|
|
|
|
mask=mask
|
|
|
|
)
|
2024-08-08 12:03:29 +08:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
|
2024-08-06 17:11:55 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
|
|
|
"""
|
|
|
|
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
|
|
|
video_segments[out_frame_idx] = {
|
|
|
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
|
|
for i, out_obj_id in enumerate(out_obj_ids)
|
|
|
|
}
|
|
|
|
|
|
|
|
"""
|
|
|
|
Step 5: Visualize the segment results across the video and save them
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
|
|
|
|
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
|
|
|
|
|
|
|
|
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
|
|
|
|
|
|
|
for frame_idx, segments in video_segments.items():
|
|
|
|
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
|
|
|
|
|
|
|
|
object_ids = list(segments.keys())
|
|
|
|
masks = list(segments.values())
|
|
|
|
masks = np.concatenate(masks, axis=0)
|
|
|
|
|
|
|
|
detections = sv.Detections(
|
|
|
|
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
|
|
|
mask=masks, # (n, h, w)
|
|
|
|
class_id=np.array(object_ids, dtype=np.int32),
|
|
|
|
)
|
|
|
|
box_annotator = sv.BoxAnnotator()
|
|
|
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
|
|
|
label_annotator = sv.LabelAnnotator()
|
|
|
|
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
|
|
|
|
mask_annotator = sv.MaskAnnotator()
|
|
|
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
|
|
|
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
Step 6: Convert the annotated frames to video
|
|
|
|
"""
|
|
|
|
|
|
|
|
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)
|