import cv2 import torch import numpy as np import supervision as sv from torchvision.ops import box_convert from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from grounding_dino.groundingdino.util.inference import load_model, load_image, predict # environment settings # use bfloat16 # build SAM2 image predictor sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_predictor = SAM2ImagePredictor(sam2_model) # build grounding dino model model_id = "IDEA-Research/grounding-dino-tiny" device = "cuda" if torch.cuda.is_available() else "cpu" grounding_model = load_model( model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth", device=device ) # setup the input image and text prompt for SAM 2 and Grounding DINO # VERY important: text queries need to be lowercased + end with a dot text = "car. tire." img_path = 'notebooks/images/truck.jpg' image_source, image = load_image(img_path) sam2_predictor.set_image(image_source) boxes, confidences, labels = predict( model=grounding_model, image=image, caption=text, box_threshold=0.35, text_threshold=0.25 ) # process the box prompt for SAM 2 h, w, _ = image_source.shape boxes = boxes * torch.Tensor([w, h, w, h]) input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() # FIXME: figure how does this influence the G-DINO model torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True masks, scores, logits = sam2_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) import pdb; pdb.set_trace() """ Post-process the output of the model to get the masks, scores, and logits for visualization """ # convert the shape to (n, H, W) if masks.ndim == 3: masks = masks[None] scores = scores[None] logits = logits[None] elif masks.ndim == 4: masks = masks.squeeze(1) confidences = confidences.numpy().tolist() class_names = labels labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence in zip(class_names, confidences) ] """ Visualize image with supervision useful API """ img = cv2.imread(img_path) detections = sv.Detections( xyxy=input_boxes, # (n, 4) mask=masks, # (n, h, w) ) box_annotator = sv.BoxAnnotator() annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels) cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) mask_annotator = sv.MaskAnnotator() annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)