From f5b99eea3df6a7c06e2c700341f68d55dd39953b Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Thu, 1 Aug 2024 17:58:42 +0800 Subject: [PATCH] support gdino local model (load local ckpt) --- README.md | 29 +++++++++ grounded_sam2_hf_model_demo.py | 2 +- grounded_sam2_local_demo.py | 105 +++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 grounded_sam2_local_demo.py diff --git a/README.md b/README.md index 38ddca3..d93a7d0 100644 --- a/README.md +++ b/README.md @@ -24,3 +24,32 @@ Install `grounding dino`: ```bash pip install --no-build-isolation -e grounding_dino ``` + +Download the pretrained `grounding dino` and `sam 2` checkpoints: + +```bash +cd checkpoints +bash download_ckpts.sh +``` + +```bash +cd gdino_checkpoints +wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth +wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth +``` + +## Run demo +### Grounded-SAM-2 Image Demo + +Note that `Grounding DINO` has already been supported in [Huggingface](https://huggingface.co/IDEA-Research/grounding-dino-tiny), so we provide two choices for running `Grounded-SAM-2` model: +- Use huggingface API to inference Grounding DINO (which is simple and clear) + +```bash +python grounded_sam2_hf_model_demo.py +``` + +- Load local pretrained Grounding DINO checkpoint and inference with Grounding DINO original API (make sure you've already downloaded the pretrained checkpoint) + +```bash +python grounded_sam2_local_demo.py +``` diff --git a/grounded_sam2_hf_model_demo.py b/grounded_sam2_hf_model_demo.py index e3b8be6..e92ec9b 100644 --- a/grounded_sam2_hf_model_demo.py +++ b/grounded_sam2_hf_model_demo.py @@ -69,8 +69,8 @@ input_boxes = results[0]["boxes"].cpu().numpy() masks, scores, logits = sam2_predictor.predict( point_coords=None, - box=input_boxes, point_labels=None, + box=input_boxes, multimask_output=False, ) diff --git a/grounded_sam2_local_demo.py b/grounded_sam2_local_demo.py new file mode 100644 index 0000000..ac03b2d --- /dev/null +++ b/grounded_sam2_local_demo.py @@ -0,0 +1,105 @@ +import cv2 +import torch +import numpy as np +import supervision as sv +from torchvision.ops import box_convert +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from grounding_dino.groundingdino.util.inference import load_model, load_image, predict + +# environment settings +# use bfloat16 + +# build SAM2 image predictor +sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") +sam2_predictor = SAM2ImagePredictor(sam2_model) + +# build grounding dino model +model_id = "IDEA-Research/grounding-dino-tiny" +device = "cuda" if torch.cuda.is_available() else "cpu" +grounding_model = load_model( + model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", + model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth", + device=device +) + + +# setup the input image and text prompt for SAM 2 and Grounding DINO +# VERY important: text queries need to be lowercased + end with a dot +text = "car. tire." +img_path = 'notebooks/images/truck.jpg' + +image_source, image = load_image(img_path) + +sam2_predictor.set_image(image_source) + +boxes, confidences, labels = predict( + model=grounding_model, + image=image, + caption=text, + box_threshold=0.35, + text_threshold=0.25 +) + +# process the box prompt for SAM 2 +h, w, _ = image_source.shape +boxes = boxes * torch.Tensor([w, h, w, h]) +input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + + +# FIXME: figure how does this influence the G-DINO model +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +masks, scores, logits = sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, +) + +import pdb; pdb.set_trace() + +""" +Post-process the output of the model to get the masks, scores, and logits for visualization +""" +# convert the shape to (n, H, W) +if masks.ndim == 3: + masks = masks[None] + scores = scores[None] + logits = logits[None] +elif masks.ndim == 4: + masks = masks.squeeze(1) + + +confidences = confidences.numpy().tolist() +class_names = labels + +labels = [ + f"{class_name} {confidence:.2f}" + for class_name, confidence + in zip(class_names, confidences) +] + +""" +Visualize image with supervision useful API +""" +img = cv2.imread(img_path) +detections = sv.Detections( + xyxy=input_boxes, # (n, 4) + mask=masks, # (n, h, w) + +) +box_annotator = sv.BoxAnnotator() +annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels) +cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) + +mask_annotator = sv.MaskAnnotator() +annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) +cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)