diff --git a/README.md b/README.md index 8e3fe51..0048f36 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,13 @@ python grounded_sam2_image_demo_florence2.py \ --text_input "The left red car." ``` +**Open-Vocabulary Detection and Segmentation** +```bash +python grounded_sam2_image_demo_florence2.py \ + --pipeline open_vocabulary_detection_segmentation \ + --image_path ./notebooks/images/cars.jpg \ + --text_input "two cars" +``` ### Citation diff --git a/grounded_sam2_image_demo_florence2.py b/grounded_sam2_image_demo_florence2.py index 4a19548..28a3507 100644 --- a/grounded_sam2_image_demo_florence2.py +++ b/grounded_sam2_image_demo_florence2.py @@ -513,6 +513,81 @@ def referring_expression_segmentation( print(f'Successfully save sam2 annotated image to "{output_dir}"') +""" +Pipeline 6: Open-Vocabulary Detection + Segmentation +""" +def open_vocabulary_detection_and_segmentation( + florence2_model, + florence2_processor, + sam2_predictor, + image_path, + task_prompt="", + text_input=None, + output_dir=OUTPUT_DIR +): + # run florence-2 object detection in demo + image = Image.open(image_path).convert("RGB") + results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image) + + """ Florence-2 Open-Vocabulary Detection Output Format + {'': + { + 'bboxes': + [ + [34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594] + ], + 'bboxes_labels': ['A green car'], + 'polygons': [], + 'polygons_labels': [] + } + } + """ + assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline." + results = results[task_prompt] + # parse florence-2 detection results + input_boxes = np.array(results["bboxes"]) + print(results) + class_names = results["bboxes_labels"] + class_ids = np.array(list(range(len(class_names)))) + + # predict mask with SAM 2 + sam2_predictor.set_image(np.array(image)) + masks, scores, logits = sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + + if masks.ndim == 4: + masks = masks.squeeze(1) + + # specify labels + labels = [ + f"{class_name}" for class_name in class_names + ] + + # visualization results + img = cv2.imread(image_path) + detections = sv.Detections( + xyxy=input_boxes, + mask=masks.astype(bool), + class_id=class_ids + ) + + box_annotator = sv.BoxAnnotator() + annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) + + label_annotator = sv.LabelAnnotator() + annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) + cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame) + + mask_annotator = sv.MaskAnnotator() + annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) + cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame) + + print(f'Successfully save annotated image to "{output_dir}"') + if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True) @@ -561,7 +636,7 @@ if __name__ == "__main__": text_input=INPUT_TEXT ) elif PIPELINE == "referring_expression_segmentation": - # pipeline-5: referring segmentation + sam2 segmentation + # pipeline-5: referring segmentation + segmentation referring_expression_segmentation( florence2_model=florence2_model, florence2_processor=florence2_processor, @@ -569,5 +644,14 @@ if __name__ == "__main__": image_path=IMAGE_PATH, text_input=INPUT_TEXT ) + elif PIPELINE == "open_vocabulary_detection_segmentation": + # pipeline-6: open-vocabulary detection + segmentation + open_vocabulary_detection_and_segmentation( + florence2_model=florence2_model, + florence2_processor=florence2_processor, + sam2_predictor=sam2_predictor, + image_path=IMAGE_PATH, + text_input=INPUT_TEXT + ) else: raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time") \ No newline at end of file