add dino-x sam2 tracking demo

This commit is contained in:
rentainhe
2024-12-02 21:01:55 +08:00
parent 779147bc48
commit de62b7fb0b
2 changed files with 244 additions and 1 deletions

View File

@@ -21,7 +21,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
## Latest updates
- `2024/12/02`: Support **DINO-X SAM 2 Demos**, please install the latest version of `dds-cloudapi-sdk` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) for more details.
- `2024/12/02`: Support **DINO-X SAM 2 Demos** (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
- `2024/10/24`: Support [SAHI (Slicing Aided Hyper Inference)](https://docs.ultralytics.com/guides/sahi-tiled-inference/) on Grounded SAM 2 (with Grounding DINO 1.5) which may be helpful for inferencing high resolution image with dense small objects (e.g. **4K** images).
- `2024/10/10`: Support `SAM-2.1` models, if you want to use `SAM 2.1` model, you need to update to the latest code and reinstall SAM 2 follow [SAM 2.1 Installation](https://github.com/facebookresearch/sam2?tab=readme-ov-file#latest-updates).
- `2024/08/31`: Support `dump json results` in Grounded SAM 2 Image Demos (with Grounding DINO).
@@ -41,6 +41,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
- [Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-grounding-dino-15--16)
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino)
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
- [Grounded SAM 2 Video Object Tracking Demo (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x)
- [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
- [Grounded SAM 2 Florence-2 Demos](#grounded-sam-2-florence-2-demos)
- [Grounded SAM 2 Florence-2 Image Demo](#grounded-sam-2-florence-2-image-demo)
@@ -280,6 +281,14 @@ And we will automatically save the tracking visualization results in `OUTPUT_VID
> [!WARNING]
> We initialize the box prompts on the first frame of the input video. If you want to start from different frame, you can refine `ann_frame_idx` by yourself in our code.
### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with DINO-X)
Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with DINO-X and SAM 2 by using the following scripts:
```bash
python grounded_sam2_tracking_demo_custom_video_input_dinox.py
```
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
In above demos, we only prompt Grounded SAM 2 in specific frame, which may not be friendly to find new object during the whole video. In this demo, we try to **find new objects** and assign them with new ID across the whole video, this function is **still under develop**. it's not that stable now.

View File

@@ -0,0 +1,234 @@
# dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk import TextPrompt
import os
import cv2
import torch
import numpy as np
import supervision as sv
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
"""
Hyperparam for Ground and Tracking
"""
VIDEO_PATH = "./assets/hippopotamus.mp4"
TEXT_PROMPT = "hippopotamus."
OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2
"""
Step 1: Environment settings and model initialization for SAM 2
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)
# # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
# video_dir = "notebooks/videos/bedroom"
"""
Custom video input directly using video files
"""
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info
print(video_info)
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
# saving video to frames
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
source_frames.mkdir(parents=True, exist_ok=True)
with sv.ImageSink(
target_dir_path=source_frames,
overwrite=True,
image_name_pattern="{:05d}.jpg"
) as sink:
for frame in tqdm(frame_generator, desc="Saving Video Frames"):
sink.save_image(frame)
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
ann_frame_idx = 0 # the frame index we interact with
"""
Step 2: Prompt DINO-X with Cloud API for box coordinates
"""
# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
image = Image.open(img_path)
# Step 1: initialize the config
config = Config(API_TOKEN_FOR_GD1_5)
# Step 2: initialize the client
client = Client(config)
# Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path)
task = DinoxTask(
image_url=image_url,
prompts=[TextPrompt(text=TEXT_PROMPT)]
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
print(input_boxes)
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
# process the detection results
OBJECTS = class_names
print(OBJECTS)
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""
assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"
# If you are using point prompts, we uniformly sample positive points based on the mask
if PROMPT_TYPE_FOR_VIDEO == "point":
# sample the positive points from mask for each objects
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
labels = np.ones((points.shape[0]), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
points=points,
labels=labels,
)
# Using box prompt
elif PROMPT_TYPE_FOR_VIDEO == "box":
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
box=box,
)
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
labels = np.ones((1), dtype=np.int32)
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
mask=mask
)
else:
raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
"""
Step 5: Visualize the segment results across the video and save them
"""
if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
os.makedirs(SAVE_TRACKING_RESULTS_DIR)
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
for frame_idx, segments in video_segments.items():
img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
object_ids = list(segments.keys())
masks = list(segments.values())
masks = np.concatenate(masks, axis=0)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
mask=masks, # (n, h, w)
class_id=np.array(object_ids, dtype=np.int32),
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
"""
Step 6: Convert the annotated frames to video
"""
create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)