support auto label pipeline with florence-2
This commit is contained in:
28
README.md
28
README.md
@@ -28,10 +28,12 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
|
||||
- [Grounded SAM 2 Video Object Tracking with Custom Video Input (using Grounding DINO 1.5 & 1.6)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-grounding-dino-15--16)
|
||||
- [Grounded SAM 2 Video Object Tracking with Continues ID (using Grounding DINO)](#grounded-sam-2-video-object-tracking-with-continuous-id-with-grounding-dino)
|
||||
- [Grounded SAM 2 Florence-2 Demos](#grounded-sam-2-florence-2-demos)
|
||||
- [Grounded SAM 2 Florence-2 Image Demo (Updating)](#grounded-sam-2-florence-2-image-demo-updating)
|
||||
- [Grounded SAM 2 Florence-2 Image Demo](#grounded-sam-2-florence-2-image-demo-updating)
|
||||
- [Grounded SAM 2 Florence-2 Image Auto-Labeling Demo](#grounded-sam-2-florence-2-image-auto-labeling-demo)
|
||||
- [Citation](#citation)
|
||||
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Download the pretrained `SAM 2` checkpoints:
|
||||
@@ -231,7 +233,7 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
|
||||
```
|
||||
|
||||
## Grounded SAM 2 Florence-2 Demos
|
||||
### Grounded SAM 2 Florence-2 Image Demo (Updating)
|
||||
### Grounded SAM 2 Florence-2 Image Demo
|
||||
|
||||
In this section, we will explore how to integrate the feature-rich and robust open-source models [Florence-2](https://arxiv.org/abs/2311.06242) and SAM 2 to develop practical applications.
|
||||
|
||||
@@ -244,6 +246,7 @@ In this section, we will explore how to integrate the feature-rich and robust op
|
||||
| Region Proposal | `<REGION_PROPOSAL>` | ✘ | Generate proposals without category name |
|
||||
| Phrase Grounding | `<CAPTION_TO_PHRASE_GROUNDING>` | ✔ | Ground main objects in image mentioned in caption |
|
||||
| Referring Expression Segmentation | `<REFERRING_EXPRESSION_SEGMENTATION>` | ✔ | Ground the object which is most related to the text input |
|
||||
| Open Vocabulary Detection and Segmentation | `<OPEN_VOCABULARY_DETECTION>` | ✔ | Ground any object with text input |
|
||||
|
||||
|
||||
Integrate `Florence-2` with `SAM-2`, we can build a strong vision pipeline to solve complex vision tasks, you can try the following scripts to run the demo:
|
||||
@@ -298,6 +301,27 @@ python grounded_sam2_florence2_image_demo.py \
|
||||
--text_input "two cars"
|
||||
```
|
||||
|
||||
### Grounded SAM 2 Florence-2 Image Auto-Labeling Demo
|
||||
`Florence-2` can be used as a auto image annotator by cascading its caption capability with its grounding capability.
|
||||
|
||||
| Task | Task Prompt | Text Input |
|
||||
|:---:|:---:|:---:|
|
||||
| Caption + Phrase Grounding | `<CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
||||
| Detailed Caption + Phrase Grounding | `<DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
||||
| More Detailed Caption + Phrase Grounding | `<MORE_DETAILED_CAPTION>` + `<CAPTION_TO_PHRASE_GROUNDING>` | ✘ |
|
||||
|
||||
You can try the following scripts to run these demo:
|
||||
|
||||
**Caption to Phrase Grounding**
|
||||
```bash
|
||||
python grounded_sam2_florence2_autolabel_pipeline.py \
|
||||
--image_path ./notebooks/images/groceries.jpg \
|
||||
--pipeline caption_to_phrase_grounding \
|
||||
--caption_type caption
|
||||
```
|
||||
|
||||
- You can specify `caption_type` to control the granularity of the caption, if you want a more detailed caption, you can try `--caption_type detailed_caption` or `--caption_type more_detailed_caption`.
|
||||
|
||||
### Citation
|
||||
|
||||
If you find this project helpful for your research, please consider citing the following BibTeX entry.
|
||||
|
198
grounded_sam2_florence2_autolabel_pipeline.py
Normal file
198
grounded_sam2_florence2_autolabel_pipeline.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
||||
|
||||
"""
|
||||
Define Some Hyperparam
|
||||
"""
|
||||
|
||||
TASK_PROMPT = {
|
||||
"caption": "<CAPTION>",
|
||||
"detailed_caption": "<DETAILED_CAPTION>",
|
||||
"more_detailed_caption": "<MORE_DETAILED_CAPTION",
|
||||
"object_detection": "<OD>",
|
||||
"dense_region_caption": "<DENSE_REGION_CAPTION>",
|
||||
"region_proposal": "<REGION_PROPOSAL>",
|
||||
"phrase_grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
||||
"referring_expression_segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
"region_to_segmentation": "<REGION_TO_SEGMENTATION>",
|
||||
"open_vocabulary_detection": "<OPEN_VOCABULARY_DETECTION>",
|
||||
"region_to_category": "<REGION_TO_CATEGORY>",
|
||||
"region_to_description": "<REGION_TO_DESCRIPTION>",
|
||||
"ocr": "<OCR>",
|
||||
"ocr_with_region": "<OCR_WITH_REGION>",
|
||||
}
|
||||
|
||||
OUTPUT_DIR = "./outputs"
|
||||
|
||||
if not os.path.exists(OUTPUT_DIR):
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
"""
|
||||
Init Florence-2 and SAM 2 Model
|
||||
"""
|
||||
|
||||
FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
|
||||
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||
SAM2_CONFIG = "sam2_hiera_l.yaml"
|
||||
|
||||
# environment settings
|
||||
# use bfloat16
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
# build florence-2
|
||||
florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
|
||||
florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
|
||||
|
||||
# build sam 2
|
||||
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
|
||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||
|
||||
def run_florence2(task_prompt, text_input, model, processor, image):
|
||||
assert model is not None, "You should pass the init florence-2 model here"
|
||||
assert processor is not None, "You should set florence-2 processor here"
|
||||
|
||||
device = model.device
|
||||
|
||||
if text_input is None:
|
||||
prompt = task_prompt
|
||||
else:
|
||||
prompt = task_prompt + text_input
|
||||
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
|
||||
generated_ids = model.generate(
|
||||
input_ids=inputs["input_ids"].to(device),
|
||||
pixel_values=inputs["pixel_values"].to(device),
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
||||
parsed_answer = processor.post_process_generation(
|
||||
generated_text,
|
||||
task=task_prompt,
|
||||
image_size=(image.width, image.height)
|
||||
)
|
||||
return parsed_answer
|
||||
|
||||
|
||||
"""
|
||||
We try to support a series of cascaded auto-labelling pipelines with Florence-2 and SAM 2
|
||||
"""
|
||||
|
||||
"""
|
||||
Auto-Labelling Pipeline: Caption/Detailed Caption/More Detailed Caption + Phrase Grounding + Segmentation
|
||||
"""
|
||||
def caption_phrase_grounding_and_segmentation(
|
||||
florence2_model,
|
||||
florence2_processor,
|
||||
sam2_predictor,
|
||||
image_path,
|
||||
caption_task_prompt='<CAPTION>',
|
||||
output_dir=OUTPUT_DIR
|
||||
):
|
||||
assert caption_task_prompt in ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>"]
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# image caption
|
||||
caption_results = run_florence2(caption_task_prompt, None, florence2_model, florence2_processor, image)
|
||||
text_input = caption_results[caption_task_prompt]
|
||||
print(f'Image caption for "{image_path}": ', text_input)
|
||||
|
||||
# phrase grounding
|
||||
grounding_results = run_florence2('<CAPTION_TO_PHRASE_GROUNDING>', text_input, florence2_model, florence2_processor, image)
|
||||
grounding_results = grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']
|
||||
|
||||
# parse florence-2 detection results
|
||||
input_boxes = np.array(grounding_results["bboxes"])
|
||||
class_names = grounding_results["labels"]
|
||||
class_ids = np.array(list(range(len(class_names))))
|
||||
|
||||
# predict mask with SAM 2
|
||||
sam2_predictor.set_image(np.array(image))
|
||||
masks, scores, logits = sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
if masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
# specify labels
|
||||
labels = [
|
||||
f"{class_name}" for class_name in class_names
|
||||
]
|
||||
|
||||
# visualization results
|
||||
img = cv2.imread(image_path)
|
||||
detections = sv.Detections(
|
||||
xyxy=input_boxes,
|
||||
mask=masks.astype(bool),
|
||||
class_id=class_ids
|
||||
)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
||||
|
||||
label_annotator = sv.LabelAnnotator()
|
||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling.jpg"), annotated_frame)
|
||||
|
||||
mask_annotator = sv.MaskAnnotator()
|
||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||
cv2.imwrite(os.path.join(output_dir, "grounded_sam2_florence2_auto_labelling_with_mask.jpg"), annotated_frame)
|
||||
|
||||
print(f'Successfully save annotated image to "{output_dir}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True)
|
||||
parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file")
|
||||
parser.add_argument("--pipeline", type=str, default="caption_to_phrase_grounding", required=True, help="pipeline to use")
|
||||
parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption")
|
||||
args = parser.parse_args()
|
||||
|
||||
CAPTION_TO_TASK_PROMPT = {
|
||||
"caption": "<CAPTION>",
|
||||
"detailed_caption": "<DETAILED_CAPTION>",
|
||||
"more_detailed_caption": "<MORE_DETAILED_CAPTION>"
|
||||
}
|
||||
|
||||
IMAGE_PATH = args.image_path
|
||||
PIPELINE = args.pipeline
|
||||
CAPTION_TYPE = args.caption_type
|
||||
assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption"]
|
||||
|
||||
print(f"Running pipeline: {PIPELINE} now.")
|
||||
|
||||
if PIPELINE == "caption_to_phrase_grounding":
|
||||
# pipeline-1: caption + phrase grounding + segmentation
|
||||
caption_phrase_grounding_and_segmentation(
|
||||
florence2_model=florence2_model,
|
||||
florence2_processor=florence2_processor,
|
||||
sam2_predictor=sam2_predictor,
|
||||
caption_task_prompt=CAPTION_TO_TASK_PROMPT[CAPTION_TYPE],
|
||||
image_path=IMAGE_PATH
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Pipeline: {args.pipeline} is not implemented at this time")
|
Reference in New Issue
Block a user