fix: zero object detection error (#64)
* update dockerfile * fix: zero object detection error * fix: zero object detection error
This commit is contained in:
@@ -107,70 +107,76 @@ for start_frame_idx in range(0, len(frame_names), step):
|
||||
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
||||
# print("results[0]",results[0])
|
||||
OBJECTS = results[0]["labels"]
|
||||
if input_boxes.shape[0] != 0:
|
||||
# prompt SAM 2 image predictor to get the mask for the object
|
||||
masks, scores, logits = image_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
# convert the mask shape to (n, H, W)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
# prompt SAM 2 image predictor to get the mask for the object
|
||||
masks, scores, logits = image_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
# convert the mask shape to (n, H, W)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
"""
|
||||
Step 3: Register each object's positive points to video predictor
|
||||
"""
|
||||
|
||||
"""
|
||||
Step 3: Register each object's positive points to video predictor
|
||||
"""
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if mask_dict.promote_type == "mask":
|
||||
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||
else:
|
||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if mask_dict.promote_type == "mask":
|
||||
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||
|
||||
"""
|
||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||
"""
|
||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||
print("objects_count", objects_count)
|
||||
else:
|
||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||
print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
|
||||
mask_dict = sam2_masks
|
||||
|
||||
|
||||
"""
|
||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||
"""
|
||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||
print("objects_count", objects_count)
|
||||
video_predictor.reset_state(inference_state)
|
||||
|
||||
if len(mask_dict.labels) == 0:
|
||||
mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
|
||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||
continue
|
||||
video_predictor.reset_state(inference_state)
|
||||
else:
|
||||
video_predictor.reset_state(inference_state)
|
||||
|
||||
for object_id, object_info in mask_dict.labels.items():
|
||||
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
start_frame_idx,
|
||||
object_id,
|
||||
object_info.mask,
|
||||
)
|
||||
|
||||
video_segments = {} # output the following {step} frames tracking masks
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
||||
frame_masks = MaskDictionaryModel()
|
||||
for object_id, object_info in mask_dict.labels.items():
|
||||
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
start_frame_idx,
|
||||
object_id,
|
||||
object_info.mask,
|
||||
)
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
||||
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
||||
object_info.update_box()
|
||||
frame_masks.labels[out_obj_id] = object_info
|
||||
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
||||
frame_masks.mask_height = out_mask.shape[-2]
|
||||
frame_masks.mask_width = out_mask.shape[-1]
|
||||
video_segments = {} # output the following {step} frames tracking masks
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
||||
frame_masks = MaskDictionaryModel()
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
||||
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
||||
object_info.update_box()
|
||||
frame_masks.labels[out_obj_id] = object_info
|
||||
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
||||
frame_masks.mask_height = out_mask.shape[-2]
|
||||
frame_masks.mask_width = out_mask.shape[-1]
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
sam2_masks = copy.deepcopy(frame_masks)
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
sam2_masks = copy.deepcopy(frame_masks)
|
||||
|
||||
print("video_segments:", len(video_segments))
|
||||
print("video_segments:", len(video_segments))
|
||||
"""
|
||||
Step 5: save the tracking masks and json files
|
||||
"""
|
||||
|
Reference in New Issue
Block a user