2024-08-01 17:05:01 +08:00
|
|
|
import cv2
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import supervision as sv
|
2024-08-06 01:59:27 +08:00
|
|
|
from supervision.draw.color import ColorPalette
|
2024-08-09 02:33:24 +02:00
|
|
|
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
2024-08-01 17:05:01 +08:00
|
|
|
from PIL import Image
|
|
|
|
from sam2.build_sam import build_sam2
|
|
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
|
|
|
|
|
|
# environment settings
|
|
|
|
# use bfloat16
|
|
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
|
|
|
|
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
|
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
# build SAM2 image predictor
|
|
|
|
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
|
|
|
model_cfg = "sam2_hiera_l.yaml"
|
|
|
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
|
|
|
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
|
|
|
|
|
|
|
# build grounding dino from huggingface
|
|
|
|
model_id = "IDEA-Research/grounding-dino-tiny"
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
|
|
|
# VERY important: text queries need to be lowercased + end with a dot
|
|
|
|
text = "car. tire."
|
|
|
|
img_path = 'notebooks/images/truck.jpg'
|
|
|
|
|
|
|
|
image = Image.open(img_path)
|
|
|
|
|
|
|
|
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
|
|
|
|
|
|
|
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = grounding_model(**inputs)
|
|
|
|
|
|
|
|
results = processor.post_process_grounded_object_detection(
|
|
|
|
outputs,
|
|
|
|
inputs.input_ids,
|
|
|
|
box_threshold=0.4,
|
|
|
|
text_threshold=0.3,
|
|
|
|
target_sizes=[image.size[::-1]]
|
|
|
|
)
|
|
|
|
|
|
|
|
"""
|
|
|
|
Results is a list of dict with the following structure:
|
|
|
|
[
|
|
|
|
{
|
|
|
|
'scores': tensor([0.7969, 0.6469, 0.6002, 0.4220], device='cuda:0'),
|
|
|
|
'labels': ['car', 'tire', 'tire', 'tire'],
|
|
|
|
'boxes': tensor([[ 89.3244, 278.6940, 1710.3505, 851.5143],
|
|
|
|
[1392.4701, 554.4064, 1628.6133, 777.5872],
|
|
|
|
[ 436.1182, 621.8940, 676.5255, 851.6897],
|
|
|
|
[1236.0990, 688.3547, 1400.2427, 753.1256]], device='cuda:0')
|
|
|
|
}
|
|
|
|
]
|
|
|
|
"""
|
|
|
|
|
|
|
|
# get the box prompt for SAM 2
|
|
|
|
input_boxes = results[0]["boxes"].cpu().numpy()
|
|
|
|
|
|
|
|
masks, scores, logits = sam2_predictor.predict(
|
|
|
|
point_coords=None,
|
|
|
|
point_labels=None,
|
2024-08-01 17:58:42 +08:00
|
|
|
box=input_boxes,
|
2024-08-01 17:05:01 +08:00
|
|
|
multimask_output=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
Post-process the output of the model to get the masks, scores, and logits for visualization
|
|
|
|
"""
|
|
|
|
# convert the shape to (n, H, W)
|
2024-08-09 01:54:40 +08:00
|
|
|
if masks.ndim == 4:
|
2024-08-01 17:05:01 +08:00
|
|
|
masks = masks.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
confidences = results[0]["scores"].cpu().numpy().tolist()
|
|
|
|
class_names = results[0]["labels"]
|
2024-08-06 01:59:27 +08:00
|
|
|
class_ids = np.array(list(range(len(class_names))))
|
2024-08-01 17:05:01 +08:00
|
|
|
|
|
|
|
labels = [
|
|
|
|
f"{class_name} {confidence:.2f}"
|
|
|
|
for class_name, confidence
|
|
|
|
in zip(class_names, confidences)
|
|
|
|
]
|
|
|
|
|
|
|
|
"""
|
|
|
|
Visualize image with supervision useful API
|
|
|
|
"""
|
|
|
|
img = cv2.imread(img_path)
|
|
|
|
detections = sv.Detections(
|
|
|
|
xyxy=input_boxes, # (n, 4)
|
2024-08-06 01:59:27 +08:00
|
|
|
mask=masks.astype(bool), # (n, h, w)
|
|
|
|
class_id=class_ids
|
2024-08-01 17:05:01 +08:00
|
|
|
)
|
2024-08-01 21:30:56 +08:00
|
|
|
|
2024-08-06 01:59:27 +08:00
|
|
|
"""
|
|
|
|
Note that if you want to use default color map,
|
|
|
|
you can set color=ColorPalette.DEFAULT
|
|
|
|
"""
|
|
|
|
box_annotator = sv.BoxAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
|
|
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
|
|
|
|
|
|
|
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
|
|
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
2024-08-01 17:05:01 +08:00
|
|
|
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
|
|
|
|
2024-08-06 01:59:27 +08:00
|
|
|
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
2024-08-01 17:05:01 +08:00
|
|
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
|
|
|
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|