Merge branch 'main' of github.com:IDEA-Research/Grounded-SAM-2 into main
This commit is contained in:
@@ -223,6 +223,13 @@ If you want to try `Grounding DINO 1.5` model, you can run the following scripts
|
|||||||
python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
|
python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Grounded-SAM-2 Video Object Tracking with Continuous ID plus Reverse Tracking(with Grounding DINO)
|
||||||
|
This method could simply cover the whole lifetime of the object
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_tracking_demo_with_continuous_id_plus.py
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## Grounded SAM 2 Florence-2 Demos
|
## Grounded SAM 2 Florence-2 Demos
|
||||||
### Grounded SAM 2 Florence-2 Image Demo (Updating)
|
### Grounded SAM 2 Florence-2 Image Demo (Updating)
|
||||||
|
|
||||||
|
@@ -62,13 +62,13 @@ CommonUtils.creat_dirs(json_data_dir)
|
|||||||
# scan all the JPEG frame names in this directory
|
# scan all the JPEG frame names in this directory
|
||||||
frame_names = [
|
frame_names = [
|
||||||
p for p in os.listdir(video_dir)
|
p for p in os.listdir(video_dir)
|
||||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
]
|
]
|
||||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
|
||||||
# init video predictor state
|
# init video predictor state
|
||||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
inference_state = video_predictor.init_state(video_path=video_dir, offload_video_to_cpu=True, async_loading_frames=True)
|
||||||
step = 10 # the step to sample frames for Grounding DINO predictor
|
step = 20 # the step to sample frames for Grounding DINO predictor
|
||||||
|
|
||||||
sam2_masks = MaskDictionaryModel()
|
sam2_masks = MaskDictionaryModel()
|
||||||
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||||
@@ -194,4 +194,4 @@ Step 6: Draw the results and save the video
|
|||||||
"""
|
"""
|
||||||
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||||
|
|
||||||
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
@@ -68,7 +68,7 @@ CommonUtils.creat_dirs(json_data_dir)
|
|||||||
# scan all the JPEG frame names in this directory
|
# scan all the JPEG frame names in this directory
|
||||||
frame_names = [
|
frame_names = [
|
||||||
p for p in os.listdir(video_dir)
|
p for p in os.listdir(video_dir)
|
||||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
]
|
]
|
||||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
|
||||||
|
242
grounded_sam2_tracking_demo_with_continuous_id_plus.py
Normal file
242
grounded_sam2_tracking_demo_with_continuous_id_plus.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
from PIL import Image
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||||
|
from utils.track_utils import sample_points_from_masks
|
||||||
|
from utils.video_utils import create_video_from_images
|
||||||
|
from utils.common_utils import CommonUtils
|
||||||
|
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# This demo shows the continuous object tracking plus reverse tracking with Grounding DINO and SAM 2
|
||||||
|
"""
|
||||||
|
Step 1: Environment settings and model initialization
|
||||||
|
"""
|
||||||
|
# use bfloat16 for the entire notebook
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# init sam image predictor and video predictor model
|
||||||
|
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||||
|
model_cfg = "sam2_hiera_l.yaml"
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print("device", device)
|
||||||
|
|
||||||
|
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||||
|
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||||
|
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
||||||
|
|
||||||
|
|
||||||
|
# init grounding dino model from huggingface
|
||||||
|
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||||
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||||
|
# VERY important: text queries need to be lowercased + end with a dot
|
||||||
|
text = "car."
|
||||||
|
|
||||||
|
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||||
|
video_dir = "notebooks/videos/car"
|
||||||
|
# 'output_dir' is the directory to save the annotated frames
|
||||||
|
output_dir = "outputs"
|
||||||
|
# 'output_video_path' is the path to save the final video
|
||||||
|
output_video_path = "./outputs/output.mp4"
|
||||||
|
# create the output directory
|
||||||
|
mask_data_dir = os.path.join(output_dir, "mask_data")
|
||||||
|
json_data_dir = os.path.join(output_dir, "json_data")
|
||||||
|
result_dir = os.path.join(output_dir, "result")
|
||||||
|
CommonUtils.creat_dirs(mask_data_dir)
|
||||||
|
CommonUtils.creat_dirs(json_data_dir)
|
||||||
|
# scan all the JPEG frame names in this directory
|
||||||
|
frame_names = [
|
||||||
|
p for p in os.listdir(video_dir)
|
||||||
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
|
]
|
||||||
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
|
||||||
|
# init video predictor state
|
||||||
|
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||||
|
step = 10 # the step to sample frames for Grounding DINO predictor
|
||||||
|
|
||||||
|
sam2_masks = MaskDictionaryModel()
|
||||||
|
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||||
|
objects_count = 0
|
||||||
|
frame_object_count = {}
|
||||||
|
"""
|
||||||
|
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
||||||
|
"""
|
||||||
|
print("Total frames:", len(frame_names))
|
||||||
|
for start_frame_idx in range(0, len(frame_names), step):
|
||||||
|
# prompt grounding dino to get the box coordinates on specific frame
|
||||||
|
print("start_frame_idx", start_frame_idx)
|
||||||
|
# continue
|
||||||
|
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
||||||
|
mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
||||||
|
|
||||||
|
# run Grounding DINO on the image
|
||||||
|
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = grounding_model(**inputs)
|
||||||
|
|
||||||
|
results = processor.post_process_grounded_object_detection(
|
||||||
|
outputs,
|
||||||
|
inputs.input_ids,
|
||||||
|
box_threshold=0.25,
|
||||||
|
text_threshold=0.25,
|
||||||
|
target_sizes=[image.size[::-1]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# prompt SAM image predictor to get the mask for the object
|
||||||
|
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||||
|
|
||||||
|
# process the detection results
|
||||||
|
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
||||||
|
# print("results[0]",results[0])
|
||||||
|
OBJECTS = results[0]["labels"]
|
||||||
|
|
||||||
|
# prompt SAM 2 image predictor to get the mask for the object
|
||||||
|
masks, scores, logits = image_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
# convert the mask shape to (n, H, W)
|
||||||
|
if masks.ndim == 2:
|
||||||
|
masks = masks[None]
|
||||||
|
scores = scores[None]
|
||||||
|
logits = logits[None]
|
||||||
|
elif masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
"""
|
||||||
|
Step 3: Register each object's positive points to video predictor
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||||
|
if mask_dict.promote_type == "mask":
|
||||||
|
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||||
|
"""
|
||||||
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||||
|
frame_object_count[start_frame_idx] = objects_count
|
||||||
|
print("objects_count", objects_count)
|
||||||
|
video_predictor.reset_state(inference_state)
|
||||||
|
if len(mask_dict.labels) == 0:
|
||||||
|
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||||
|
continue
|
||||||
|
video_predictor.reset_state(inference_state)
|
||||||
|
|
||||||
|
for object_id, object_info in mask_dict.labels.items():
|
||||||
|
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||||
|
inference_state,
|
||||||
|
start_frame_idx,
|
||||||
|
object_id,
|
||||||
|
object_info.mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_segments = {} # output the following {step} frames tracking masks
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
||||||
|
frame_masks = MaskDictionaryModel()
|
||||||
|
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids):
|
||||||
|
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
||||||
|
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))
|
||||||
|
object_info.update_box()
|
||||||
|
frame_masks.labels[out_obj_id] = object_info
|
||||||
|
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||||
|
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
||||||
|
frame_masks.mask_height = out_mask.shape[-2]
|
||||||
|
frame_masks.mask_width = out_mask.shape[-1]
|
||||||
|
|
||||||
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
sam2_masks = copy.deepcopy(frame_masks)
|
||||||
|
|
||||||
|
print("video_segments:", len(video_segments))
|
||||||
|
"""
|
||||||
|
Step 5: save the tracking masks and json files
|
||||||
|
"""
|
||||||
|
for frame_idx, frame_masks_info in video_segments.items():
|
||||||
|
mask = frame_masks_info.labels
|
||||||
|
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
||||||
|
for obj_id, obj_info in mask.items():
|
||||||
|
mask_img[obj_info.mask == True] = obj_id
|
||||||
|
|
||||||
|
mask_img = mask_img.numpy().astype(np.uint16)
|
||||||
|
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
||||||
|
|
||||||
|
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
||||||
|
frame_masks_info.to_json(json_data_path)
|
||||||
|
|
||||||
|
|
||||||
|
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||||
|
|
||||||
|
print("try reverse tracking")
|
||||||
|
start_object_id = 0
|
||||||
|
object_info_dict = {}
|
||||||
|
for frame_idx, current_object_count in frame_object_count.items():
|
||||||
|
print("reverse tracking frame", frame_idx, frame_names[frame_idx])
|
||||||
|
if frame_idx != 0:
|
||||||
|
video_predictor.reset_state(inference_state)
|
||||||
|
image_base_name = frame_names[frame_idx].split(".")[0]
|
||||||
|
json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
|
||||||
|
json_data = MaskDictionaryModel().from_json(json_data_path)
|
||||||
|
mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
|
||||||
|
mask_array = np.load(mask_data_path)
|
||||||
|
for object_id in range(start_object_id+1, current_object_count+1):
|
||||||
|
print("reverse tracking object", object_id)
|
||||||
|
object_info_dict[object_id] = json_data.labels[object_id]
|
||||||
|
video_predictor.add_new_mask(inference_state, frame_idx, object_id, mask_array == object_id)
|
||||||
|
start_object_id = current_object_count
|
||||||
|
|
||||||
|
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step*2, start_frame_idx=frame_idx, reverse=True):
|
||||||
|
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||||
|
json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json")
|
||||||
|
json_data = MaskDictionaryModel().from_json(json_data_path)
|
||||||
|
mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy")
|
||||||
|
mask_array = np.load(mask_data_path)
|
||||||
|
# merge the reverse tracking masks with the original masks
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids):
|
||||||
|
out_mask = (out_mask_logits[i] > 0.0).cpu()
|
||||||
|
if out_mask.sum() == 0:
|
||||||
|
print("no mask for object", out_obj_id, "at frame", out_frame_idx)
|
||||||
|
continue
|
||||||
|
object_info = object_info_dict[out_obj_id]
|
||||||
|
object_info.mask = out_mask[0]
|
||||||
|
object_info.update_box()
|
||||||
|
json_data.labels[out_obj_id] = object_info
|
||||||
|
mask_array = np.where(mask_array != out_obj_id, mask_array, 0)
|
||||||
|
mask_array[object_info.mask] = out_obj_id
|
||||||
|
|
||||||
|
np.save(mask_data_path, mask_array)
|
||||||
|
json_data.to_json(json_data_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Step 6: Draw the results and save the video
|
||||||
|
"""
|
||||||
|
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir+"_reverse")
|
||||||
|
|
||||||
|
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
@@ -44,7 +44,7 @@ video_dir = "notebooks/videos/bedroom"
|
|||||||
# scan all the JPEG frame names in this directory
|
# scan all the JPEG frame names in this directory
|
||||||
frame_names = [
|
frame_names = [
|
||||||
p for p in os.listdir(video_dir)
|
p for p in os.listdir(video_dir)
|
||||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
]
|
]
|
||||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
|
||||||
|
@@ -184,7 +184,7 @@ def load_video_frames(
|
|||||||
frame_names = [
|
frame_names = [
|
||||||
p
|
p
|
||||||
for p in os.listdir(jpg_folder)
|
for p in os.listdir(jpg_folder)
|
||||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
]
|
]
|
||||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
num_frames = len(frame_names)
|
num_frames = len(frame_names)
|
||||||
|
@@ -67,7 +67,9 @@ class MaskDictionaryModel:
|
|||||||
def get_target_class_name(self, instance_id):
|
def get_target_class_name(self, instance_id):
|
||||||
return self.labels[instance_id].class_name
|
return self.labels[instance_id].class_name
|
||||||
|
|
||||||
|
def get_target_logit(self, instance_id):
|
||||||
|
return self.labels[instance_id].logit
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_iou(mask1, mask2):
|
def calculate_iou(mask1, mask2):
|
||||||
# Convert masks to float tensors for calculations
|
# Convert masks to float tensors for calculations
|
||||||
@@ -90,7 +92,20 @@ class MaskDictionaryModel:
|
|||||||
"promote_type": self.promote_type,
|
"promote_type": self.promote_type,
|
||||||
"labels": {k: v.to_dict() for k, v in self.labels.items()}
|
"labels": {k: v.to_dict() for k, v in self.labels.items()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def to_json(self, json_file):
|
||||||
|
with open(json_file, "w") as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=4)
|
||||||
|
|
||||||
|
def from_json(self, json_file):
|
||||||
|
with open(json_file, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.mask_name = data["mask_name"]
|
||||||
|
self.mask_height = data["mask_height"]
|
||||||
|
self.mask_width = data["mask_width"]
|
||||||
|
self.promote_type = data["promote_type"]
|
||||||
|
self.labels = {int(k): ObjectInfo(**v) for k, v in data["labels"].items()}
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@@ -4,7 +4,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
def create_video_from_images(image_folder, output_video_path, frame_rate=25):
|
def create_video_from_images(image_folder, output_video_path, frame_rate=25):
|
||||||
# define valid extension
|
# define valid extension
|
||||||
valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
|
||||||
|
|
||||||
# get all image files in the folder
|
# get all image files in the folder
|
||||||
image_files = [f for f in os.listdir(image_folder)
|
image_files = [f for f in os.listdir(image_folder)
|
||||||
|
Reference in New Issue
Block a user