add grounded samurai demo with dino-x

This commit is contained in:
rentainhe
2024-12-11 15:55:25 +08:00
parent c38520be44
commit 859a3e4721
3 changed files with 381 additions and 127 deletions

226
grounded_samurai_dinox.py Normal file
View File

@@ -0,0 +1,226 @@
# libraries for SAMURAI
import os
import cv2
import torch
import numpy as np
import supervision as sv
import sys
from pathlib import Path
from tqdm import tqdm
from PIL import Image
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor
# dds cloudapi for DINO-X
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
"""
Hyperparam for Ground and Tracking
"""
VIDEO_PATH = "demo.mp4"
TEXT_PROMPT = "person."
OUTPUT_VIDEO_PATH = "./tracking_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_DINOX = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2
"""
Step 1: Environment settings and model initialization for SAM 2
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "/comp_robot/rentianhe/code/samurai/sam2/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/samurai/sam2.1_hiera_l.yaml"
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
# # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
# video_dir = "notebooks/videos/bedroom"
"""
Custom video input directly using video files
"""
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
print(video_info)
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
# saving video to frames
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
source_frames.mkdir(parents=True, exist_ok=True)
with sv.ImageSink(
target_dir_path=source_frames,
overwrite=True,
image_name_pattern="{:05d}.jpg"
) as sink:
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
sink.save_image(frame)
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
ann_frame_idx = 0 # the frame index we interact with
"""
Step 2: Prompt DINO-X with Cloud API for box coordinates
"""
# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
image = Image.open(img_path)
# Step 1: initialize the config
config = Config(API_TOKEN_FOR_DINOX)
# Step 2: initialize the client
client = Client(config)
# Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=TEXT_PROMPT)],
bbox_threshold=0.25,
targets=[DetectionTarget.BBox],
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
print(input_boxes)
# process the detection results
OBJECTS = class_names
print(OBJECTS)
"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
# Using box prompt
if PROMPT_TYPE_FOR_VIDEO == "box":
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
box=box,
)
break
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
"""
Step 5: Visualize the segment results across the video and save them
"""
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
for frame_idx, segments in video_segments.items():
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
object_ids = list(segments.keys())
masks = list(segments.values())
masks = np.concatenate(masks, axis=0)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
mask=masks, # (n, h, w)
class_id=np.array(object_ids, dtype=np.int32),
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
"""
Step 6: Convert the annotated frames to video
"""
def create_video_from_images(image_folder, output_video_path, frame_rate=25):
# define valid extension
valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
# get all image files in the folder
image_files = [f for f in os.listdir(image_folder)
if os.path.splitext(f)[1] in valid_extensions]
image_files.sort() # sort the files in alphabetical order
print(image_files)
if not image_files:
raise ValueError("No valid image files found in the specified folder.")
# load the first image to get the dimensions of the video
first_image_path = os.path.join(image_folder, image_files[0])
first_image = cv2.imread(first_image_path)
height, width, _ = first_image.shape
# create a video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # codec for saving the video
video_writer = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))
# write each image to the video
for image_file in tqdm(image_files):
image_path = os.path.join(image_folder, image_file)
image = cv2.imread(image_path)
video_writer.write(image)
# source release
video_writer.release()
print(f"Video saved at {output_video_path}")
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)