fix: zero object detection error (#64)

* update dockerfile

* fix: zero object detection error

* fix: zero object detection error
This commit is contained in:
Susan Shen
2024-10-30 13:38:30 +01:00
committed by GitHub
parent e537a1e763
commit bf57b3086c
4 changed files with 190 additions and 152 deletions

View File

@@ -107,7 +107,7 @@ for start_frame_idx in range(0, len(frame_names), step):
input_boxes = results[0]["boxes"] # .cpu().numpy()
# print("results[0]",results[0])
OBJECTS = results[0]["labels"]
if input_boxes.shape[0] != 0:
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
@@ -139,10 +139,16 @@ for start_frame_idx in range(0, len(frame_names), step):
"""
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count)
video_predictor.reset_state(inference_state)
else:
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
mask_dict = sam2_masks
if len(mask_dict.labels) == 0:
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
continue
else:
video_predictor.reset_state(inference_state)
for object_id, object_info in mask_dict.labels.items():

View File

@@ -123,7 +123,7 @@ for start_frame_idx in range(0, len(frame_names), step):
input_boxes = np.array(input_boxes)
OBJECTS = class_names
if input_boxes.shape[0] != 0:
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
@@ -153,15 +153,23 @@ for start_frame_idx in range(0, len(frame_names), step):
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count)
else:
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
mask_dict = sam2_masks
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count)
video_predictor.reset_state(inference_state)
if len(mask_dict.labels) == 0:
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
continue
else:
video_predictor.reset_state(inference_state)
for object_id, object_info in mask_dict.labels.items():

View File

@@ -68,7 +68,7 @@ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
step = 10 # the step to sample frames for Grounding DINO predictor
step = 20 # the step to sample frames for Grounding DINO predictor
sam2_masks = MaskDictionaryModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
@@ -107,6 +107,7 @@ for start_frame_idx in range(0, len(frame_names), step):
input_boxes = results[0]["boxes"] # .cpu().numpy()
# print("results[0]",results[0])
OBJECTS = results[0]["labels"]
if input_boxes.shape[0] != 0:
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
@@ -131,7 +132,9 @@ for start_frame_idx in range(0, len(frame_names), step):
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
else:
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
else:
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
mask_dict = sam2_masks
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
@@ -139,10 +142,12 @@ for start_frame_idx in range(0, len(frame_names), step):
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
frame_object_count[start_frame_idx] = objects_count
print("objects_count", objects_count)
video_predictor.reset_state(inference_state)
if len(mask_dict.labels) == 0:
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
continue
else:
video_predictor.reset_state(inference_state)
for object_id, object_info in mask_dict.labels.items():

View File

@@ -84,6 +84,25 @@ class MaskDictionaryModel:
iou = intersection / union
return iou
def save_empty_mask_and_json(self, mask_data_dir, json_data_dir, image_name_list=None):
mask_img = torch.zeros((self.mask_height, self.mask_width))
if image_name_list:
for image_base_name in image_name_list:
image_base_name = image_base_name.split(".")[0]+".npy"
mask_name = "mask_"+image_base_name
np.save(os.path.join(mask_data_dir, mask_name), mask_img.numpy().astype(np.uint16))
json_data_path = os.path.join(json_data_dir, mask_name.replace(".npy", ".json"))
print("save_empty_mask_and_json", json_data_path)
self.to_json(json_data_path)
else:
np.save(os.path.join(mask_data_dir, self.mask_name), mask_img.numpy().astype(np.uint16))
json_data_path = os.path.join(json_data_dir, self.mask_name.replace(".npy", ".json"))
print("save_empty_mask_and_json", json_data_path)
self.to_json(json_data_path)
def to_dict(self):
return {
"mask_name": self.mask_name,