fix: zero object detection error (#64)
* update dockerfile * fix: zero object detection error * fix: zero object detection error
This commit is contained in:
@@ -107,7 +107,7 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
||||||
# print("results[0]",results[0])
|
# print("results[0]",results[0])
|
||||||
OBJECTS = results[0]["labels"]
|
OBJECTS = results[0]["labels"]
|
||||||
|
if input_boxes.shape[0] != 0:
|
||||||
# prompt SAM 2 image predictor to get the mask for the object
|
# prompt SAM 2 image predictor to get the mask for the object
|
||||||
masks, scores, logits = image_predictor.predict(
|
masks, scores, logits = image_predictor.predict(
|
||||||
point_coords=None,
|
point_coords=None,
|
||||||
@@ -139,10 +139,16 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
"""
|
"""
|
||||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||||
print("objects_count", objects_count)
|
print("objects_count", objects_count)
|
||||||
video_predictor.reset_state(inference_state)
|
else:
|
||||||
|
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
||||||
|
mask_dict = sam2_masks
|
||||||
|
|
||||||
|
|
||||||
if len(mask_dict.labels) == 0:
|
if len(mask_dict.labels) == 0:
|
||||||
|
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
||||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
video_predictor.reset_state(inference_state)
|
video_predictor.reset_state(inference_state)
|
||||||
|
|
||||||
for object_id, object_info in mask_dict.labels.items():
|
for object_id, object_info in mask_dict.labels.items():
|
||||||
|
@@ -123,7 +123,7 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
OBJECTS = class_names
|
OBJECTS = class_names
|
||||||
|
if input_boxes.shape[0] != 0:
|
||||||
# prompt SAM image predictor to get the mask for the object
|
# prompt SAM image predictor to get the mask for the object
|
||||||
image_predictor.set_image(np.array(image.convert("RGB")))
|
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||||
|
|
||||||
@@ -153,15 +153,23 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||||
|
print("objects_count", objects_count)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
||||||
|
mask_dict = sam2_masks
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||||
"""
|
"""
|
||||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
|
||||||
print("objects_count", objects_count)
|
|
||||||
video_predictor.reset_state(inference_state)
|
|
||||||
if len(mask_dict.labels) == 0:
|
if len(mask_dict.labels) == 0:
|
||||||
|
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
||||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
video_predictor.reset_state(inference_state)
|
video_predictor.reset_state(inference_state)
|
||||||
|
|
||||||
for object_id, object_info in mask_dict.labels.items():
|
for object_id, object_info in mask_dict.labels.items():
|
||||||
|
@@ -68,7 +68,7 @@ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
|||||||
|
|
||||||
# init video predictor state
|
# init video predictor state
|
||||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||||
step = 10 # the step to sample frames for Grounding DINO predictor
|
step = 20 # the step to sample frames for Grounding DINO predictor
|
||||||
|
|
||||||
sam2_masks = MaskDictionaryModel()
|
sam2_masks = MaskDictionaryModel()
|
||||||
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||||
@@ -107,6 +107,7 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
||||||
# print("results[0]",results[0])
|
# print("results[0]",results[0])
|
||||||
OBJECTS = results[0]["labels"]
|
OBJECTS = results[0]["labels"]
|
||||||
|
if input_boxes.shape[0] != 0:
|
||||||
|
|
||||||
# prompt SAM 2 image predictor to get the mask for the object
|
# prompt SAM 2 image predictor to get the mask for the object
|
||||||
masks, scores, logits = image_predictor.predict(
|
masks, scores, logits = image_predictor.predict(
|
||||||
@@ -131,7 +132,9 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||||
|
else:
|
||||||
|
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
||||||
|
mask_dict = sam2_masks
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||||
@@ -139,10 +142,12 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||||
frame_object_count[start_frame_idx] = objects_count
|
frame_object_count[start_frame_idx] = objects_count
|
||||||
print("objects_count", objects_count)
|
print("objects_count", objects_count)
|
||||||
video_predictor.reset_state(inference_state)
|
|
||||||
if len(mask_dict.labels) == 0:
|
if len(mask_dict.labels) == 0:
|
||||||
|
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
||||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
video_predictor.reset_state(inference_state)
|
video_predictor.reset_state(inference_state)
|
||||||
|
|
||||||
for object_id, object_info in mask_dict.labels.items():
|
for object_id, object_info in mask_dict.labels.items():
|
||||||
|
@@ -84,6 +84,25 @@ class MaskDictionaryModel:
|
|||||||
iou = intersection / union
|
iou = intersection / union
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def save_empty_mask_and_json(self, mask_data_dir, json_data_dir, image_name_list=None):
|
||||||
|
mask_img = torch.zeros((self.mask_height, self.mask_width))
|
||||||
|
if image_name_list:
|
||||||
|
for image_base_name in image_name_list:
|
||||||
|
image_base_name = image_base_name.split(".")[0]+".npy"
|
||||||
|
mask_name = "mask_"+image_base_name
|
||||||
|
np.save(os.path.join(mask_data_dir, mask_name), mask_img.numpy().astype(np.uint16))
|
||||||
|
|
||||||
|
json_data_path = os.path.join(json_data_dir, mask_name.replace(".npy", ".json"))
|
||||||
|
print("save_empty_mask_and_json", json_data_path)
|
||||||
|
self.to_json(json_data_path)
|
||||||
|
else:
|
||||||
|
np.save(os.path.join(mask_data_dir, self.mask_name), mask_img.numpy().astype(np.uint16))
|
||||||
|
json_data_path = os.path.join(json_data_dir, self.mask_name.replace(".npy", ".json"))
|
||||||
|
print("save_empty_mask_and_json", json_data_path)
|
||||||
|
self.to_json(json_data_path)
|
||||||
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"mask_name": self.mask_name,
|
"mask_name": self.mask_name,
|
||||||
|
Reference in New Issue
Block a user