import cv2 import torch import numpy as np import supervision as sv from supervision.draw.color import ColorPalette from utils.supervision_utils import CUSTOM_COLOR_MAP from PIL import Image from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection # environment settings # use bfloat16 torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # build SAM2 image predictor sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_predictor = SAM2ImagePredictor(sam2_model) # build grounding dino from huggingface model_id = "IDEA-Research/grounding-dino-tiny" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) # setup the input image and text prompt for SAM 2 and Grounding DINO # VERY important: text queries need to be lowercased + end with a dot text = "car. tire." img_path = 'notebooks/images/truck.jpg' image = Image.open(img_path) sam2_predictor.set_image(np.array(image.convert("RGB"))) inputs = processor(images=image, text=text, return_tensors="pt").to(device) with torch.no_grad(): outputs = grounding_model(**inputs) results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=0.4, text_threshold=0.3, target_sizes=[image.size[::-1]] ) """ Results is a list of dict with the following structure: [ { 'scores': tensor([0.7969, 0.6469, 0.6002, 0.4220], device='cuda:0'), 'labels': ['car', 'tire', 'tire', 'tire'], 'boxes': tensor([[ 89.3244, 278.6940, 1710.3505, 851.5143], [1392.4701, 554.4064, 1628.6133, 777.5872], [ 436.1182, 621.8940, 676.5255, 851.6897], [1236.0990, 688.3547, 1400.2427, 753.1256]], device='cuda:0') } ] """ # get the box prompt for SAM 2 input_boxes = results[0]["boxes"].cpu().numpy() masks, scores, logits = sam2_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) """ Post-process the output of the model to get the masks, scores, and logits for visualization """ # convert the shape to (n, H, W) if masks.ndim == 4: masks = masks.squeeze(1) confidences = results[0]["scores"].cpu().numpy().tolist() class_names = results[0]["labels"] class_ids = np.array(list(range(len(class_names)))) labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence in zip(class_names, confidences) ] """ Visualize image with supervision useful API """ img = cv2.imread(img_path) detections = sv.Detections( xyxy=input_boxes, # (n, 4) mask=masks.astype(bool), # (n, h, w) class_id=class_ids ) """ Note that if you want to use default color map, you can set color=ColorPalette.DEFAULT """ box_annotator = sv.BoxAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP)) annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections) label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP)) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP)) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)