add open-vocab demo

This commit is contained in:
rentainhe
2024-08-19 00:22:47 +08:00
parent 5f886743d9
commit afa91ca407
2 changed files with 92 additions and 1 deletions

View File

@@ -283,6 +283,13 @@ python grounded_sam2_image_demo_florence2.py \
--text_input "The left red car." --text_input "The left red car."
``` ```
**Open-Vocabulary Detection and Segmentation**
```bash
python grounded_sam2_image_demo_florence2.py \
--pipeline open_vocabulary_detection_segmentation \
--image_path ./notebooks/images/cars.jpg \
--text_input "two cars"
```
### Citation ### Citation

View File

@@ -513,6 +513,81 @@ def referring_expression_segmentation(
print(f'Successfully save sam2 annotated image to "{output_dir}"') print(f'Successfully save sam2 annotated image to "{output_dir}"')
"""
Pipeline 6: Open-Vocabulary Detection + Segmentation
"""
def open_vocabulary_detection_and_segmentation(
florence2_model,
florence2_processor,
sam2_predictor,
image_path,
task_prompt="<OPEN_VOCABULARY_DETECTION>",
text_input=None,
output_dir=OUTPUT_DIR
):
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
""" Florence-2 Open-Vocabulary Detection Output Format
{'<OPEN_VOCABULARY_DETECTION>':
{
'bboxes':
[
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594]
],
'bboxes_labels': ['A green car'],
'polygons': [],
'polygons_labels': []
}
}
"""
assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline."
results = results[task_prompt]
# parse florence-2 detection results
input_boxes = np.array(results["bboxes"])
print(results)
class_names = results["bboxes_labels"]
class_ids = np.array(list(range(len(class_names))))
# predict mask with SAM 2
sam2_predictor.set_image(np.array(image))
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if masks.ndim == 4:
masks = masks.squeeze(1)
# specify labels
labels = [
f"{class_name}" for class_name in class_names
]
# visualization results
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=masks.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame)
print(f'Successfully save annotated image to "{output_dir}"')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True) parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
@@ -561,7 +636,7 @@ if __name__ == "__main__":
text_input=INPUT_TEXT text_input=INPUT_TEXT
) )
elif PIPELINE == "referring_expression_segmentation": elif PIPELINE == "referring_expression_segmentation":
# pipeline-5: referring segmentation + sam2 segmentation # pipeline-5: referring segmentation + segmentation
referring_expression_segmentation( referring_expression_segmentation(
florence2_model=florence2_model, florence2_model=florence2_model,
florence2_processor=florence2_processor, florence2_processor=florence2_processor,
@@ -569,5 +644,14 @@ if __name__ == "__main__":
image_path=IMAGE_PATH, image_path=IMAGE_PATH,
text_input=INPUT_TEXT text_input=INPUT_TEXT
) )
elif PIPELINE == "open_vocabulary_detection_segmentation":
# pipeline-6: open-vocabulary detection + segmentation
open_vocabulary_detection_and_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH,
text_input=INPUT_TEXT
)
else: else:
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time") raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")