add grounded sam 2 tracking demo
This commit is contained in:
172
grounded_sam2_tracking_demo.py
Normal file
172
grounded_sam2_tracking_demo.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
from track_utils import sample_points_from_masks
|
||||
|
||||
|
||||
"""
|
||||
Step 1: Environment settings and model initialization
|
||||
"""
|
||||
# use bfloat16 for the entire notebook
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# init sam image predictor and video predictor model
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
|
||||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
|
||||
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
||||
|
||||
|
||||
# init grounding dino model from huggingface
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "children."
|
||||
|
||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||
video_dir = "notebooks/videos/bedroom"
|
||||
|
||||
# scan all the JPEG frame names in this directory
|
||||
frame_names = [
|
||||
p for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
|
||||
# init video predictor state
|
||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
|
||||
"""
|
||||
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
|
||||
"""
|
||||
|
||||
# prompt grounding dino to get the box coordinates on specific frame
|
||||
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
|
||||
image = Image.open(img_path)
|
||||
|
||||
# run Grounding DINO on the image
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = grounding_model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=0.25,
|
||||
text_threshold=0.3,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)
|
||||
|
||||
# prompt SAM image predictor to get the mask for the object
|
||||
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||
|
||||
# process the detection results
|
||||
input_boxes = results[0]["boxes"].cpu().numpy()
|
||||
OBJECTS = results[0]["labels"]
|
||||
|
||||
# prompt SAM 2 image predictor to get the mask for the object
|
||||
masks, scores, logits = image_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# convert the mask shape to (n, H, W)
|
||||
if masks.ndim == 3:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
"""
|
||||
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
|
||||
"""
|
||||
|
||||
# sample the positive points from mask for each objects
|
||||
all_sample_points = sample_points_from_masks(masks=masks, num_points=10)
|
||||
|
||||
for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
|
||||
labels = np.ones((points.shape[0]), dtype=np.int32)
|
||||
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points(
|
||||
inference_state=inference_state,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_id=object_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||
"""
|
||||
video_segments = {} # video_segments contains the per-frame segmentation results
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
|
||||
"""
|
||||
Step 5: Visualize the segment results across the video and save them
|
||||
"""
|
||||
|
||||
save_dir = "./tracking_results"
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
|
||||
for frame_idx, segments in video_segments.items():
|
||||
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
|
||||
|
||||
object_ids = list(segments.keys())
|
||||
masks = list(segments.values())
|
||||
masks = np.concatenate(masks, axis=0)
|
||||
|
||||
detections = sv.Detections(
|
||||
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
|
||||
mask=masks, # (n, h, w)
|
||||
class_id=np.array(object_ids, dtype=np.int32),
|
||||
)
|
||||
mask_annotator = sv.MaskAnnotator()
|
||||
annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections)
|
||||
cv2.imwrite(f"annotated_frame_{frame_idx}.jpg", annotated_frame)
|
||||
|
||||
|
||||
# import cv2
|
||||
# import supervision as sv
|
||||
# # visualize each mask
|
||||
# for out_frame_idx, masks in video_segments.items():
|
||||
# img = cv2.imread(os.path.join(video_dir, frame_names[out_frame_idx]))
|
||||
# detections = sv.Detections(
|
||||
# xyxy=np.array([[0, 0, 100, 100]]), # (n, 4)
|
||||
# mask=masks[1]
|
||||
# )
|
||||
# mask_annotator = sv.MaskAnnotator()
|
||||
# annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections)
|
||||
# cv2.imwrite(f"annotated_frame_{out_frame_idx}.jpg", annotated_frame)
|
||||
# import pdb; pdb.set_trace()
|
42
track_utils.py
Normal file
42
track_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
from scipy.ndimage import center_of_mass
|
||||
|
||||
def sample_points_from_masks(masks, num_points):
|
||||
"""
|
||||
sample points from masks and return its absolute coordinates
|
||||
|
||||
Args:
|
||||
masks: np.array with shape (n, h, w)
|
||||
num_points: int
|
||||
|
||||
Returns:
|
||||
points: np.array with shape (n, points, 2)
|
||||
"""
|
||||
n, h, w = masks.shape
|
||||
points = []
|
||||
|
||||
for i in range(n):
|
||||
# 找到当前mask中值为1的位置的坐标
|
||||
indices = np.argwhere(masks[i] == 1)
|
||||
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
|
||||
# we should convert it to (x, y)
|
||||
indices = indices[:, ::-1] # (num_points, [y x]) to (num_points, [x y])
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
if len(indices) == 0:
|
||||
# 如果没有有效点,返回一个空数组
|
||||
points.append(np.array([]))
|
||||
continue
|
||||
|
||||
# 如果mask中的点少于需要的数量,则重复采样
|
||||
if len(indices) < num_points:
|
||||
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
|
||||
else:
|
||||
sampled_indices = np.random.choice(len(indices), num_points, replace=False)
|
||||
|
||||
sampled_points = indices[sampled_indices]
|
||||
points.append(sampled_points)
|
||||
|
||||
# 将结果转换为numpy数组
|
||||
points = np.array(points, dtype=np.float32)
|
||||
return points
|
Reference in New Issue
Block a user