diff --git a/grounded_sam2_image_demo_florence2.py b/grounded_sam2_image_demo_florence2.py new file mode 100644 index 0000000..755c15f --- /dev/null +++ b/grounded_sam2_image_demo_florence2.py @@ -0,0 +1,184 @@ +import os +import cv2 +import torch +import numpy as np +import supervision as sv +from PIL import Image +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from transformers import AutoProcessor, AutoModelForCausalLM +from utils.supervision_utils import CUSTOM_COLOR_MAP + +""" +Define Some Hyperparam +""" + +TASK_PROMPT = { + "caption": "", + "detailed_caption": "", + "more_detailed_caption": "", + "dense_region_caption": "", + "region_proposal": "", + "phrase_grounding": "", + "referring_expression_segmentation": "", + "region_to_segmentation": "", + "open_vocabulary_detection": "", + "region_to_category": "", + "region_to_description": "", + "ocr": "", + "ocr_with_region": "", +} + +OUTPUT_DIR = "./outputs" + +if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR, exist_ok=True) + +""" +Init Florence-2 and SAM 2 Model +""" + +FLORENCE2_MODEL_ID = "microsoft/Florence-2-large" +SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt" +SAM2_CONFIG = "sam2_hiera_l.yaml" + +# environment settings +# use bfloat16 +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +# build florence-2 +florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device) +florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True) + +# build sam 2 +sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device) +sam2_predictor = SAM2ImagePredictor(sam2_model) + +def run_florence2(task_prompt, text_input, model, processor, image): + assert model is not None, "You should pass the init florence-2 model here" + assert processor is not None, "You should set florence-2 processor here" + + device = model.device + + if text_input is None: + prompt = task_prompt + else: + prompt = task_prompt + text_input + + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16) + generated_ids = model.generate( + input_ids=inputs["input_ids"].to(device), + pixel_values=inputs["pixel_values"].to(device), + max_new_tokens=1024, + early_stopping=False, + do_sample=False, + num_beams=3, + ) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = processor.post_process_generation( + generated_text, + task=task_prompt, + image_size=(image.width, image.height) + ) + return parsed_answer + + +""" +We support a set of pipelines built by Florence-2 + SAM 2 +""" + +""" +Pipeline-1: Object Detection + Segmentation +""" +def object_detection_and_segmentation( + florence2_model, + florence2_processor, + sam2_predictor, + image_path, + task_prompt="", + text_input=None, + output_dir=OUTPUT_DIR +): + # run florence-2 object detection in demo + image = Image.open(image_path).convert("RGB") + results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image) + + """ Florence-2 Object Detection Output Format + {'': + { + 'bboxes': + [ + [33.599998474121094, 159.59999084472656, 596.7999877929688, 371.7599792480469], + [454.0799865722656, 96.23999786376953, 580.7999877929688, 261.8399963378906], + [224.95999145507812, 86.15999603271484, 333.7599792480469, 164.39999389648438], + [449.5999755859375, 276.239990234375, 554.5599975585938, 370.3199768066406], + [91.19999694824219, 280.0799865722656, 198.0800018310547, 370.3199768066406] + ], + 'labels': ['car', 'door', 'door', 'wheel', 'wheel'] + } + } + """ + results = results[""] + # parse florence-2 detection results + input_boxes = np.array(results["bboxes"]) + class_names = results["labels"] + class_ids = np.array(list(range(len(class_names)))) + + # predict mask with SAM 2 + sam2_predictor.set_image(np.array(image)) + masks, scores, logits = sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + + if masks.ndim == 4: + masks = masks.squeeze(1) + + # specify labels + labels = [ + f"{class_name}" for class_name in class_names + ] + + # visualization results + img = cv2.imread(image_path) + detections = sv.Detections( + xyxy=input_boxes, + mask=masks.astype(bool), + class_id=class_ids + ) + + box_annotator = sv.BoxAnnotator() + annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) + + label_annotator = sv.LabelAnnotator() + annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) + cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_annotated_image.jpg"), annotated_frame) + + mask_annotator = sv.MaskAnnotator() + annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) + cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_florence2_det_image_with_mask.jpg"), annotated_frame) + + + +if __name__ == "__main__": + + image_path = "/comp_robot/rentianhe/code/Grounded-SAM-2/notebooks/images/groceries.jpg" + + # pipeline-1: detection + segmentation + object_detection_and_segmentation( + florence2_model=florence2_model, + florence2_processor=florence2_processor, + sam2_predictor=sam2_predictor, + image_path=image_path + ) \ No newline at end of file