diff --git a/grounded_sam2_image_demo_florence2.py b/grounded_sam2_image_demo_florence2.py
new file mode 100644
index 0000000..755c15f
--- /dev/null
+++ b/grounded_sam2_image_demo_florence2.py
@@ -0,0 +1,184 @@
+import os
+import cv2
+import torch
+import numpy as np
+import supervision as sv
+from PIL import Image
+from sam2.build_sam import build_sam2
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from transformers import AutoProcessor, AutoModelForCausalLM
+from utils.supervision_utils import CUSTOM_COLOR_MAP
+
+"""
+Define Some Hyperparam
+"""
+
+TASK_PROMPT = {
+ "caption": "
",
+ "detailed_caption": "",
+ "more_detailed_caption": "",
+ "dense_region_caption": "",
+ "region_proposal": "",
+ "phrase_grounding": "",
+ "referring_expression_segmentation": "",
+ "region_to_segmentation": "",
+ "open_vocabulary_detection": "",
+ "region_to_category": "",
+ "region_to_description": "",
+ "ocr": "",
+ "ocr_with_region": "",
+}
+
+OUTPUT_DIR = "./outputs"
+
+if not os.path.exists(OUTPUT_DIR):
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+"""
+Init Florence-2 and SAM 2 Model
+"""
+
+FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
+SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
+SAM2_CONFIG = "sam2_hiera_l.yaml"
+
+# environment settings
+# use bfloat16
+torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
+
+if torch.cuda.get_device_properties(0).major >= 8:
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+device = "cuda:0" if torch.cuda.is_available() else "cpu"
+torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
+
+# build florence-2
+florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
+florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
+
+# build sam 2
+sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
+sam2_predictor = SAM2ImagePredictor(sam2_model)
+
+def run_florence2(task_prompt, text_input, model, processor, image):
+ assert model is not None, "You should pass the init florence-2 model here"
+ assert processor is not None, "You should set florence-2 processor here"
+
+ device = model.device
+
+ if text_input is None:
+ prompt = task_prompt
+ else:
+ prompt = task_prompt + text_input
+
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
+ generated_ids = model.generate(
+ input_ids=inputs["input_ids"].to(device),
+ pixel_values=inputs["pixel_values"].to(device),
+ max_new_tokens=1024,
+ early_stopping=False,
+ do_sample=False,
+ num_beams=3,
+ )
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
+ parsed_answer = processor.post_process_generation(
+ generated_text,
+ task=task_prompt,
+ image_size=(image.width, image.height)
+ )
+ return parsed_answer
+
+
+"""
+We support a set of pipelines built by Florence-2 + SAM 2
+"""
+
+"""
+Pipeline-1: Object Detection + Segmentation
+"""
+def object_detection_and_segmentation(
+ florence2_model,
+ florence2_processor,
+ sam2_predictor,
+ image_path,
+ task_prompt="",
+ text_input=None,
+ output_dir=OUTPUT_DIR
+):
+ # run florence-2 object detection in demo
+ image = Image.open(image_path).convert("RGB")
+ results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
+
+ """ Florence-2 Object Detection Output Format
+ {'':
+ {
+ 'bboxes':
+ [
+ [33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469],
+ [454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906],
+ [224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438],
+ [449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406],
+ [91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406]
+ ],
+ 'labels': ['car', 'door', 'door', 'wheel', 'wheel']
+ }
+ }
+ """
+ results = results[""]
+ # parse florence-2 detection results
+ input_boxes = np.array(results["bboxes"])
+ class_names = results["labels"]
+ class_ids = np.array(list(range(len(class_names))))
+
+ # predict mask with SAM 2
+ sam2_predictor.set_image(np.array(image))
+ masks, scores, logits = sam2_predictor.predict(
+ point_coords=None,
+ point_labels=None,
+ box=input_boxes,
+ multimask_output=False,
+ )
+
+ if masks.ndim == 4:
+ masks = masks.squeeze(1)
+
+ # specify labels
+ labels = [
+ f"{class_name}" for class_name in class_names
+ ]
+
+ # visualization results
+ img = cv2.imread(image_path)
+ detections = sv.Detections(
+ xyxy=input_boxes,
+ mask=masks.astype(bool),
+ class_id=class_ids
+ )
+
+ box_annotator = sv.BoxAnnotator()
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
+
+ label_annotator = sv.LabelAnnotator()
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame)
+
+ mask_annotator = sv.MaskAnnotator()
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
+ cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame)
+
+
+
+if __name__ == "__main__":
+
+ image_path = "/comp_robot/rentianhe/code/Grounded-SAM-2/notebooks/images/groceries.jpg"
+
+ # pipeline-1: detection + segmentation
+ object_detection_and_segmentation(
+ florence2_model=florence2_model,
+ florence2_processor=florence2_processor,
+ sam2_predictor=sam2_predictor,
+ image_path=image_path
+ )
\ No newline at end of file