From bf57b3086caaa66922b88e072b2e33b18ad2d917 Mon Sep 17 00:00:00 2001 From: Susan Shen <31400000+ShuoShenDe@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:38:30 +0100 Subject: [PATCH] fix: zero object detection error (#64) * update dockerfile * fix: zero object detection error * fix: zero object detection error --- ...d_sam2_tracking_demo_with_continuous_id.py | 110 +++++++++-------- ..._tracking_demo_with_continuous_id_gd1.5.py | 112 ++++++++++-------- ...2_tracking_demo_with_continuous_id_plus.py | 101 ++++++++-------- utils/mask_dictionary_model.py | 19 +++ 4 files changed, 190 insertions(+), 152 deletions(-) diff --git a/grounded_sam2_tracking_demo_with_continuous_id.py b/grounded_sam2_tracking_demo_with_continuous_id.py index 7baac35..4ddba65 100644 --- a/grounded_sam2_tracking_demo_with_continuous_id.py +++ b/grounded_sam2_tracking_demo_with_continuous_id.py @@ -107,70 +107,76 @@ for start_frame_idx in range(0, len(frame_names), step): input_boxes = results[0]["boxes"] # .cpu().numpy() # print("results[0]",results[0]) OBJECTS = results[0]["labels"] + if input_boxes.shape[0] != 0: + # prompt SAM 2 image predictor to get the mask for the object + masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + # convert the mask shape to (n, H, W) + if masks.ndim == 2: + masks = masks[None] + scores = scores[None] + logits = logits[None] + elif masks.ndim == 4: + masks = masks.squeeze(1) - # prompt SAM 2 image predictor to get the mask for the object - masks, scores, logits = image_predictor.predict( - point_coords=None, - point_labels=None, - box=input_boxes, - multimask_output=False, - ) - # convert the mask shape to (n, H, W) - if masks.ndim == 2: - masks = masks[None] - scores = scores[None] - logits = logits[None] - elif masks.ndim == 4: - masks = masks.squeeze(1) + """ + Step 3: Register each object's positive points to video predictor + """ - """ - Step 3: Register each object's positive points to video predictor - """ + # If you are using point prompts, we uniformly sample positive points based on the mask + if mask_dict.promote_type == "mask": + mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + else: + raise NotImplementedError("SAM 2 video predictor only support mask prompts") - # If you are using point prompts, we uniformly sample positive points based on the mask - if mask_dict.promote_type == "mask": - mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + + """ + Step 4: Propagate the video predictor to get the segmentation results for each frame + """ + objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) + print("objects_count", objects_count) else: - raise NotImplementedError("SAM 2 video predictor only support mask prompts") + print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx])) + mask_dict = sam2_masks - - """ - Step 4: Propagate the video predictor to get the segmentation results for each frame - """ - objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) - print("objects_count", objects_count) - video_predictor.reset_state(inference_state) + if len(mask_dict.labels) == 0: + mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step]) print("No object detected in the frame, skip the frame {}".format(start_frame_idx)) continue - video_predictor.reset_state(inference_state) + else: + video_predictor.reset_state(inference_state) - for object_id, object_info in mask_dict.labels.items(): - frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( - inference_state, - start_frame_idx, - object_id, - object_info.mask, - ) - - video_segments = {} # output the following {step} frames tracking masks - for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): - frame_masks = MaskDictionaryModel() + for object_id, object_info in mask_dict.labels.items(): + frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state, + start_frame_idx, + object_id, + object_info.mask, + ) - for i, out_obj_id in enumerate(out_obj_ids): - out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() - object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id)) - object_info.update_box() - frame_masks.labels[out_obj_id] = object_info - image_base_name = frame_names[out_frame_idx].split(".")[0] - frame_masks.mask_name = f"mask_{image_base_name}.npy" - frame_masks.mask_height = out_mask.shape[-2] - frame_masks.mask_width = out_mask.shape[-1] + video_segments = {} # output the following {step} frames tracking masks + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): + frame_masks = MaskDictionaryModel() + + for i, out_obj_id in enumerate(out_obj_ids): + out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() + object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id)) + object_info.update_box() + frame_masks.labels[out_obj_id] = object_info + image_base_name = frame_names[out_frame_idx].split(".")[0] + frame_masks.mask_name = f"mask_{image_base_name}.npy" + frame_masks.mask_height = out_mask.shape[-2] + frame_masks.mask_width = out_mask.shape[-1] - video_segments[out_frame_idx] = frame_masks - sam2_masks = copy.deepcopy(frame_masks) + video_segments[out_frame_idx] = frame_masks + sam2_masks = copy.deepcopy(frame_masks) - print("video_segments:", len(video_segments)) + print("video_segments:", len(video_segments)) """ Step 5: save the tracking masks and json files """ diff --git a/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py b/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py index cd7dd5f..99ea4bc 100644 --- a/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py +++ b/grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py @@ -123,73 +123,81 @@ for start_frame_idx in range(0, len(frame_names), step): input_boxes = np.array(input_boxes) OBJECTS = class_names + if input_boxes.shape[0] != 0: + # prompt SAM image predictor to get the mask for the object + image_predictor.set_image(np.array(image.convert("RGB"))) - # prompt SAM image predictor to get the mask for the object - image_predictor.set_image(np.array(image.convert("RGB"))) + # prompt SAM 2 image predictor to get the mask for the object + masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + # convert the mask shape to (n, H, W) + if masks.ndim == 2: + masks = masks[None] + scores = scores[None] + logits = logits[None] + elif masks.ndim == 4: + masks = masks.squeeze(1) - # prompt SAM 2 image predictor to get the mask for the object - masks, scores, logits = image_predictor.predict( - point_coords=None, - point_labels=None, - box=input_boxes, - multimask_output=False, - ) - # convert the mask shape to (n, H, W) - if masks.ndim == 2: - masks = masks[None] - scores = scores[None] - logits = logits[None] - elif masks.ndim == 4: - masks = masks.squeeze(1) + """ + Step 3: Register each object's positive points to video predictor + """ - """ - Step 3: Register each object's positive points to video predictor - """ + # If you are using point prompts, we uniformly sample positive points based on the mask + if mask_dict.promote_type == "mask": + mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + else: + raise NotImplementedError("SAM 2 video predictor only support mask prompts") - # If you are using point prompts, we uniformly sample positive points based on the mask - if mask_dict.promote_type == "mask": - mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + + + objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) + print("objects_count", objects_count) + else: - raise NotImplementedError("SAM 2 video predictor only support mask prompts") - + print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx])) + mask_dict = sam2_masks + """ - Step 4: Propagate the video predictor to get the segmentation results for each frame + Step 4: Propagate the video predictor to get the segmentation results for each frame """ - objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) - print("objects_count", objects_count) - video_predictor.reset_state(inference_state) if len(mask_dict.labels) == 0: + mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step]) print("No object detected in the frame, skip the frame {}".format(start_frame_idx)) continue - video_predictor.reset_state(inference_state) + else: + video_predictor.reset_state(inference_state) - for object_id, object_info in mask_dict.labels.items(): - frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( - inference_state, - start_frame_idx, - object_id, - object_info.mask, - ) - - video_segments = {} # output the following {step} frames tracking masks - for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): - frame_masks = MaskDictionaryModel() + for object_id, object_info in mask_dict.labels.items(): + frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state, + start_frame_idx, + object_id, + object_info.mask, + ) - for i, out_obj_id in enumerate(out_obj_ids): - out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() - object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id)) - object_info.update_box() - frame_masks.labels[out_obj_id] = object_info - image_base_name = frame_names[out_frame_idx].split(".")[0] - frame_masks.mask_name = f"mask_{image_base_name}.npy" - frame_masks.mask_height = out_mask.shape[-2] - frame_masks.mask_width = out_mask.shape[-1] + video_segments = {} # output the following {step} frames tracking masks + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): + frame_masks = MaskDictionaryModel() + + for i, out_obj_id in enumerate(out_obj_ids): + out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() + object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id)) + object_info.update_box() + frame_masks.labels[out_obj_id] = object_info + image_base_name = frame_names[out_frame_idx].split(".")[0] + frame_masks.mask_name = f"mask_{image_base_name}.npy" + frame_masks.mask_height = out_mask.shape[-2] + frame_masks.mask_width = out_mask.shape[-1] - video_segments[out_frame_idx] = frame_masks - sam2_masks = copy.deepcopy(frame_masks) + video_segments[out_frame_idx] = frame_masks + sam2_masks = copy.deepcopy(frame_masks) - print("video_segments:", len(video_segments)) + print("video_segments:", len(video_segments)) """ Step 5: save the tracking masks and json files """ diff --git a/grounded_sam2_tracking_demo_with_continuous_id_plus.py b/grounded_sam2_tracking_demo_with_continuous_id_plus.py index e023737..8887642 100644 --- a/grounded_sam2_tracking_demo_with_continuous_id_plus.py +++ b/grounded_sam2_tracking_demo_with_continuous_id_plus.py @@ -68,7 +68,7 @@ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) # init video predictor state inference_state = video_predictor.init_state(video_path=video_dir) -step = 10 # the step to sample frames for Grounding DINO predictor +step = 20 # the step to sample frames for Grounding DINO predictor sam2_masks = MaskDictionaryModel() PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point @@ -107,31 +107,34 @@ for start_frame_idx in range(0, len(frame_names), step): input_boxes = results[0]["boxes"] # .cpu().numpy() # print("results[0]",results[0]) OBJECTS = results[0]["labels"] + if input_boxes.shape[0] != 0: - # prompt SAM 2 image predictor to get the mask for the object - masks, scores, logits = image_predictor.predict( - point_coords=None, - point_labels=None, - box=input_boxes, - multimask_output=False, - ) - # convert the mask shape to (n, H, W) - if masks.ndim == 2: - masks = masks[None] - scores = scores[None] - logits = logits[None] - elif masks.ndim == 4: - masks = masks.squeeze(1) - """ - Step 3: Register each object's positive points to video predictor - """ + # prompt SAM 2 image predictor to get the mask for the object + masks, scores, logits = image_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + # convert the mask shape to (n, H, W) + if masks.ndim == 2: + masks = masks[None] + scores = scores[None] + logits = logits[None] + elif masks.ndim == 4: + masks = masks.squeeze(1) + """ + Step 3: Register each object's positive points to video predictor + """ - # If you are using point prompts, we uniformly sample positive points based on the mask - if mask_dict.promote_type == "mask": - mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + # If you are using point prompts, we uniformly sample positive points based on the mask + if mask_dict.promote_type == "mask": + mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS) + else: + raise NotImplementedError("SAM 2 video predictor only support mask prompts") else: - raise NotImplementedError("SAM 2 video predictor only support mask prompts") - + print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx])) + mask_dict = sam2_masks """ Step 4: Propagate the video predictor to get the segmentation results for each frame @@ -139,38 +142,40 @@ for start_frame_idx in range(0, len(frame_names), step): objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) frame_object_count[start_frame_idx] = objects_count print("objects_count", objects_count) - video_predictor.reset_state(inference_state) + if len(mask_dict.labels) == 0: + mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step]) print("No object detected in the frame, skip the frame {}".format(start_frame_idx)) continue - video_predictor.reset_state(inference_state) + else: + video_predictor.reset_state(inference_state) - for object_id, object_info in mask_dict.labels.items(): - frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( - inference_state, - start_frame_idx, - object_id, - object_info.mask, - ) - - video_segments = {} # output the following {step} frames tracking masks - for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): - frame_masks = MaskDictionaryModel() + for object_id, object_info in mask_dict.labels.items(): + frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask( + inference_state, + start_frame_idx, + object_id, + object_info.mask, + ) - for i, out_obj_id in enumerate(out_obj_ids): - out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() - object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id)) - object_info.update_box() - frame_masks.labels[out_obj_id] = object_info - image_base_name = frame_names[out_frame_idx].split(".")[0] - frame_masks.mask_name = f"mask_{image_base_name}.npy" - frame_masks.mask_height = out_mask.shape[-2] - frame_masks.mask_width = out_mask.shape[-1] + video_segments = {} # output the following {step} frames tracking masks + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx): + frame_masks = MaskDictionaryModel() + + for i, out_obj_id in enumerate(out_obj_ids): + out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy() + object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id)) + object_info.update_box() + frame_masks.labels[out_obj_id] = object_info + image_base_name = frame_names[out_frame_idx].split(".")[0] + frame_masks.mask_name = f"mask_{image_base_name}.npy" + frame_masks.mask_height = out_mask.shape[-2] + frame_masks.mask_width = out_mask.shape[-1] - video_segments[out_frame_idx] = frame_masks - sam2_masks = copy.deepcopy(frame_masks) + video_segments[out_frame_idx] = frame_masks + sam2_masks = copy.deepcopy(frame_masks) - print("video_segments:", len(video_segments)) + print("video_segments:", len(video_segments)) """ Step 5: save the tracking masks and json files """ diff --git a/utils/mask_dictionary_model.py b/utils/mask_dictionary_model.py index b916216..6febe62 100644 --- a/utils/mask_dictionary_model.py +++ b/utils/mask_dictionary_model.py @@ -84,6 +84,25 @@ class MaskDictionaryModel: iou = intersection / union return iou + + def save_empty_mask_and_json(self, mask_data_dir, json_data_dir, image_name_list=None): + mask_img = torch.zeros((self.mask_height, self.mask_width)) + if image_name_list: + for image_base_name in image_name_list: + image_base_name = image_base_name.split(".")[0]+".npy" + mask_name = "mask_"+image_base_name + np.save(os.path.join(mask_data_dir, mask_name), mask_img.numpy().astype(np.uint16)) + + json_data_path = os.path.join(json_data_dir, mask_name.replace(".npy", ".json")) + print("save_empty_mask_and_json", json_data_path) + self.to_json(json_data_path) + else: + np.save(os.path.join(mask_data_dir, self.mask_name), mask_img.numpy().astype(np.uint16)) + json_data_path = os.path.join(json_data_dir, self.mask_name.replace(".npy", ".json")) + print("save_empty_mask_and_json", json_data_path) + self.to_json(json_data_path) + + def to_dict(self): return { "mask_name": self.mask_name,