From 245fc7206b07046dabf0b9309b79caf7b7c18fca Mon Sep 17 00:00:00 2001 From: SusanSHEN Date: Fri, 16 Aug 2024 01:46:41 +0200 Subject: [PATCH] feat: continuous_id_plus --- README.md | 7 + ...d_sam2_tracking_demo_with_continuous_id.py | 8 +- ..._tracking_demo_with_continuous_id_gd1.5.py | 2 +- ...2_tracking_demo_with_continuous_id_plus.py | 242 ++++++++++++++++++ grounded_sam2_tracking_demo_with_gd1.5.py | 2 +- sam2/utils/misc.py | 2 +- utils/mask_dictionary_model.py | 19 +- utils/video_utils.py | 2 +- 8 files changed, 274 insertions(+), 10 deletions(-) create mode 100644 grounded_sam2_tracking_demo_with_continuous_id_plus.py diff --git a/README.md b/README.md index 8e3fe51..4f3e9e0 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,13 @@ If you want to try `Grounding DINO 1.5` model, you can run the following scripts python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py ``` +### Grounded-SAM-2 Video Object Tracking with Continuous ID plus Reverse Tracking(with Grounding DINO) +This method could simply cover the whole lifetime of the object +```bash +python grounded_sam2_tracking_demo_with_continuous_id_plus.py + +``` + ## Grounded SAM 2 Florence-2 Demos ### Grounded SAM 2 Florence-2 Image Demo (Updating) diff --git a/grounded_sam2_tracking_demo_with_continuous_id.py b/grounded_sam2_tracking_demo_with_continuous_id.py index 915ae8c..5c354b9 100644 --- a/grounded_sam2_tracking_demo_with_continuous_id.py +++ b/grounded_sam2_tracking_demo_with_continuous_id.py @@ -62,13 +62,13 @@ CommonUtils.creat_dirs(json_data_dir) # scan all the JPEG frame names in this directory frame_names = [ p for p in os.listdir(video_dir) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) # init video predictor state -inference_state = video_predictor.init_state(video_path=video_dir) -step = 10 # the step to sample frames for Grounding DINO predictor +inference_state = video_predictor.init_state(video_path=video_dir, offload_video_to_cpu=True, async_loading_frames=True) +step = 20 # the step to sample frames for Grounding DINO predictor sam2_masks = MaskDictionaryModel() PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point @@ -194,4 +194,4 @@ Step 6: Draw the results and save the video """ CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir) -create_video_from_images(result_dir, output_video_path, frame_rate=30) \ No newline at end of file +create_video_from_images(result_dir, output_video_path, frame_rate=15) \ No newline at end of file diff --git a/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py b/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py index 659645c..7ea2302 100644 --- a/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py +++ b/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py @@ -68,7 +68,7 @@ CommonUtils.creat_dirs(json_data_dir) # scan all the JPEG frame names in this directory frame_names = [ p for p in os.listdir(video_dir) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) diff --git a/grounded_sam2_tracking_demo_with_continuous_id_plus.py b/grounded_sam2_tracking_demo_with_continuous_id_plus.py new file mode 100644 index 0000000..4273090 --- /dev/null +++ b/grounded_sam2_tracking_demo_with_continuous_id_plus.py @@ -0,0 +1,242 @@ +import os +import cv2 +import torch +import numpy as np +import supervision as sv +from PIL import Image +from sam2.build_sam import build_sam2_video_predictor, build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection +from utils.track_utils import sample_points_from_masks +from utils.video_utils import create_video_from_images +from utils.common_utils import CommonUtils +from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo +import json +import copy + +# This demo shows the continuous object tracking plus reverse tracking with Grounding DINO and SAM 2 +""" +Step 1: Environment settings and model initialization +""" +# use bfloat16 for the entire notebook +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# init sam image predictor and video predictor model +sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +device = "cuda" if torch.cuda.is_available() else "cpu" +print("device", device) + +video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) +sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device) +image_predictor = SAM2ImagePredictor(sam2_image_model) + + +# init grounding dino model from huggingface +model_id = "IDEA-Research/grounding-dino-tiny" +processor = AutoProcessor.from_pretrained(model_id) +grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) + + +# setup the input image and text prompt for SAM 2 and Grounding DINO +# VERY important: text queries need to be lowercased + end with a dot +text = "car." + +# `video_dir` a directory of JPEG frames with filenames like `.jpg` +video_dir = "notebooks/videos/car" +# 'output_dir' is the directory to save the annotated frames +output_dir = "outputs" +# 'output_video_path' is the path to save the final video +output_video_path = "./outputs/output.mp4" +# create the output directory +mask_data_dir = os.path.join(output_dir, "mask_data") +json_data_dir = os.path.join(output_dir, "json_data") +result_dir = os.path.join(output_dir, "result") +CommonUtils.creat_dirs(mask_data_dir) +CommonUtils.creat_dirs(json_data_dir) +# scan all the JPEG frame names in this directory +frame_names = [ + p for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + +# init video predictor state +inference_state = video_predictor.init_state(video_path=video_dir) +step = 10 # the step to sample frames for Grounding DINO predictor + +sam2_masks = MaskDictionaryModel() +PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point +objects_count = 0 +frame_object_count = {} +""" +Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames +""" +print("Total frames:", len(frame_names)) +for start_frame_idx in range(0, len(frame_names), step): +# prompt grounding dino to get the box coordinates on specific frame + print("start_frame_idx", start_frame_idx) + # continue + img_path = os.path.join(video_dir, frame_names[start_frame_idx]) + image = Image.open(img_path).convert("RGB") + image_base_name = frame_names[start_frame_idx].split(".")[0] + mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy") + + # run Grounding DINO on the image + inputs = processor(images=image, text=text, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = grounding_model(**inputs) + + results = processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=0.25, + text_threshold=0.25, + target_sizes=[image.size[::-1]] + ) + + # prompt SAM image predictor to get the mask for the object + image_predictor.set_image(np.array(image.convert("RGB"))) + + # process the detection results + input_boxes = results[0]["boxes"] # .cpu().numpy() + # print("results[0]",results[0]) + OBJECTS = results[0]["labels"] + + # prompt SAM 2 image predictor to get the mask for the object + masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + # convert the mask shape to (n, H, W) + if masks.ndim == 2: + masks = masks[None] + scores = scores[None] + logits = logits[None] + elif masks.ndim == 4: + masks = masks.squeeze(1) + """ + Step 3: Register each object's positive points to video predictor + """ + + # If you are using point prompts, we uniformly sample positive points based on the mask + if mask_dict.promote_type == "mask": + mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + else: + raise NotImplementedError("SAM 2 video predictor only support mask prompts") + + + """ + Step 4: Propagate the video predictor to get the segmentation results for each frame + """ + objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) + frame_object_count[start_frame_idx] = objects_count + print("objects_count", objects_count) + video_predictor.reset_state(inference_state) + if len(mask_dict.labels) == 0: + print("No object detected in the frame, skip the frame {}".format(start_frame_idx)) + continue + video_predictor.reset_state(inference_state) + + for object_id, object_info in mask_dict.labels.items(): + frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state, + start_frame_idx, + object_id, + object_info.mask, + ) + + video_segments = {} # output the following {step} frames tracking masks + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): + frame_masks = MaskDictionaryModel() + + for i, out_obj_id in enumerate(out_obj_ids): + out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() + object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id)) + object_info.update_box() + frame_masks.labels[out_obj_id] = object_info + image_base_name = frame_names[out_frame_idx].split(".")[0] + frame_masks.mask_name = f"mask_{image_base_name}.npy" + frame_masks.mask_height = out_mask.shape[-2] + frame_masks.mask_width = out_mask.shape[-1] + + video_segments[out_frame_idx] = frame_masks + sam2_masks = copy.deepcopy(frame_masks) + + print("video_segments:", len(video_segments)) + """ + Step 5: save the tracking masks and json files + """ + for frame_idx, frame_masks_info in video_segments.items(): + mask = frame_masks_info.labels + mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width) + for obj_id, obj_info in mask.items(): + mask_img[obj_info.mask == True] = obj_id + + mask_img = mask_img.numpy().astype(np.uint16) + np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img) + + json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json")) + frame_masks_info.to_json(json_data_path) + + +CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir) + +print("try reverse tracking") +start_object_id = 0 +object_info_dict = {} +for frame_idx, current_object_count in frame_object_count.items(): + print("reverse tracking frame", frame_idx, frame_names[frame_idx]) + if frame_idx != 0: + video_predictor.reset_state(inference_state) + image_base_name = frame_names[frame_idx].split(".")[0] + json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json") + json_data = MaskDictionaryModel().from_json(json_data_path) + mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy") + mask_array = np.load(mask_data_path) + for object_id in range(start_object_id+1, current_object_count+1): + print("reverse tracking object", object_id) + object_info_dict[object_id] = json_data.labels[object_id] + video_predictor.add_new_mask(inference_state, frame_idx, object_id, mask_array == object_id) + start_object_id = current_object_count + + + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step*2, start_frame_idx=frame_idx, reverse=True): + image_base_name = frame_names[out_frame_idx].split(".")[0] + json_data_path = os.path.join(json_data_dir, f"mask_{image_base_name}.json") + json_data = MaskDictionaryModel().from_json(json_data_path) + mask_data_path = os.path.join(mask_data_dir, f"mask_{image_base_name}.npy") + mask_array = np.load(mask_data_path) + # merge the reverse tracking masks with the original masks + for i, out_obj_id in enumerate(out_obj_ids): + out_mask = (out_mask_logits[i] > 0.0).cpu() + if out_mask.sum() == 0: + print("no mask for object", out_obj_id, "at frame", out_frame_idx) + continue + object_info = object_info_dict[out_obj_id] + object_info.mask = out_mask[0] + object_info.update_box() + json_data.labels[out_obj_id] = object_info + mask_array = np.where(mask_array != out_obj_id, mask_array, 0) + mask_array[object_info.mask] = out_obj_id + + np.save(mask_data_path, mask_array) + json_data.to_json(json_data_path) + + + + + +""" +Step 6: Draw the results and save the video +""" +CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir+"_reverse") + +create_video_from_images(result_dir, output_video_path, frame_rate=15) \ No newline at end of file diff --git a/grounded_sam2_tracking_demo_with_gd1.5.py b/grounded_sam2_tracking_demo_with_gd1.5.py index 8be2daa..325deda 100644 --- a/grounded_sam2_tracking_demo_with_gd1.5.py +++ b/grounded_sam2_tracking_demo_with_gd1.5.py @@ -44,7 +44,7 @@ video_dir = "notebooks/videos/bedroom" # scan all the JPEG frame names in this directory frame_names = [ p for p in os.listdir(video_dir) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index df97b4a..3c5ff13 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -184,7 +184,7 @@ def load_video_frames( frame_names = [ p for p in os.listdir(jpg_folder) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) num_frames = len(frame_names) diff --git a/utils/mask_dictionary_model.py b/utils/mask_dictionary_model.py index d98551e..b916216 100644 --- a/utils/mask_dictionary_model.py +++ b/utils/mask_dictionary_model.py @@ -67,7 +67,9 @@ class MaskDictionaryModel: def get_target_class_name(self, instance_id): return self.labels[instance_id].class_name - + def get_target_logit(self, instance_id): + return self.labels[instance_id].logit + @staticmethod def calculate_iou(mask1, mask2): # Convert masks to float tensors for calculations @@ -90,7 +92,20 @@ class MaskDictionaryModel: "promote_type": self.promote_type, "labels": {k: v.to_dict() for k, v in self.labels.items()} } - + + def to_json(self, json_file): + with open(json_file, "w") as f: + json.dump(self.to_dict(), f, indent=4) + + def from_json(self, json_file): + with open(json_file, "r") as f: + data = json.load(f) + self.mask_name = data["mask_name"] + self.mask_height = data["mask_height"] + self.mask_width = data["mask_width"] + self.promote_type = data["promote_type"] + self.labels = {int(k): ObjectInfo(**v) for k, v in data["labels"].items()} + return self @dataclass diff --git a/utils/video_utils.py b/utils/video_utils.py index d844f04..a37b9b3 100644 --- a/utils/video_utils.py +++ b/utils/video_utils.py @@ -4,7 +4,7 @@ from tqdm import tqdm def create_video_from_images(image_folder, output_video_path, frame_rate=25): # define valid extension - valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"] + valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"] # get all image files in the folder image_files = [f for f in os.listdir(image_folder)