add open-vocab demo
This commit is contained in:
@@ -283,6 +283,13 @@ python grounded_sam2_image_demo_florence2.py \
|
||||
--text_input "The left red car."
|
||||
```
|
||||
|
||||
**Open-Vocabulary Detection and Segmentation**
|
||||
```bash
|
||||
python grounded_sam2_image_demo_florence2.py \
|
||||
--pipeline open_vocabulary_detection_segmentation \
|
||||
--image_path ./notebooks/images/cars.jpg \
|
||||
--text_input "two cars"
|
||||
```
|
||||
|
||||
### Citation
|
||||
|
||||
|
@@ -513,6 +513,81 @@ def referring_expression_segmentation(
|
||||
print(f'Successfully save sam2 annotated image to "{output_dir}"')
|
||||
|
||||
|
||||
"""
|
||||
Pipeline 6: Open-Vocabulary Detection + Segmentation
|
||||
"""
|
||||
def open_vocabulary_detection_and_segmentation(
|
||||
florence2_model,
|
||||
florence2_processor,
|
||||
sam2_predictor,
|
||||
image_path,
|
||||
task_prompt="<OPEN_VOCABULARY_DETECTION>",
|
||||
text_input=None,
|
||||
output_dir=OUTPUT_DIR
|
||||
):
|
||||
# run florence-2 object detection in demo
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||
|
||||
""" Florence-2 Open-Vocabulary Detection Output Format
|
||||
{'<OPEN_VOCABULARY_DETECTION>':
|
||||
{
|
||||
'bboxes':
|
||||
[
|
||||
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594]
|
||||
],
|
||||
'bboxes_labels': ['A green car'],
|
||||
'polygons': [],
|
||||
'polygons_labels': []
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline."
|
||||
results = results[task_prompt]
|
||||
# parse florence-2 detection results
|
||||
input_boxes = np.array(results["bboxes"])
|
||||
print(results)
|
||||
class_names = results["bboxes_labels"]
|
||||
class_ids = np.array(list(range(len(class_names))))
|
||||
|
||||
# predict mask with SAM 2
|
||||
sam2_predictor.set_image(np.array(image))
|
||||
masks, scores, logits = sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
if masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
# specify labels
|
||||
labels = [
|
||||
f"{class_name}" for class_name in class_names
|
||||
]
|
||||
|
||||
# visualization results
|
||||
img = cv2.imread(image_path)
|
||||
detections = sv.Detections(
|
||||
xyxy=input_boxes,
|
||||
mask=masks.astype(bool),
|
||||
class_id=class_ids
|
||||
)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||
|
||||
label_annotator = sv.LabelAnnotator()
|
||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame)
|
||||
|
||||
mask_annotator = sv.MaskAnnotator()
|
||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame)
|
||||
|
||||
print(f'Successfully save annotated image to "{output_dir}"')
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
||||
@@ -561,7 +636,7 @@ if __name__ == "__main__":
|
||||
text_input=INPUT_TEXT
|
||||
)
|
||||
elif PIPELINE == "referring_expression_segmentation":
|
||||
# pipeline-5: referring segmentation + sam2 segmentation
|
||||
# pipeline-5: referring segmentation + segmentation
|
||||
referring_expression_segmentation(
|
||||
florence2_model=florence2_model,
|
||||
florence2_processor=florence2_processor,
|
||||
@@ -569,5 +644,14 @@ if __name__ == "__main__":
|
||||
image_path=IMAGE_PATH,
|
||||
text_input=INPUT_TEXT
|
||||
)
|
||||
elif PIPELINE == "open_vocabulary_detection_segmentation":
|
||||
# pipeline-6: open-vocabulary detection + segmentation
|
||||
open_vocabulary_detection_and_segmentation(
|
||||
florence2_model=florence2_model,
|
||||
florence2_processor=florence2_processor,
|
||||
sam2_predictor=sam2_predictor,
|
||||
image_path=IMAGE_PATH,
|
||||
text_input=INPUT_TEXT
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
Reference in New Issue
Block a user