From bf450b6b4145020337f2438c827fb55cb61145f7 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Fri, 2 Aug 2024 15:46:31 +0800 Subject: [PATCH] add grounded sam 2 tracking demo --- grounded_sam2_tracking_demo.py | 172 +++++++++++++++++++++++++++++++++ track_utils.py | 42 ++++++++ 2 files changed, 214 insertions(+) create mode 100644 grounded_sam2_tracking_demo.py create mode 100644 track_utils.py diff --git a/grounded_sam2_tracking_demo.py b/grounded_sam2_tracking_demo.py new file mode 100644 index 0000000..c3cc813 --- /dev/null +++ b/grounded_sam2_tracking_demo.py @@ -0,0 +1,172 @@ +import os +import cv2 +import torch +import numpy as np +import supervision as sv +from PIL import Image +from sam2.build_sam import build_sam2_video_predictor, build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection +from track_utils import sample_points_from_masks + + +""" +Step 1: Environment settings and model initialization +""" +# use bfloat16 for the entire notebook +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# init sam image predictor and video predictor model +sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" + +video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) +sam2_image_model = build_sam2(model_cfg, sam2_checkpoint) +image_predictor = SAM2ImagePredictor(sam2_image_model) + + +# init grounding dino model from huggingface +model_id = "IDEA-Research/grounding-dino-tiny" +device = "cuda" if torch.cuda.is_available() else "cpu" +processor = AutoProcessor.from_pretrained(model_id) +grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) + + +# setup the input image and text prompt for SAM 2 and Grounding DINO +# VERY important: text queries need to be lowercased + end with a dot +text = "children." + +# `video_dir` a directory of JPEG frames with filenames like `.jpg` +video_dir = "notebooks/videos/bedroom" + +# scan all the JPEG frame names in this directory +frame_names = [ + p for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + +# init video predictor state +inference_state = video_predictor.init_state(video_path=video_dir) + +ann_frame_idx = 0 # the frame index we interact with +ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + +""" +Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame +""" + +# prompt grounding dino to get the box coordinates on specific frame +img_path = os.path.join(video_dir, frame_names[ann_frame_idx]) +image = Image.open(img_path) + +# run Grounding DINO on the image +inputs = processor(images=image, text=text, return_tensors="pt").to(device) +with torch.no_grad(): + outputs = grounding_model(**inputs) + +results = processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=0.25, + text_threshold=0.3, + target_sizes=[image.size[::-1]] +) + +# prompt SAM image predictor to get the mask for the object +image_predictor.set_image(np.array(image.convert("RGB"))) + +# process the detection results +input_boxes = results[0]["boxes"].cpu().numpy() +OBJECTS = results[0]["labels"] + +# prompt SAM 2 image predictor to get the mask for the object +masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, +) + +# convert the mask shape to (n, H, W) +if masks.ndim == 3: + masks = masks[None] + scores = scores[None] + logits = logits[None] +elif masks.ndim == 4: + masks = masks.squeeze(1) + +""" +Step 3: Register each object's positive points to video predictor with seperate add_new_points call +""" + +# sample the positive points from mask for each objects +all_sample_points = sample_points_from_masks(masks=masks, num_points=10) + +for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): + labels = np.ones((points.shape[0]), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + points=points, + labels=labels, + ) + + +""" +Step 4: Propagate the video predictor to get the segmentation results for each frame +""" +video_segments = {} # video_segments contains the per-frame segmentation results +for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + +""" +Step 5: Visualize the segment results across the video and save them +""" + +save_dir = "./tracking_results" + +if not os.path.exists(save_dir): + os.makedirs(save_dir) + +ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)} +for frame_idx, segments in video_segments.items(): + img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx])) + + object_ids = list(segments.keys()) + masks = list(segments.values()) + masks = np.concatenate(masks, axis=0) + + detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks), # (n, 4) + mask=masks, # (n, h, w) + class_id=np.array(object_ids, dtype=np.int32), + ) + mask_annotator = sv.MaskAnnotator() + annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections) + cv2.imwrite(f"annotated_frame_{frame_idx}.jpg", annotated_frame) + + +# import cv2 +# import supervision as sv +# # visualize each mask +# for out_frame_idx, masks in video_segments.items(): +# img = cv2.imread(os.path.join(video_dir, frame_names[out_frame_idx])) +# detections = sv.Detections( +# xyxy=np.array([[0, 0, 100, 100]]), # (n, 4) +# mask=masks[1] +# ) +# mask_annotator = sv.MaskAnnotator() +# annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections) +# cv2.imwrite(f"annotated_frame_{out_frame_idx}.jpg", annotated_frame) +# import pdb; pdb.set_trace() diff --git a/track_utils.py b/track_utils.py new file mode 100644 index 0000000..f88fd53 --- /dev/null +++ b/track_utils.py @@ -0,0 +1,42 @@ +import numpy as np +from scipy.ndimage import center_of_mass + +def sample_points_from_masks(masks, num_points): + """ + sample points from masks and return its absolute coordinates + + Args: + masks: np.array with shape (n, h, w) + num_points: int + + Returns: + points: np.array with shape (n, points, 2) + """ + n, h, w = masks.shape + points = [] + + for i in range(n): + # 找到当前mask中值为1的位置的坐标 + indices = np.argwhere(masks[i] == 1) + # the output format of np.argwhere is (y, x) and the shape is (num_points, 2) + # we should convert it to (x, y) + indices = indices[:, ::-1] # (num_points, [y x]) to (num_points, [x y]) + + # import pdb; pdb.set_trace() + if len(indices) == 0: + # 如果没有有效点,返回一个空数组 + points.append(np.array([])) + continue + + # 如果mask中的点少于需要的数量,则重复采样 + if len(indices) < num_points: + sampled_indices = np.random.choice(len(indices), num_points, replace=True) + else: + sampled_indices = np.random.choice(len(indices), num_points, replace=False) + + sampled_points = indices[sampled_indices] + points.append(sampled_points) + + # 将结果转换为numpy数组 + points = np.array(points, dtype=np.float32) + return points