diff --git a/utils/common_utils.py b/utils/common_utils.py index 9373bdc..ac60224 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -48,6 +48,10 @@ class CommonUtils: object_mask = (mask == uid) all_object_masks.append(object_mask[None]) + if len(all_object_masks) == 0: + output_image_path = os.path.join(output_path, raw_image_name) + cv2.imwrite(output_image_path, image) + continue # get n masks: (n, h, w) all_object_masks = np.concatenate(all_object_masks, axis=0)