support more demos with florence-2
This commit is contained in:
35
README.md
35
README.md
@@ -224,10 +224,43 @@ python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Grounded SAM 2 Florence-2 Demos
|
## Grounded SAM 2 Florence-2 Demos
|
||||||
### Grounded SAM 2 Florence-2 Image Demo
|
### Grounded SAM 2 Florence-2 Image Demo (Updating)
|
||||||
|
|
||||||
In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
|
In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
|
||||||
|
|
||||||
|
[Florence-2](https://arxiv.org/abs/2311.06242) is a powerful vision foundation model by Microsoft which supports a series of vision tasks by prompting with special `task_prompt` includes but not limited to:
|
||||||
|
|
||||||
|
| Task | Task Prompt | Text Input | Task Introduction |
|
||||||
|
|:---:|:---:|:---:|:---:|
|
||||||
|
| Object Detection | `<OD>` | ✘ | Detect main objects with single category name |
|
||||||
|
| Dense Region Caption | `<DENSE_REGION_CAPTION>` | ✘ | Detect main objects with short description |
|
||||||
|
| Region Proposal | `<REGION_PROPOSAL>` | ✘ | Generate proposals without category name |
|
||||||
|
| Phrase Grounding | `<CAPTION_TO_PHRASE_GROUNDING>` | ✔ | Ground main objects in image mentioned in caption |
|
||||||
|
|
||||||
|
|
||||||
|
Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo:
|
||||||
|
|
||||||
|
**Object Detection and Segmentation**
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_image_demo_florence2.py --pipeline object_detection_segmentation --image_path ./notebooks/images/cars.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Dense Region Caption and Segmentation**
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_image_demo_florence2.py --pipeline dense_region_caption_segmentation --image_path ./notebooks/images/cars.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Region Proposal and Segmentation**
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_image_demo_florence2.py --pipeline region_proposal_segmentation --image_path ./notebooks/images/cars.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
**Phrase Grounding and Segmentation**
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_image_demo_florence2.py --pipeline phrase_grounding_and_segmentation --image_path ./notebooks/images/cars.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Citation
|
### Citation
|
||||||
|
|
||||||
If you find this project helpful for your research, please consider citing the following BibTeX entry.
|
If you find this project helpful for your research, please consider citing the following BibTeX entry.
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -127,7 +128,7 @@ def object_detection_and_segmentation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
results = results["<OD>"]
|
results = results[task_prompt]
|
||||||
# parse florence-2 detection results
|
# parse florence-2 detection results
|
||||||
input_boxes = np.array(results["bboxes"])
|
input_boxes = np.array(results["bboxes"])
|
||||||
class_names = results["labels"]
|
class_names = results["labels"]
|
||||||
@@ -163,22 +164,286 @@ def object_detection_and_segmentation(
|
|||||||
|
|
||||||
label_annotator = sv.LabelAnnotator()
|
label_annotator = sv.LabelAnnotator()
|
||||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
|
||||||
|
|
||||||
mask_annotator = sv.MaskAnnotator()
|
mask_annotator = sv.MaskAnnotator()
|
||||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
print(f'Successfully save annotated image to "{output_dir}"')
|
||||||
|
|
||||||
|
"""
|
||||||
|
Pipeline 2: Dense Region Caption + Segmentation
|
||||||
|
"""
|
||||||
|
def dense_region_caption_and_segmentation(
|
||||||
|
florence2_model,
|
||||||
|
florence2_processor,
|
||||||
|
sam2_predictor,
|
||||||
|
image_path,
|
||||||
|
task_prompt="<DENSE_REGION_CAPTION>",
|
||||||
|
text_input=None,
|
||||||
|
output_dir=OUTPUT_DIR
|
||||||
|
):
|
||||||
|
# run florence-2 object detection in demo
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
|
|
||||||
|
""" Florence-2 Object Detection Output Format
|
||||||
|
{'<DENSE_REGION_CAPTION>':
|
||||||
|
{
|
||||||
|
'bboxes':
|
||||||
|
[
|
||||||
|
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
|
||||||
|
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
|
||||||
|
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
|
||||||
|
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
|
||||||
|
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
|
||||||
|
],
|
||||||
|
'labels': ['turquoise Volkswagen Beetle', 'wooden double doors with metal handles', 'wheel', 'wheel', 'door']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
results = results[task_prompt]
|
||||||
|
# parse florence-2 detection results
|
||||||
|
input_boxes = np.array(results["bboxes"])
|
||||||
|
class_names = results["labels"]
|
||||||
|
class_ids = np.array(list(range(len(class_names))))
|
||||||
|
|
||||||
|
# predict mask with SAM 2
|
||||||
|
sam2_predictor.set_image(np.array(image))
|
||||||
|
masks, scores, logits = sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
# specify labels
|
||||||
|
labels = [
|
||||||
|
f"{class_name}" for class_name in class_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# visualization results
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=input_boxes,
|
||||||
|
mask=masks.astype(bool),
|
||||||
|
class_id=class_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||||
|
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_annotated_image.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_image_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
print(f'Successfully save annotated image to "{output_dir}"')
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Pipeline 3: Region Proposal + Segmentation
|
||||||
|
"""
|
||||||
|
def region_proposal_and_segmentation(
|
||||||
|
florence2_model,
|
||||||
|
florence2_processor,
|
||||||
|
sam2_predictor,
|
||||||
|
image_path,
|
||||||
|
task_prompt="<REGION_PROPOSAL>",
|
||||||
|
text_input=None,
|
||||||
|
output_dir=OUTPUT_DIR
|
||||||
|
):
|
||||||
|
# run florence-2 object detection in demo
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
|
|
||||||
|
""" Florence-2 Object Detection Output Format
|
||||||
|
{'<REGION_PROPOSAL>':
|
||||||
|
{
|
||||||
|
'bboxes':
|
||||||
|
[
|
||||||
|
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
|
||||||
|
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
|
||||||
|
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
|
||||||
|
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
|
||||||
|
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
|
||||||
|
],
|
||||||
|
'labels': ['', '', '', '', '', '', '']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
results = results[task_prompt]
|
||||||
|
# parse florence-2 detection results
|
||||||
|
input_boxes = np.array(results["bboxes"])
|
||||||
|
class_names = results["labels"]
|
||||||
|
class_ids = np.array(list(range(len(class_names))))
|
||||||
|
|
||||||
|
# predict mask with SAM 2
|
||||||
|
sam2_predictor.set_image(np.array(image))
|
||||||
|
masks, scores, logits = sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
# specify labels
|
||||||
|
labels = [
|
||||||
|
f"region_{idx}" for idx, class_name in enumerate(class_names)
|
||||||
|
]
|
||||||
|
|
||||||
|
# visualization results
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=input_boxes,
|
||||||
|
mask=masks.astype(bool),
|
||||||
|
class_id=class_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||||
|
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
print(f'Successfully save annotated image to "{output_dir}"')
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Pipeline 4: Phrase Grounding + Segmentation
|
||||||
|
"""
|
||||||
|
def phrase_grounding_and_segmentation(
|
||||||
|
florence2_model,
|
||||||
|
florence2_processor,
|
||||||
|
sam2_predictor,
|
||||||
|
image_path,
|
||||||
|
task_prompt="<CAPTION_TO_PHRASE_GROUNDING>",
|
||||||
|
text_input=None,
|
||||||
|
output_dir=OUTPUT_DIR
|
||||||
|
):
|
||||||
|
# run florence-2 object detection in demo
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
|
|
||||||
|
""" Florence-2 Object Detection Output Format
|
||||||
|
{'<CAPTION_TO_PHRASE_GROUNDING>':
|
||||||
|
{
|
||||||
|
'bboxes':
|
||||||
|
[
|
||||||
|
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594],
|
||||||
|
[1.5999999046325684, 4.079999923706055, 639.0399780273438, 305.03997802734375]
|
||||||
|
],
|
||||||
|
'labels': ['A green car', 'a yellow building']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
assert text_input is not None, "Text input should not be none when calling phrase grounding pipeline."
|
||||||
|
results = results[task_prompt]
|
||||||
|
# parse florence-2 detection results
|
||||||
|
input_boxes = np.array(results["bboxes"])
|
||||||
|
class_names = results["labels"]
|
||||||
|
class_ids = np.array(list(range(len(class_names))))
|
||||||
|
|
||||||
|
# predict mask with SAM 2
|
||||||
|
sam2_predictor.set_image(np.array(image))
|
||||||
|
masks, scores, logits = sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
# specify labels
|
||||||
|
labels = [
|
||||||
|
f"{class_name}" for class_name in class_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# visualization results
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=input_boxes,
|
||||||
|
mask=masks.astype(bool),
|
||||||
|
class_id=class_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||||
|
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
print(f'Successfully save annotated image to "{output_dir}"')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
image_path = "./notebooks/images/groceries.jpg"
|
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
||||||
|
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
|
||||||
|
parser.add_argument("--pipeline", type=str, default="object_detection_segmentation", required=True, help="path to image file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# pipeline-1: detection + segmentation
|
IMAGE_PATH = args.image_path
|
||||||
object_detection_and_segmentation(
|
PIPELINE = args.pipeline
|
||||||
florence2_model=florence2_model,
|
|
||||||
florence2_processor=florence2_processor,
|
print(f"Running pipeline: {PIPELINE} now.")
|
||||||
sam2_predictor=sam2_predictor,
|
|
||||||
image_path=image_path
|
if PIPELINE == "object_detection_segmentation":
|
||||||
)
|
# pipeline-1: detection + segmentation
|
||||||
|
object_detection_and_segmentation(
|
||||||
|
florence2_model=florence2_model,
|
||||||
|
florence2_processor=florence2_processor,
|
||||||
|
sam2_predictor=sam2_predictor,
|
||||||
|
image_path=IMAGE_PATH
|
||||||
|
)
|
||||||
|
elif PIPELINE == "dense_region_caption_segmentation":
|
||||||
|
# pipeline-2: dense region caption + segmentation
|
||||||
|
dense_region_caption_and_segmentation(
|
||||||
|
florence2_model=florence2_model,
|
||||||
|
florence2_processor=florence2_processor,
|
||||||
|
sam2_predictor=sam2_predictor,
|
||||||
|
image_path=IMAGE_PATH
|
||||||
|
)
|
||||||
|
elif PIPELINE == "region_proposal_segmentation":
|
||||||
|
# pipeline-3: dense region caption + segmentation
|
||||||
|
region_proposal_and_segmentation(
|
||||||
|
florence2_model=florence2_model,
|
||||||
|
florence2_processor=florence2_processor,
|
||||||
|
sam2_predictor=sam2_predictor,
|
||||||
|
image_path=IMAGE_PATH
|
||||||
|
)
|
||||||
|
elif PIPELINE == "phrase_grounding_segmentation":
|
||||||
|
# pipeline-4: phrase grounding + segmentation
|
||||||
|
phrase_grounding_and_segmentation(
|
||||||
|
florence2_model=florence2_model,
|
||||||
|
florence2_processor=florence2_processor,
|
||||||
|
sam2_predictor=sam2_predictor,
|
||||||
|
image_path=IMAGE_PATH,
|
||||||
|
text_input="The image shows two vintage Chevrolet cars parked side by side, with one being a red convertible and the other a pink sedan, \
|
||||||
|
set against the backdrop of an urban area with a multi-story building and trees. \
|
||||||
|
The cars have Cuban license plates, indicating a location likely in Cuba."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
Reference in New Issue
Block a user