add referring demo

This commit is contained in:
rentainhe
2024-08-16 02:12:41 +08:00
parent 1fc4d469ab
commit 122c46d823
2 changed files with 152 additions and 8 deletions

View File

@@ -109,6 +109,7 @@ def object_detection_and_segmentation(
text_input=None,
output_dir=OUTPUT_DIR
):
assert text_input is None, "Text input should not be none when calling object detection pipeline."
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -184,6 +185,7 @@ def dense_region_caption_and_segmentation(
text_input=None,
output_dir=OUTPUT_DIR
):
assert text_input is None, "Text input should not be none when calling dense region caption pipeline."
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -260,6 +262,7 @@ def region_proposal_and_segmentation(
text_input=None,
output_dir=OUTPUT_DIR
):
assert text_input is None, "Text input should not be none when calling region proposal pipeline."
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -398,15 +401,129 @@ def phrase_grounding_and_segmentation(
print(f'Successfully save annotated image to "{output_dir}"')
"""
Pipeline 5: Referring Expression Segmentation
Note that Florence-2 directly support referring segmentation with polygon output format, which may be not that accurate,
therefore we try to decode box from polygon and use SAM 2 for mask prediction
"""
def referring_expression_segmentation(
florence2_model,
florence2_processor,
sam2_predictor,
image_path,
task_prompt="<REFERRING_EXPRESSION_SEGMENTATION>",
text_input=None,
output_dir=OUTPUT_DIR
):
# run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
""" Florence-2 Object Detection Output Format
{'<REFERRING_EXPRESSION_SEGMENTATION>':
{
'polygons': [[[...]]]
'labels': ['']
}
}
"""
assert text_input is not None, "Text input should not be none when calling referring segmentation pipeline."
results = results[task_prompt]
# parse florence-2 detection results
polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
class_names = [text_input]
class_ids = np.array(list(range(len(class_names))))
# parse polygon format to mask
img_width, img_height = image.size[0], image.size[1]
florence2_mask = np.zeros((img_height, img_width), dtype=np.uint8)
if len(polygon_points) < 3:
print("Invalid polygon:", polygon_points)
exit()
cv2.fillPoly(florence2_mask, [polygon_points], 1)
if florence2_mask.ndim == 2:
florence2_mask = florence2_mask[None]
# compute bounding box based on polygon points
x_min = np.min(polygon_points[:, 0])
y_min = np.min(polygon_points[:, 1])
x_max = np.max(polygon_points[:, 0])
y_max = np.max(polygon_points[:, 1])
input_boxes = np.array([[x_min, y_min, x_max, y_max]])
# predict mask with SAM 2
sam2_predictor.set_image(np.array(image))
sam2_masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
if sam2_masks.ndim == 4:
sam2_masks = sam2_masks.squeeze(1)
# specify labels
labels = [
f"{class_name}" for class_name in class_names
]
# visualization florence2 mask
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=florence2_mask.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "florence2_referring_segmentation_box_with_mask.jpg"), annotated_frame)
print(f'Successfully save florence-2 annotated image to "{output_dir}"')
# visualize sam2 mask
img = cv2.imread(image_path)
detections = sv.Detections(
xyxy=input_boxes,
mask=sam2_masks.astype(bool),
class_id=class_ids
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_referring_box_with_sam2_mask.jpg"), annotated_frame)
print(f'Successfully save sam2 annotated image to "{output_dir}"')
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
parser.add_argument("--pipeline", type=str, default="object_detection_segmentation", required=True, help="path to image file")
parser.add_argument("--text_input", type=str, default=None, required=False, help="path to image file")
args = parser.parse_args()
IMAGE_PATH = args.image_path
PIPELINE = args.pipeline
INPUT_TEXT = args.text_input
print(f"Running pipeline: {PIPELINE} now.")
@@ -441,9 +558,16 @@ if __name__ == "__main__":
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH,
text_input="The image shows two vintage Chevrolet cars parked side by side, with one being a red convertible and the other a pink sedan, \
set against the backdrop of an urban area with a multi-story building and trees. \
The cars have Cuban license plates, indicating a location likely in Cuba."
text_input=INPUT_TEXT
)
elif PIPELINE == "referring_expression_segmentation":
# pipeline-5: referring segmentation + sam2 segmentation
referring_expression_segmentation(
florence2_model=florence2_model,
florence2_processor=florence2_processor,
sam2_predictor=sam2_predictor,
image_path=IMAGE_PATH,
text_input=INPUT_TEXT
)
else:
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")