support gdino local model (load local ckpt)
This commit is contained in:
29
README.md
29
README.md
@@ -24,3 +24,32 @@ Install `grounding dino`:
|
||||
```bash
|
||||
pip install --no-build-isolation -e grounding_dino
|
||||
```
|
||||
|
||||
Download the pretrained `grounding dino` and `sam 2` checkpoints:
|
||||
|
||||
```bash
|
||||
cd checkpoints
|
||||
bash download_ckpts.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
cd gdino_checkpoints
|
||||
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
|
||||
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth
|
||||
```
|
||||
|
||||
## Run demo
|
||||
### Grounded-SAM-2 Image Demo
|
||||
|
||||
Note that `Grounding DINO` has already been supported in [Huggingface](https://huggingface.co/IDEA-Research/grounding-dino-tiny), so we provide two choices for running `Grounded-SAM-2` model:
|
||||
- Use huggingface API to inference Grounding DINO (which is simple and clear)
|
||||
|
||||
```bash
|
||||
python grounded_sam2_hf_model_demo.py
|
||||
```
|
||||
|
||||
- Load local pretrained Grounding DINO checkpoint and inference with Grounding DINO original API (make sure you've already downloaded the pretrained checkpoint)
|
||||
|
||||
```bash
|
||||
python grounded_sam2_local_demo.py
|
||||
```
|
||||
|
@@ -69,8 +69,8 @@ input_boxes = results[0]["boxes"].cpu().numpy()
|
||||
|
||||
masks, scores, logits = sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
box=input_boxes,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
|
105
grounded_sam2_local_demo.py
Normal file
105
grounded_sam2_local_demo.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from torchvision.ops import box_convert
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
||||
|
||||
# environment settings
|
||||
# use bfloat16
|
||||
|
||||
# build SAM2 image predictor
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||
|
||||
# build grounding dino model
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
grounding_model = load_model(
|
||||
model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
||||
model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth",
|
||||
device=device
|
||||
)
|
||||
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "car. tire."
|
||||
img_path = 'notebooks/images/truck.jpg'
|
||||
|
||||
image_source, image = load_image(img_path)
|
||||
|
||||
sam2_predictor.set_image(image_source)
|
||||
|
||||
boxes, confidences, labels = predict(
|
||||
model=grounding_model,
|
||||
image=image,
|
||||
caption=text,
|
||||
box_threshold=0.35,
|
||||
text_threshold=0.25
|
||||
)
|
||||
|
||||
# process the box prompt for SAM 2
|
||||
h, w, _ = image_source.shape
|
||||
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
|
||||
|
||||
# FIXME: figure how does this influence the G-DINO model
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
masks, scores, logits = sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
"""
|
||||
Post-process the output of the model to get the masks, scores, and logits for visualization
|
||||
"""
|
||||
# convert the shape to (n, H, W)
|
||||
if masks.ndim == 3:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
|
||||
confidences = confidences.numpy().tolist()
|
||||
class_names = labels
|
||||
|
||||
labels = [
|
||||
f"{class_name} {confidence:.2f}"
|
||||
for class_name, confidence
|
||||
in zip(class_names, confidences)
|
||||
]
|
||||
|
||||
"""
|
||||
Visualize image with supervision useful API
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
detections = sv.Detections(
|
||||
xyxy=input_boxes, # (n, 4)
|
||||
mask=masks, # (n, h, w)
|
||||
|
||||
)
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
|
||||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
||||
|
||||
mask_annotator = sv.MaskAnnotator()
|
||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|
Reference in New Issue
Block a user