Files
Grounded-SAM-2/grounded_sam2_local_demo.py

106 lines
3.1 KiB
Python
Raw Normal View History

import cv2
import torch
import numpy as np
import supervision as sv
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
# environment settings
# use bfloat16
# build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_predictor = SAM2ImagePredictor(sam2_model)
# build grounding dino model
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
grounding_model = load_model(
model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py",
model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth",
device=device
)
# setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot
text = "car. tire."
img_path = 'notebooks/images/truck.jpg'
image_source, image = load_image(img_path)
sam2_predictor.set_image(image_source)
boxes, confidences, labels = predict(
model=grounding_model,
image=image,
caption=text,
box_threshold=0.35,
text_threshold=0.25
)
# process the box prompt for SAM 2
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
import pdb; pdb.set_trace()
"""
Post-process the output of the model to get the masks, scores, and logits for visualization
"""
# convert the shape to (n, H, W)
if masks.ndim == 3:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
confidences = confidences.numpy().tolist()
class_names = labels
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence
in zip(class_names, confidences)
]
"""
Visualize image with supervision useful API
"""
img = cv2.imread(img_path)
detections = sv.Detections(
xyxy=input_boxes, # (n, 4)
mask=masks, # (n, h, w)
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)