This commit is contained in:
SusanSHEN
2024-08-09 20:05:01 +02:00
6 changed files with 334 additions and 15 deletions

View File

@@ -12,6 +12,11 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
[![Video Name](./assets/grounded_sam_2_intro.jpg)](https://github.com/user-attachments/assets/f0fb0022-779a-49fb-8f46-3a18a8b4e893)
## News
- `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
- `2024/08/07`: Support custom video inputs, users need only submit their video file (e.g. mp4 file) with specific text prompts to get an impressive demo videos.
## Contents
- [Installation](#installation)
- [Grounded SAM 2 Demo](#grounded-sam-2-demos)
@@ -21,6 +26,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
- [Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-grounding-dino-15--16)
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino)
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
- [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
- [Citation](#citation)
@@ -167,12 +173,15 @@ And we will automatically save the tracking visualization results in `OUTPUT_VID
> [!WARNING]
> We initialize the box prompts on the first frame of the input video. If you want to start from different frame, you can refine `ann_frame_idx` by yourself in our code.
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
In above demos, we only prompt Grounded SAM 2 in specific frame, which may not be friendly to find new object during the whole video. In this demo, we try to **find new objects** and assign them with new ID across the whole video, this function is **still under develop**. it's not that stable now.
Users can upload their own video files and specify custom text prompts for grounding and tracking using the Grounding DINO and SAM 2 frameworks. To do this, execute the script:
```python
grounded_sam2_tracking_demo_with_continuous_id.py
```bash
python grounded_sam2_tracking_demo_with_continuous_id.py
```
You can customize various parameters including:
@@ -186,13 +195,15 @@ You can customize various parameters including:
- `text_threshold`: text threshold for groundingdino model
Note: This method supports only the mask type of text prompt.
The demo video is:
[![car tracking demo data](./assets/tracking_car_1.jpg)](./assets/tracking_car.mp4)
After running our demo code, you can get the tracking results as follows:
[![car tracking result data](./assets/tracking_car_mask_1.jpg)](./assets/tracking_car_output.mp4)
[![Video Name](./assets/tracking_car_mask_1.jpg)](https://github.com/user-attachments/assets/d3f91ad0-3d32-43c4-a0dc-0bed661415f4)
If you want to try `Grounding DINO 1.5` model, you can run the following scripts after setting your API token:
```bash
python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
```
### Citation
@@ -216,6 +227,15 @@ If you find this project helpful for your research, please consider citing the f
year={2023}
}
@misc{ren2024grounding,
title={Grounding DINO 1.5: Advance the "Edge" of Open-Set Object Detection},
author={Tianhe Ren and Qing Jiang and Shilong Liu and Zhaoyang Zeng and Wenlong Liu and Han Gao and Hongjie Huang and Zhengyu Ma and Xiaoke Jiang and Yihao Chen and Yuda Xiong and Hao Zhang and Feng Li and Peijun Tang and Kent Yu and Lei Zhang},
year={2024},
eprint={2405.10300},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{ren2024grounded,
title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},

Binary file not shown.

Binary file not shown.

View File

@@ -29,7 +29,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device",device)
print("device", device)
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
@@ -68,7 +68,7 @@ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
step = 10 # the step to sample frames for groundedDino predictor
step = 10 # the step to sample frames for Grounding DINO predictor
sam2_masks = MaskDictionatyModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
@@ -189,10 +189,9 @@ for start_frame_idx in range(0, len(frame_names), step):
json.dump(json_data, f)
"""
Step 6: Draw the results and save the video
"""
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
create_video_from_images(result_dir, output_video_path, frame_rate=15)
create_video_from_images(result_dir, output_video_path, frame_rate=30)

View File

@@ -0,0 +1,216 @@
# dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os
import torch
import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from utils.video_utils import create_video_from_images
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionatyModel, ObjectInfo
import json
import copy
"""
Step 1: Environment settings and model initialization
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device", device)
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
image_predictor = SAM2ImagePredictor(sam2_image_model)
# init grounding dino model from huggingface
model_id = "IDEA-Research/grounding-dino-tiny"
processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
# setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot
text = "car."
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/car"
# 'output_dir' is the directory to save the annotated frames
output_dir = "./outputs"
# 'output_video_path' is the path to save the final video
output_video_path = "./outputs/output.mp4"
# create the output directory
CommonUtils.creat_dirs(output_dir)
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
result_dir = os.path.join(output_dir, "result")
CommonUtils.creat_dirs(mask_data_dir)
CommonUtils.creat_dirs(json_data_dir)
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
step = 10 # the step to sample frames for Grounding DINO predictor
sam2_masks = MaskDictionatyModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
"""
print("Total frames:", len(frame_names))
for start_frame_idx in range(0, len(frame_names), step):
# prompt grounding dino to get the box coordinates on specific frame
print("start_frame_idx", start_frame_idx)
# continue
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
image = Image.open(img_path)
image_base_name = frame_names[start_frame_idx].split(".")[0]
mask_dict = MaskDictionatyModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
# run Grounding DINO 1.5 on the image
API_TOKEN_FOR_GD1_5 = "Your API token"
config = Config(API_TOKEN_FOR_GD1_5)
# Step 2: initialize the client
client = Client(config)
image_url = client.upload_file(img_path)
task = DetectionTask(
image_url=image_url,
prompts=[TextPrompt(text=text)],
targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
OBJECTS = class_names
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
"""
Step 3: Register each object's positive points to video predictor
"""
# If you are using point prompts, we uniformly sample positive points based on the mask
if mask_dict.promote_type == "mask":
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
else:
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count)
video_predictor.reset_state(inference_state)
if len(mask_dict.labels) == 0:
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
continue
video_predictor.reset_state(inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state,
start_frame_idx,
object_id,
object_info.mask,
)
video_segments = {} # output the following {step} frames tracking masks
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
frame_masks = MaskDictionatyModel()
for i, out_obj_id in enumerate(out_obj_ids):
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
object_info.update_box()
frame_masks.labels[out_obj_id] = object_info
image_base_name = frame_names[out_frame_idx].split(".")[0]
frame_masks.mask_name = f"mask_{image_base_name}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
video_segments[out_frame_idx] = frame_masks
sam2_masks = copy.deepcopy(frame_masks)
print("video_segments:", len(video_segments))
"""
Step 5: save the tracking masks and json files
"""
for frame_idx, frame_masks_info in video_segments.items():
mask = frame_masks_info.labels
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
for obj_id, obj_info in mask.items():
mask_img[obj_info.mask == True] = obj_id
mask_img = mask_img.numpy().astype(np.uint16)
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
json_data = frame_masks_info.to_dict()
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
with open(json_data_path, "w") as f:
json.dump(json_data, f)
"""
Step 6: Draw the results and save the video
"""
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
create_video_from_images(result_dir, output_video_path, frame_rate=30)

View File

@@ -3,6 +3,7 @@ import json
import cv2
import numpy as np
from dataclasses import dataclass
import supervision as sv
import random
class CommonUtils:
@@ -21,7 +22,90 @@ class CommonUtils:
print(f"Path '{path}' already exists.")
except Exception as e:
print(f"An error occurred while creating the path: {e}")
@staticmethod
def draw_masks_and_box_with_supervision(raw_image_path, mask_path, json_path, output_path):
CommonUtils.creat_dirs(output_path)
raw_image_name_list = os.listdir(raw_image_path)
raw_image_name_list.sort()
for raw_image_name in raw_image_name_list:
image_path = os.path.join(raw_image_path, raw_image_name)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError("Image file not found.")
# load mask
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
mask = np.load(mask_npy_path)
# color map
unique_ids = np.unique(mask)
# get each mask from unique mask file
all_object_masks = []
for uid in unique_ids:
if uid == 0: # skip background id
continue
else:
object_mask = (mask == uid)
all_object_masks.append(object_mask[None])
# get n masks: (n, h, w)
all_object_masks = np.concatenate(all_object_masks, axis=0)
# load box information
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
all_object_boxes = []
all_object_ids = []
all_class_names = []
object_id_to_name = {}
with open(file_path, "r") as file:
json_data = json.load(file)
for obj_id, obj_item in json_data["labels"].items():
# box id
instance_id = obj_item["instance_id"]
if instance_id not in unique_ids: # not a valid box
continue
# box coordinates
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
all_object_boxes.append([x1, y1, x2, y2])
# box name
class_name = obj_item["class_name"]
# build id list and id2name mapping
all_object_ids.append(instance_id)
all_class_names.append(class_name)
object_id_to_name[instance_id] = class_name
# Adjust object id and boxes to ascending order
paired_id_and_box = zip(all_object_ids, all_object_boxes)
sorted_pair = sorted(paired_id_and_box, key=lambda pair: pair[0])
# Because we get the mask data as ascending order, so we also need to ascend box and ids
all_object_ids = [pair[0] for pair in sorted_pair]
all_object_boxes = [pair[1] for pair in sorted_pair]
detections = sv.Detections(
xyxy=np.array(all_object_boxes),
mask=all_object_masks,
class_id=np.array(all_object_ids, dtype=np.int32),
)
# custom label to show both id and class name
labels = [
f"{instance_id}: {class_name}" for instance_id, class_name in zip(all_object_ids, all_class_names)
]
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=labels)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
output_image_path = os.path.join(output_path, raw_image_name)
cv2.imwrite(output_image_path, annotated_frame)
print(f"Annotated image saved as {output_image_path}")
@staticmethod
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
CommonUtils.creat_dirs(output_path)
@@ -40,7 +124,7 @@ class CommonUtils:
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
colors[0] = (0, 0, 0) # background color
# apply mask to image
# apply mask to image in RBG channels
colored_mask = np.zeros_like(image)
for uid in unique_ids:
colored_mask[mask == uid] = colors[uid]