From 379e35cb4042b858dddc6c64d7a8d645db913393 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Thu, 5 Sep 2024 14:57:17 +0800 Subject: [PATCH] support custom video tracking demo with local gd1.0 model --- README.md | 6 + ...mo_custom_video_input_gd1.0_local_model.py | 220 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py diff --git a/README.md b/README.md index a3deafd..5abd22d 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,12 @@ Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and speci python grounded_sam2_tracking_demo_custom_video_input_gd1.0_hf_model.py ``` +If you are not convenient to use huggingface demo, you can also run tracking demo with local grounding dino model with the following scripts: + +```bash +python grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py +``` + ### Grounded SAM 2 Video Object Tracking Demo with Custom Video Input (with Grounding DINO 1.5 & 1.6) Users can upload their own video file (e.g. `assets/hippopotamus.mp4`) and specify their custom text prompts for grounding and tracking with Grounding DINO 1.5 and SAM 2 by using the following scripts: diff --git a/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py b/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py new file mode 100644 index 0000000..81806c1 --- /dev/null +++ b/grounded_sam2_tracking_demo_custom_video_input_gd1.0_local_model.py @@ -0,0 +1,220 @@ +import os +import cv2 +import torch +import numpy as np +import supervision as sv +from torchvision.ops import box_convert +from pathlib import Path +from tqdm import tqdm +from PIL import Image +from sam2.build_sam import build_sam2_video_predictor, build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from grounding_dino.groundingdino.util.inference import load_model, load_image, predict +from utils.track_utils import sample_points_from_masks +from utils.video_utils import create_video_from_images + +""" +Hyperparam for Ground and Tracking +""" +GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py" +GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth" +BOX_THRESHOLD = 0.35 +TEXT_THRESHOLD = 0.25 +VIDEO_PATH = "./assets/hippopotamus.mp4" +TEXT_PROMPT = "hippopotamus." +OUTPUT_VIDEO_PATH = "./hippopotamus_tracking_demo.mp4" +SOURCE_VIDEO_FRAME_DIR = "./custom_video_frames" +SAVE_TRACKING_RESULTS_DIR = "./tracking_results" +PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +""" +Step 1: Environment settings and model initialization for Grounding DINO and SAM 2 +""" +# build grounding dino model from local path +grounding_model = load_model( + model_config_path=GROUNDING_DINO_CONFIG, + model_checkpoint_path=GROUNDING_DINO_CHECKPOINT, + device=DEVICE +) + + +# init sam image predictor and video predictor model +sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" + +video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) +sam2_image_model = build_sam2(model_cfg, sam2_checkpoint) +image_predictor = SAM2ImagePredictor(sam2_image_model) + + +""" +Custom video input directly using video files +""" +video_info = sv.VideoInfo.from_video_path(VIDEO_PATH) # get video info +print(video_info) +frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None) + +# saving video to frames +source_frames = Path(SOURCE_VIDEO_FRAME_DIR) +source_frames.mkdir(parents=True, exist_ok=True) + +with sv.ImageSink( + target_dir_path=source_frames, + overwrite=True, + image_name_pattern="{:05d}.jpg" +) as sink: + for frame in tqdm(frame_generator, desc="Saving Video Frames"): + sink.save_image(frame) + +# scan all the JPEG frame names in this directory +frame_names = [ + p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + +# init video predictor state +inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR) + +ann_frame_idx = 0 # the frame index we interact with +""" +Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates +""" + +# prompt grounding dino to get the box coordinates on specific frame +img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx]) +image_source, image = load_image(img_path) + +boxes, confidences, labels = predict( + model=grounding_model, + image=image, + caption=TEXT_PROMPT, + box_threshold=BOX_THRESHOLD, + text_threshold=TEXT_THRESHOLD, +) + +# process the box prompt for SAM 2 +h, w, _ = image_source.shape +boxes = boxes * torch.Tensor([w, h, w, h]) +input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() +confidences = confidences.numpy().tolist() +class_names = labels + +print(input_boxes) + +# prompt SAM image predictor to get the mask for the object +image_predictor.set_image(image_source) + +# process the detection results +OBJECTS = class_names + +print(OBJECTS) + +# FIXME: figure how does this influence the G-DINO model +torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# prompt SAM 2 image predictor to get the mask for the object +masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, +) +# convert the mask shape to (n, H, W) +if masks.ndim == 4: + masks = masks.squeeze(1) + +""" +Step 3: Register each object's positive points to video predictor with seperate add_new_points call +""" + +assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt" + +# If you are using point prompts, we uniformly sample positive points based on the mask +if PROMPT_TYPE_FOR_VIDEO == "point": + # sample the positive points from mask for each objects + all_sample_points = sample_points_from_masks(masks=masks, num_points=10) + + for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1): + labels = np.ones((points.shape[0]), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + points=points, + labels=labels, + ) +# Using box prompt +elif PROMPT_TYPE_FOR_VIDEO == "box": + for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1): + _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + box=box, + ) +# Using mask prompt is a more straightforward way +elif PROMPT_TYPE_FOR_VIDEO == "mask": + for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1): + labels = np.ones((1), dtype=np.int32) + _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=object_id, + mask=mask + ) +else: + raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts") + + +""" +Step 4: Propagate the video predictor to get the segmentation results for each frame +""" +video_segments = {} # video_segments contains the per-frame segmentation results +for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + +""" +Step 5: Visualize the segment results across the video and save them +""" + +if not os.path.exists(SAVE_TRACKING_RESULTS_DIR): + os.makedirs(SAVE_TRACKING_RESULTS_DIR) + +ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)} + +for frame_idx, segments in video_segments.items(): + img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx])) + + object_ids = list(segments.keys()) + masks = list(segments.values()) + masks = np.concatenate(masks, axis=0) + + detections = sv.Detections( + xyxy=sv.mask_to_xyxy(masks), # (n, 4) + mask=masks, # (n, h, w) + class_id=np.array(object_ids, dtype=np.int32), + ) + box_annotator = sv.BoxAnnotator() + annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) + label_annotator = sv.LabelAnnotator() + annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids]) + mask_annotator = sv.MaskAnnotator() + annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) + cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame) + + +""" +Step 6: Convert the annotated frames to video +""" + +create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH) \ No newline at end of file