From da69bf587e070f88fcbe979afd9493ace37ef2f9 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Mon, 19 Aug 2024 16:37:50 +0800 Subject: [PATCH] support auto label pipeline with florence-2 --- README.md | 28 ++- grounded_sam2_florence2_autolabel_pipeline.py | 198 ++++++++++++++++++ 2 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 grounded_sam2_florence2_autolabel_pipeline.py diff --git a/README.md b/README.md index ae828b0..d6b595a 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,12 @@ Grounded SAM 2 does not introduce significant methodological changes compared to - [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16) - [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino) - [Grounded SAM 2 Florence-2 Demos](#grounded-sam-2-florence-2-demos) - - [Grounded SAM 2 Florence-2 Image Demo (Updating)](#grounded-sam-2-florence-2-image-demo-updating) + - [Grounded SAM 2 Florence-2 Image Demo](#grounded-sam-2-florence-2-image-demo-updating) + - [Grounded SAM 2 Florence-2 Image Auto-Labeling Demo](#grounded-sam-2-florence-2-image-auto-labeling-demo) - [Citation](#citation) + ## Installation Download the pretrained `SAM 2` checkpoints: @@ -231,7 +233,7 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py ``` ## Grounded SAM 2 Florence-2 Demos -### Grounded SAM 2 Florence-2 Image Demo (Updating) +### Grounded SAM 2 Florence-2 Image Demo In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications. @@ -244,6 +246,7 @@ In this section, we will explore how to integrate the feature-rich and robust op | Region Proposal | `` | ✘ | Generate proposals without category name | | Phrase Grounding | `` | ✔ | Ground main objects in image mentioned in caption | | Referring Expression Segmentation | `` | ✔ | Ground the object which is most related to the text input | +| Open Vocabulary Detection and Segmentation | `` | ✔ | Ground any object with text input | Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo: @@ -298,6 +301,27 @@ python grounded_sam2_florence2_image_demo.py \ --text_input "two cars" ``` +### Grounded SAM 2 Florence-2 Image Auto-Labeling Demo +`Florence-2` can be used as a auto image annotator by cascading its caption capability with its grounding capability. + +| Task | Task Prompt | Text Input | +|:---:|:---:|:---:| +| Caption + Phrase Grounding | `` + `` | ✘ | +| Detailed Caption + Phrase Grounding | `` + `` | ✘ | +| More Detailed Caption + Phrase Grounding | `` + `` | ✘ | + +You can try the following scripts to run these demo: + +**Caption to Phrase Grounding** +```bash +python grounded_sam2_florence2_autolabel_pipeline.py \ + --image_path ./notebooks/images/groceries.jpg \ + --pipeline caption_to_phrase_grounding \ + --caption_type caption +``` + +- You can specify `caption_type` to control the granularity of the caption, if you want a more detailed caption, you can try `--caption_type detailed_caption` or `--caption_type more_detailed_caption`. + ### Citation If you find this project helpful for your research, please consider citing the following BibTeX entry. diff --git a/grounded_sam2_florence2_autolabel_pipeline.py b/grounded_sam2_florence2_autolabel_pipeline.py new file mode 100644 index 0000000..91756e9 --- /dev/null +++ b/grounded_sam2_florence2_autolabel_pipeline.py @@ -0,0 +1,198 @@ +import os +import cv2 +import torch +import argparse +import numpy as np +import supervision as sv +from PIL import Image +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from transformers import AutoProcessor, AutoModelForCausalLM +from utils.supervision_utils import CUSTOM_COLOR_MAP + +""" +Define Some Hyperparam +""" + +TASK_PROMPT = { + "caption": "", + "detailed_caption": "", + "more_detailed_caption": "", + "dense_region_caption": "", + "region_proposal": "", + "phrase_grounding": "", + "referring_expression_segmentation": "", + "region_to_segmentation": "", + "open_vocabulary_detection": "", + "region_to_category": "", + "region_to_description": "", + "ocr": "", + "ocr_with_region": "", +} + +OUTPUT_DIR = "./outputs" + +if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR, exist_ok=True) + +""" +Init Florence-2 and SAM 2 Model +""" + +FLORENCE2_MODEL_ID = "microsoft/Florence-2-large" +SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt" +SAM2_CONFIG = "sam2_hiera_l.yaml" + +# environment settings +# use bfloat16 +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +# build florence-2 +florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device) +florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True) + +# build sam 2 +sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device) +sam2_predictor = SAM2ImagePredictor(sam2_model) + +def run_florence2(task_prompt, text_input, model, processor, image): + assert model is not None, "You should pass the init florence-2 model here" + assert processor is not None, "You should set florence-2 processor here" + + device = model.device + + if text_input is None: + prompt = task_prompt + else: + prompt = task_prompt + text_input + + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16) + generated_ids = model.generate( + input_ids=inputs["input_ids"].to(device), + pixel_values=inputs["pixel_values"].to(device), + max_new_tokens=1024, + early_stopping=False, + do_sample=False, + num_beams=3, + ) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = processor.post_process_generation( + generated_text, + task=task_prompt, + image_size=(image.width, image.height) + ) + return parsed_answer + + +""" +We try to support a series of cascaded auto-labelling pipelines with Florence-2 and SAM 2 +""" + +""" +Auto-Labelling Pipeline: Caption/Detailed Caption/More Detailed Caption + Phrase Grounding + Segmentation +""" +def caption_phrase_grounding_and_segmentation( + florence2_model, + florence2_processor, + sam2_predictor, + image_path, + caption_task_prompt='', + output_dir=OUTPUT_DIR +): + assert caption_task_prompt in ["", "", ""] + image = Image.open(image_path).convert("RGB") + + # image caption + caption_results = run_florence2(caption_task_prompt, None, florence2_model, florence2_processor, image) + text_input = caption_results[caption_task_prompt] + print(f'Image caption for "{image_path}": ', text_input) + + # phrase grounding + grounding_results = run_florence2('', text_input, florence2_model, florence2_processor, image) + grounding_results = grounding_results[''] + + # parse florence-2 detection results + input_boxes = np.array(grounding_results["bboxes"]) + class_names = grounding_results["labels"] + class_ids = np.array(list(range(len(class_names)))) + + # predict mask with SAM 2 + sam2_predictor.set_image(np.array(image)) + masks, scores, logits = sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + + if masks.ndim == 4: + masks = masks.squeeze(1) + + # specify labels + labels = [ + f"{class_name}" for class_name in class_names + ] + + # visualization results + img = cv2.imread(image_path) + detections = sv.Detections( + xyxy=input_boxes, + mask=masks.astype(bool), + class_id=class_ids + ) + + box_annotator = sv.BoxAnnotator() + annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) + + label_annotator = sv.LabelAnnotator() + annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) + cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling.jpg"), annotated_frame) + + mask_annotator = sv.MaskAnnotator() + annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) + cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling_with_mask.jpg"), annotated_frame) + + print(f'Successfully save annotated image to "{output_dir}"') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True) + parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file") + parser.add_argument("--pipeline", type=str, default="caption_to_phrase_grounding", required=True, help="pipeline to use") + parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption") + args = parser.parse_args() + + CAPTION_TO_TASK_PROMPT = { + "caption": "", + "detailed_caption": "", + "more_detailed_caption": "" + } + + IMAGE_PATH = args.image_path + PIPELINE = args.pipeline + CAPTION_TYPE = args.caption_type + assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption"] + + print(f"Running pipeline: {PIPELINE} now.") + + if PIPELINE == "caption_to_phrase_grounding": + # pipeline-1: caption + phrase grounding + segmentation + caption_phrase_grounding_and_segmentation( + florence2_model=florence2_model, + florence2_processor=florence2_processor, + sam2_predictor=sam2_predictor, + caption_task_prompt=CAPTION_TO_TASK_PROMPT[CAPTION_TYPE], + image_path=IMAGE_PATH + ) + else: + raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time") \ No newline at end of file