Merge branch 'main' of https://github.com/ShuoShenDe/Grounded-SAM-2 into main
This commit is contained in:
36
README.md
36
README.md
@@ -12,6 +12,11 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
|
||||
|
||||
[](https://github.com/user-attachments/assets/f0fb0022-779a-49fb-8f46-3a18a8b4e893)
|
||||
|
||||
## News
|
||||
|
||||
- `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
|
||||
- `2024/08/07`: Support custom video inputs, users need only submit their video file (e.g. mp4 file) with specific text prompts to get an impressive demo videos.
|
||||
|
||||
## Contents
|
||||
- [Installation](#installation)
|
||||
- [Grounded SAM 2 Demo](#grounded-sam-2-demos)
|
||||
@@ -21,6 +26,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
|
||||
- [Grounded SAM 2 Video Object Tracking Demo (with Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-grounding-dino-15--16)
|
||||
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino)
|
||||
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
|
||||
- [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
|
||||
- [Citation](#citation)
|
||||
|
||||
|
||||
@@ -167,12 +173,15 @@ And we will automatically save the tracking visualization results in `OUTPUT_VID
|
||||
> [!WARNING]
|
||||
> We initialize the box prompts on the first frame of the input video. If you want to start from different frame, you can refine `ann_frame_idx` by yourself in our code.
|
||||
|
||||
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
|
||||
### Grounded-SAM-2 Video Object Tracking with Continuous ID (with Grounding DINO)
|
||||
|
||||
In above demos, we only prompt Grounded SAM 2 in specific frame, which may not be friendly to find new object during the whole video. In this demo, we try to **find new objects** and assign them with new ID across the whole video, this function is **still under develop**. it's not that stable now.
|
||||
|
||||
Users can upload their own video files and specify custom text prompts for grounding and tracking using the Grounding DINO and SAM 2 frameworks. To do this, execute the script:
|
||||
|
||||
|
||||
```python
|
||||
grounded_sam2_tracking_demo_with_continuous_id.py
|
||||
```bash
|
||||
python grounded_sam2_tracking_demo_with_continuous_id.py
|
||||
```
|
||||
|
||||
You can customize various parameters including:
|
||||
@@ -186,13 +195,15 @@ You can customize various parameters including:
|
||||
- `text_threshold`: text threshold for groundingdino model
|
||||
Note: This method supports only the mask type of text prompt.
|
||||
|
||||
The demo video is:
|
||||
[](./assets/tracking_car.mp4)
|
||||
|
||||
|
||||
After running our demo code, you can get the tracking results as follows:
|
||||
[](./assets/tracking_car_output.mp4)
|
||||
|
||||
[](https://github.com/user-attachments/assets/d3f91ad0-3d32-43c4-a0dc-0bed661415f4)
|
||||
|
||||
If you want to try `Grounding DINO 1.5` model, you can run the following scripts after setting your API token:
|
||||
|
||||
```bash
|
||||
python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
|
||||
```
|
||||
|
||||
### Citation
|
||||
|
||||
@@ -216,6 +227,15 @@ If you find this project helpful for your research, please consider citing the f
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@misc{ren2024grounding,
|
||||
title={Grounding DINO 1.5: Advance the "Edge" of Open-Set Object Detection},
|
||||
author={Tianhe Ren and Qing Jiang and Shilong Liu and Zhaoyang Zeng and Wenlong Liu and Han Gao and Hongjie Huang and Zhengyu Ma and Xiaoke Jiang and Yihao Chen and Yuda Xiong and Hao Zhang and Feng Li and Peijun Tang and Kent Yu and Lei Zhang},
|
||||
year={2024},
|
||||
eprint={2405.10300},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
|
||||
@misc{ren2024grounded,
|
||||
title={Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks},
|
||||
author={Tianhe Ren and Shilong Liu and Ailing Zeng and Jing Lin and Kunchang Li and He Cao and Jiayu Chen and Xinyu Huang and Yukang Chen and Feng Yan and Zhaoyang Zeng and Hao Zhang and Feng Li and Jie Yang and Hongyang Li and Qing Jiang and Lei Zhang},
|
||||
|
Binary file not shown.
Binary file not shown.
@@ -29,7 +29,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("device",device)
|
||||
print("device", device)
|
||||
|
||||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||
@@ -68,7 +68,7 @@ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
|
||||
# init video predictor state
|
||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||
step = 10 # the step to sample frames for groundedDino predictor
|
||||
step = 10 # the step to sample frames for Grounding DINO predictor
|
||||
|
||||
sam2_masks = MaskDictionatyModel()
|
||||
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||
@@ -189,10 +189,9 @@ for start_frame_idx in range(0, len(frame_names), step):
|
||||
json.dump(json_data, f)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Step 6: Draw the results and save the video
|
||||
"""
|
||||
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||
|
||||
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
||||
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
216
grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
Normal file
216
grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# dds cloudapi for Grounding DINO 1.5
|
||||
from dds_cloudapi_sdk import Config
|
||||
from dds_cloudapi_sdk import Client
|
||||
from dds_cloudapi_sdk import DetectionTask
|
||||
from dds_cloudapi_sdk import TextPrompt
|
||||
from dds_cloudapi_sdk import DetectionModel
|
||||
from dds_cloudapi_sdk import DetectionTarget
|
||||
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
from utils.video_utils import create_video_from_images
|
||||
from utils.common_utils import CommonUtils
|
||||
from utils.mask_dictionary_model import MaskDictionatyModel, ObjectInfo
|
||||
import json
|
||||
import copy
|
||||
|
||||
"""
|
||||
Step 1: Environment settings and model initialization
|
||||
"""
|
||||
# use bfloat16 for the entire notebook
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# init sam image predictor and video predictor model
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("device", device)
|
||||
|
||||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
||||
|
||||
|
||||
# init grounding dino model from huggingface
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "car."
|
||||
|
||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||
video_dir = "notebooks/videos/car"
|
||||
# 'output_dir' is the directory to save the annotated frames
|
||||
output_dir = "./outputs"
|
||||
# 'output_video_path' is the path to save the final video
|
||||
output_video_path = "./outputs/output.mp4"
|
||||
# create the output directory
|
||||
CommonUtils.creat_dirs(output_dir)
|
||||
mask_data_dir = os.path.join(output_dir, "mask_data")
|
||||
json_data_dir = os.path.join(output_dir, "json_data")
|
||||
result_dir = os.path.join(output_dir, "result")
|
||||
CommonUtils.creat_dirs(mask_data_dir)
|
||||
CommonUtils.creat_dirs(json_data_dir)
|
||||
# scan all the JPEG frame names in this directory
|
||||
frame_names = [
|
||||
p for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
|
||||
# init video predictor state
|
||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||
step = 10 # the step to sample frames for Grounding DINO predictor
|
||||
|
||||
sam2_masks = MaskDictionatyModel()
|
||||
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||
objects_count = 0
|
||||
|
||||
"""
|
||||
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
||||
"""
|
||||
print("Total frames:", len(frame_names))
|
||||
for start_frame_idx in range(0, len(frame_names), step):
|
||||
# prompt grounding dino to get the box coordinates on specific frame
|
||||
print("start_frame_idx", start_frame_idx)
|
||||
# continue
|
||||
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
||||
image = Image.open(img_path)
|
||||
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
||||
mask_dict = MaskDictionatyModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
||||
|
||||
# run Grounding DINO 1.5 on the image
|
||||
|
||||
API_TOKEN_FOR_GD1_5 = "Your API token"
|
||||
|
||||
config = Config(API_TOKEN_FOR_GD1_5)
|
||||
# Step 2: initialize the client
|
||||
client = Client(config)
|
||||
|
||||
image_url = client.upload_file(img_path)
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=text)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
||||
)
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
|
||||
objects = result.objects # the list of detected objects
|
||||
input_boxes = []
|
||||
confidences = []
|
||||
class_names = []
|
||||
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
class_names.append(obj.category)
|
||||
|
||||
input_boxes = np.array(input_boxes)
|
||||
OBJECTS = class_names
|
||||
|
||||
# prompt SAM image predictor to get the mask for the object
|
||||
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||
|
||||
# prompt SAM 2 image predictor to get the mask for the object
|
||||
masks, scores, logits = image_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
# convert the mask shape to (n, H, W)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
"""
|
||||
Step 3: Register each object's positive points to video predictor
|
||||
"""
|
||||
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if mask_dict.promote_type == "mask":
|
||||
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||
else:
|
||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||
|
||||
|
||||
"""
|
||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||
"""
|
||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||
print("objects_count", objects_count)
|
||||
video_predictor.reset_state(inference_state)
|
||||
if len(mask_dict.labels) == 0:
|
||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||
continue
|
||||
video_predictor.reset_state(inference_state)
|
||||
|
||||
for object_id, object_info in mask_dict.labels.items():
|
||||
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
start_frame_idx,
|
||||
object_id,
|
||||
object_info.mask,
|
||||
)
|
||||
|
||||
video_segments = {} # output the following {step} frames tracking masks
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
||||
frame_masks = MaskDictionatyModel()
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
||||
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
||||
object_info.update_box()
|
||||
frame_masks.labels[out_obj_id] = object_info
|
||||
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
||||
frame_masks.mask_height = out_mask.shape[-2]
|
||||
frame_masks.mask_width = out_mask.shape[-1]
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
sam2_masks = copy.deepcopy(frame_masks)
|
||||
|
||||
print("video_segments:", len(video_segments))
|
||||
"""
|
||||
Step 5: save the tracking masks and json files
|
||||
"""
|
||||
for frame_idx, frame_masks_info in video_segments.items():
|
||||
mask = frame_masks_info.labels
|
||||
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
||||
for obj_id, obj_info in mask.items():
|
||||
mask_img[obj_info.mask == True] = obj_id
|
||||
|
||||
mask_img = mask_img.numpy().astype(np.uint16)
|
||||
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
||||
|
||||
json_data = frame_masks_info.to_dict()
|
||||
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
||||
with open(json_data_path, "w") as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
|
||||
"""
|
||||
Step 6: Draw the results and save the video
|
||||
"""
|
||||
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||
|
||||
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
@@ -3,6 +3,7 @@ import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
import supervision as sv
|
||||
import random
|
||||
|
||||
class CommonUtils:
|
||||
@@ -22,6 +23,89 @@ class CommonUtils:
|
||||
except Exception as e:
|
||||
print(f"An error occurred while creating the path: {e}")
|
||||
|
||||
@staticmethod
|
||||
def draw_masks_and_box_with_supervision(raw_image_path, mask_path, json_path, output_path):
|
||||
CommonUtils.creat_dirs(output_path)
|
||||
raw_image_name_list = os.listdir(raw_image_path)
|
||||
raw_image_name_list.sort()
|
||||
for raw_image_name in raw_image_name_list:
|
||||
image_path = os.path.join(raw_image_path, raw_image_name)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError("Image file not found.")
|
||||
# load mask
|
||||
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
|
||||
mask = np.load(mask_npy_path)
|
||||
# color map
|
||||
unique_ids = np.unique(mask)
|
||||
|
||||
# get each mask from unique mask file
|
||||
all_object_masks = []
|
||||
for uid in unique_ids:
|
||||
if uid == 0: # skip background id
|
||||
continue
|
||||
else:
|
||||
object_mask = (mask == uid)
|
||||
all_object_masks.append(object_mask[None])
|
||||
|
||||
# get n masks: (n, h, w)
|
||||
all_object_masks = np.concatenate(all_object_masks, axis=0)
|
||||
|
||||
# load box information
|
||||
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
|
||||
|
||||
all_object_boxes = []
|
||||
all_object_ids = []
|
||||
all_class_names = []
|
||||
object_id_to_name = {}
|
||||
with open(file_path, "r") as file:
|
||||
json_data = json.load(file)
|
||||
for obj_id, obj_item in json_data["labels"].items():
|
||||
# box id
|
||||
instance_id = obj_item["instance_id"]
|
||||
if instance_id not in unique_ids: # not a valid box
|
||||
continue
|
||||
# box coordinates
|
||||
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
|
||||
all_object_boxes.append([x1, y1, x2, y2])
|
||||
# box name
|
||||
class_name = obj_item["class_name"]
|
||||
|
||||
# build id list and id2name mapping
|
||||
all_object_ids.append(instance_id)
|
||||
all_class_names.append(class_name)
|
||||
object_id_to_name[instance_id] = class_name
|
||||
|
||||
# Adjust object id and boxes to ascending order
|
||||
paired_id_and_box = zip(all_object_ids, all_object_boxes)
|
||||
sorted_pair = sorted(paired_id_and_box, key=lambda pair: pair[0])
|
||||
|
||||
# Because we get the mask data as ascending order, so we also need to ascend box and ids
|
||||
all_object_ids = [pair[0] for pair in sorted_pair]
|
||||
all_object_boxes = [pair[1] for pair in sorted_pair]
|
||||
|
||||
detections = sv.Detections(
|
||||
xyxy=np.array(all_object_boxes),
|
||||
mask=all_object_masks,
|
||||
class_id=np.array(all_object_ids, dtype=np.int32),
|
||||
)
|
||||
|
||||
# custom label to show both id and class name
|
||||
labels = [
|
||||
f"{instance_id}: {class_name}" for instance_id, class_name in zip(all_object_ids, all_class_names)
|
||||
]
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections)
|
||||
label_annotator = sv.LabelAnnotator()
|
||||
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=labels)
|
||||
mask_annotator = sv.MaskAnnotator()
|
||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||
|
||||
output_image_path = os.path.join(output_path, raw_image_name)
|
||||
cv2.imwrite(output_image_path, annotated_frame)
|
||||
print(f"Annotated image saved as {output_image_path}")
|
||||
|
||||
@staticmethod
|
||||
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
|
||||
CommonUtils.creat_dirs(output_path)
|
||||
@@ -40,7 +124,7 @@ class CommonUtils:
|
||||
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
|
||||
colors[0] = (0, 0, 0) # background color
|
||||
|
||||
# apply mask to image
|
||||
# apply mask to image in RBG channels
|
||||
colored_mask = np.zeros_like(image)
|
||||
for uid in unique_ids:
|
||||
colored_mask[mask == uid] = colors[uid]
|
||||
|
Reference in New Issue
Block a user