support more demos with florence-2

This commit is contained in:
rentainhe
2024-08-15 02:13:30 +08:00
parent 35541890cc
commit 1fc4d469ab
2 changed files with 310 additions and 12 deletions

View File

@@ -224,10 +224,43 @@ python grounded_sam2_tracking_demo_with_continuous_id_gd1.5.py
``` ```
## Grounded SAM 2 Florence-2 Demos ## Grounded SAM 2 Florence-2 Demos
### Grounded SAM 2 Florence-2 Image Demo ### Grounded SAM 2 Florence-2 Image Demo (Updating)
In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications. In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
[Florence-2](https://arxiv.org/abs/2311.06242) is a powerful vision foundation model by Microsoft which supports a series of vision tasks by prompting with special `task_prompt` includes but not limited to:
| Task | Task Prompt | Text Input | Task Introduction |
|:---:|:---:|:---:|:---:|
| Object Detection | `<OD>` | &#10008; | Detect main objects with single category name |
| Dense Region Caption | `<DENSE_REGION_CAPTION>` | &#10008; | Detect main objects with short description |
| Region Proposal | `<REGION_PROPOSAL>` | &#10008; | Generate proposals without category name |
| Phrase Grounding | `<CAPTION_TO_PHRASE_GROUNDING>` | &#10004; | Ground main objects in image mentioned in caption |
Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo:
**Object Detection and Segmentation**
```bash
python grounded_sam2_image_demo_florence2.py --pipeline object_detection_segmentation --image_path ./notebooks/images/cars.jpg
```
**Dense Region Caption and Segmentation**
```bash
python grounded_sam2_image_demo_florence2.py --pipeline dense_region_caption_segmentation --image_path ./notebooks/images/cars.jpg
```
**Region Proposal and Segmentation**
```bash
python grounded_sam2_image_demo_florence2.py --pipeline region_proposal_segmentation --image_path ./notebooks/images/cars.jpg
```
**Phrase Grounding and Segmentation**
```bash
python grounded_sam2_image_demo_florence2.py --pipeline phrase_grounding_and_segmentation --image_path ./notebooks/images/cars.jpg
```
### Citation ### Citation
If you find this project helpful for your research, please consider citing the following BibTeX entry. If you find this project helpful for your research, please consider citing the following BibTeX entry.

View File

@@ -1,6 +1,7 @@
import os import os
import cv2 import cv2
import torch import torch
import argparse
import numpy as np import numpy as np
import supervision as sv import supervision as sv
from PIL import Image from PIL import Image
@@ -127,7 +128,7 @@ def object_detection_and_segmentation(
} }
} }
""" """
results = results["<OD>"] results = results[task_prompt]
# parse florence-2 detection results # parse florence-2 detection results
input_boxes = np.array(results["bboxes"]) input_boxes = np.array(results["bboxes"])
class_names = results["labels"] class_names = results["labels"]
@@ -163,22 +164,286 @@ def object_detection_and_segmentation(
label_annotator = sv.LabelAnnotator() label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame) cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator() mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame) cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
print(f'Successfully save annotated image to "{output_dir}"')
"""
Pipeline 2: Dense Region Caption + Segmentation
"""
def dense_region_caption_and_segmentation(
florence2_model,
florence2_processor,
sam2_predictor,
image_path,
task_prompt="<DENSE_REGION_CAPTION>",
text_input=None,
output_dir=OUTPUT_DIR
):
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
""" Florence-2 Object Detection Output Format
{'<DENSE_REGION_CAPTION>':
{
'bboxes':
[
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
],
'labels': ['turquoise Volkswagen Beetle', 'wooden double doors with metal handles', 'wheel', 'wheel', 'door']
}
}
"""
results = results[task_prompt]
# parse florence-2 detection results
input_boxes = np.array(results["bboxes"])
class_names = results["labels"]
class_ids = np.array(list(range(len(class_names))))
# predict mask with SAM 2
sam2_predictor.set_image(np.array(image))
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if masks.ndim == 4:
masks = masks.squeeze(1)
# specify labels
labels = [
f"{class_name}" for class_name in class_names
]
# visualization results
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=masks.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_dense_region_cap_image_with_mask.jpg"), annotated_frame)
print(f'Successfully save annotated image to "{output_dir}"')
"""
Pipeline 3: Region Proposal + Segmentation
"""
def region_proposal_and_segmentation(
florence2_model,
florence2_processor,
sam2_predictor,
image_path,
task_prompt="<REGION_PROPOSAL>",
text_input=None,
output_dir=OUTPUT_DIR
):
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
""" Florence-2 Object Detection Output Format
{'<REGION_PROPOSAL>':
{
'bboxes':
[
[33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
[454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
[224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
[449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
[91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
],
'labels': ['', '', '', '', '', '', '']
}
}
"""
results = results[task_prompt]
# parse florence-2 detection results
input_boxes = np.array(results["bboxes"])
class_names = results["labels"]
class_ids = np.array(list(range(len(class_names))))
# predict mask with SAM 2
sam2_predictor.set_image(np.array(image))
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if masks.ndim == 4:
masks = masks.squeeze(1)
# specify labels
labels = [
f"region_{idx}" for idx, class_name in enumerate(class_names)
]
# visualization results
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=masks.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_region_proposal_with_mask.jpg"), annotated_frame)
print(f'Successfully save annotated image to "{output_dir}"')
"""
Pipeline 4: Phrase Grounding + Segmentation
"""
def phrase_grounding_and_segmentation(
florence2_model,
florence2_processor,
sam2_predictor,
image_path,
task_prompt="<CAPTION_TO_PHRASE_GROUNDING>",
text_input=None,
output_dir=OUTPUT_DIR
):
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
""" Florence-2 Object Detection Output Format
{'<CAPTION_TO_PHRASE_GROUNDING>':
{
'bboxes':
[
[34.23999786376953, 159.1199951171875, 582.0800170898438, 374.6399841308594],
[1.5999999046325684, 4.079999923706055, 639.0399780273438, 305.03997802734375]
],
'labels': ['A green car', 'a yellow building']
}
}
"""
assert text_input is not None, "Text input should not be none when calling phrase grounding pipeline."
results = results[task_prompt]
# parse florence-2 detection results
input_boxes = np.array(results["bboxes"])
class_names = results["labels"]
class_ids = np.array(list(range(len(class_names))))
# predict mask with SAM 2
sam2_predictor.set_image(np.array(image))
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if masks.ndim == 4:
masks = masks.squeeze(1)
# specify labels
labels = [
f"{class_name}" for class_name in class_names
]
# visualization results
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=masks.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_phrase_grounding_with_mask.jpg"), annotated_frame)
print(f'Successfully save annotated image to "{output_dir}"')
if __name__ == "__main__": if __name__ == "__main__":
image_path = "./notebooks/images/groceries.jpg" parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
parser.add_argument("--pipeline", type=str, default="object_detection_segmentation", required=True, help="path to image file")
args = parser.parse_args()
# pipeline-1: detection + segmentation IMAGE_PATH = args.image_path
object_detection_and_segmentation( PIPELINE = args.pipeline
florence2_model=florence2_model,
florence2_processor=florence2_processor, print(f"Running pipeline: {PIPELINE} now.")
sam2_predictor=sam2_predictor,
image_path=image_path if PIPELINE == "object_detection_segmentation":
) # pipeline-1: detection + segmentation
object_detection_and_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH
)
elif PIPELINE == "dense_region_caption_segmentation":
# pipeline-2: dense region caption + segmentation
dense_region_caption_and_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH
)
elif PIPELINE == "region_proposal_segmentation":
# pipeline-3: dense region caption + segmentation
region_proposal_and_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH
)
elif PIPELINE == "phrase_grounding_segmentation":
# pipeline-4: phrase grounding + segmentation
phrase_grounding_and_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH,
text_input="The image shows two vintage Chevrolet cars parked side by side, with one being a red convertible and the other a pink sedan, \
set against the backdrop of an urban area with a multi-story building and trees. \
The cars have Cuban license plates, indicating a location likely in Cuba."
)
else:
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")