add open-vocab demo
This commit is contained in:
@@ -283,6 +283,13 @@ python grounded_sam2_image_demo_florence2.py \
|
|||||||
--text_input "The left red car."
|
--text_input "The left red car."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Open-Vocabulary Detection and Segmentation**
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_image_demo_florence2.py \
|
||||||
|
--pipeline open_vocabulary_detection_segmentation \
|
||||||
|
--image_path ./notebooks/images/cars.jpg \
|
||||||
|
--text_input "two cars"
|
||||||
|
```
|
||||||
|
|
||||||
### Citation
|
### Citation
|
||||||
|
|
||||||
|
@@ -513,6 +513,81 @@ def referring_expression_segmentation(
|
|||||||
print(f'Successfully save sam2 annotated image to "{output_dir}"')
|
print(f'Successfully save sam2 annotated image to "{output_dir}"')
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Pipeline 6: Open-Vocabulary Detection + Segmentation
|
||||||
|
"""
|
||||||
|
def open_vocabulary_detection_and_segmentation(
|
||||||
|
florence2_model,
|
||||||
|
florence2_processor,
|
||||||
|
sam2_predictor,
|
||||||
|
image_path,
|
||||||
|
task_prompt="<OPEN_VOCABULARY_DETECTION>",
|
||||||
|
text_input=None,
|
||||||
|
output_dir=OUTPUT_DIR
|
||||||
|
):
|
||||||
|
# run florence-2 object detection in demo
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
|
|
||||||
|
""" Florence-2 Open-Vocabulary Detection Output Format
|
||||||
|
{'<OPEN_VOCABULARY_DETECTION>':
|
||||||
|
{
|
||||||
|
'bboxes':
|
||||||
|
[
|
||||||
|
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594]
|
||||||
|
],
|
||||||
|
'bboxes_labels': ['A green car'],
|
||||||
|
'polygons': [],
|
||||||
|
'polygons_labels': []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline."
|
||||||
|
results = results[task_prompt]
|
||||||
|
# parse florence-2 detection results
|
||||||
|
input_boxes = np.array(results["bboxes"])
|
||||||
|
print(results)
|
||||||
|
class_names = results["bboxes_labels"]
|
||||||
|
class_ids = np.array(list(range(len(class_names))))
|
||||||
|
|
||||||
|
# predict mask with SAM 2
|
||||||
|
sam2_predictor.set_image(np.array(image))
|
||||||
|
masks, scores, logits = sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
# specify labels
|
||||||
|
labels = [
|
||||||
|
f"{class_name}" for class_name in class_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# visualization results
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=input_boxes,
|
||||||
|
mask=masks.astype(bool),
|
||||||
|
class_id=class_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||||
|
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_open_vocabulary_detection_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
print(f'Successfully save annotated image to "{output_dir}"')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
||||||
@@ -561,7 +636,7 @@ if __name__ == "__main__":
|
|||||||
text_input=INPUT_TEXT
|
text_input=INPUT_TEXT
|
||||||
)
|
)
|
||||||
elif PIPELINE == "referring_expression_segmentation":
|
elif PIPELINE == "referring_expression_segmentation":
|
||||||
# pipeline-5: referring segmentation + sam2 segmentation
|
# pipeline-5: referring segmentation + segmentation
|
||||||
referring_expression_segmentation(
|
referring_expression_segmentation(
|
||||||
florence2_model=florence2_model,
|
florence2_model=florence2_model,
|
||||||
florence2_processor=florence2_processor,
|
florence2_processor=florence2_processor,
|
||||||
@@ -569,5 +644,14 @@ if __name__ == "__main__":
|
|||||||
image_path=IMAGE_PATH,
|
image_path=IMAGE_PATH,
|
||||||
text_input=INPUT_TEXT
|
text_input=INPUT_TEXT
|
||||||
)
|
)
|
||||||
|
elif PIPELINE == "open_vocabulary_detection_segmentation":
|
||||||
|
# pipeline-6: open-vocabulary detection + segmentation
|
||||||
|
open_vocabulary_detection_and_segmentation(
|
||||||
|
florence2_model=florence2_model,
|
||||||
|
florence2_processor=florence2_processor,
|
||||||
|
sam2_predictor=sam2_predictor,
|
||||||
|
image_path=IMAGE_PATH,
|
||||||
|
text_input=INPUT_TEXT
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
Reference in New Issue
Block a user