2024-08-01 21:30:56 +08:00
|
|
|
# dds cloudapi for Grounding DINO 1.5
|
|
|
|
from dds_cloudapi_sdk import Config
|
|
|
|
from dds_cloudapi_sdk import Client
|
|
|
|
from dds_cloudapi_sdk import DetectionTask
|
|
|
|
from dds_cloudapi_sdk import TextPrompt
|
|
|
|
from dds_cloudapi_sdk import DetectionModel
|
|
|
|
from dds_cloudapi_sdk import DetectionTarget
|
|
|
|
|
2024-08-31 20:22:17 +08:00
|
|
|
import os
|
2024-08-01 21:30:56 +08:00
|
|
|
import cv2
|
2024-08-31 20:22:17 +08:00
|
|
|
import json
|
2024-08-01 21:30:56 +08:00
|
|
|
import torch
|
2024-10-24 16:57:04 +08:00
|
|
|
import tempfile
|
2024-08-01 21:30:56 +08:00
|
|
|
import numpy as np
|
|
|
|
import supervision as sv
|
2024-08-31 20:22:17 +08:00
|
|
|
import pycocotools.mask as mask_util
|
|
|
|
from pathlib import Path
|
2024-08-01 21:30:56 +08:00
|
|
|
from PIL import Image
|
|
|
|
from sam2.build_sam import build_sam2
|
|
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
|
|
2024-08-31 20:22:17 +08:00
|
|
|
"""
|
|
|
|
Hyper parameters
|
|
|
|
"""
|
|
|
|
API_TOKEN = "Your API token"
|
2024-08-31 20:40:59 +08:00
|
|
|
TEXT_PROMPT = "car . building ."
|
2024-08-31 20:22:17 +08:00
|
|
|
IMG_PATH = "notebooks/images/cars.jpg"
|
2024-10-10 14:55:50 +08:00
|
|
|
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
|
|
|
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
2024-08-31 20:22:17 +08:00
|
|
|
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
2024-10-24 16:57:04 +08:00
|
|
|
WITH_SLICE_INFERENCE = False
|
|
|
|
SLICE_WH = (480, 480)
|
|
|
|
OVERLAP_RATIO = (0.2, 0.2)
|
2024-08-31 20:40:59 +08:00
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
2024-08-31 20:22:17 +08:00
|
|
|
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
|
|
|
|
DUMP_JSON_RESULTS = True
|
|
|
|
|
|
|
|
# create output directory
|
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
2024-08-01 21:30:56 +08:00
|
|
|
"""
|
|
|
|
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
|
|
|
|
"""
|
|
|
|
# Step 1: initialize the config
|
2024-08-31 20:22:17 +08:00
|
|
|
token = API_TOKEN
|
2024-08-01 21:30:56 +08:00
|
|
|
config = Config(token)
|
|
|
|
|
|
|
|
# Step 2: initialize the client
|
|
|
|
client = Client(config)
|
|
|
|
|
|
|
|
# Step 3: run the task by DetectionTask class
|
|
|
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
|
|
|
# if you are processing local image file, upload them to DDS server to get the image url
|
|
|
|
|
2024-10-24 16:57:04 +08:00
|
|
|
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
|
|
|
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
|
|
|
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
|
|
|
|
|
|
|
if WITH_SLICE_INFERENCE:
|
|
|
|
def callback(image_slice: np.ndarray) -> sv.Detections:
|
|
|
|
print("Inference on image slice")
|
|
|
|
# save the img as temp img file for GD-1.5 API usage
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
|
|
|
temp_filename = tmpfile.name
|
|
|
|
cv2.imwrite(temp_filename, image_slice)
|
|
|
|
image_url = client.upload_file(temp_filename)
|
|
|
|
task = DetectionTask(
|
|
|
|
image_url=image_url,
|
|
|
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
|
|
|
targets=[DetectionTarget.BBox], # detect bbox
|
|
|
|
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
|
|
|
)
|
|
|
|
client.run_task(task)
|
|
|
|
result = task.result
|
|
|
|
# detele the tempfile
|
|
|
|
os.remove(temp_filename)
|
|
|
|
|
|
|
|
input_boxes = []
|
|
|
|
confidences = []
|
|
|
|
class_ids = []
|
|
|
|
objects = result.objects
|
|
|
|
for idx, obj in enumerate(objects):
|
|
|
|
input_boxes.append(obj.bbox)
|
|
|
|
confidences.append(obj.score)
|
|
|
|
cls_name = obj.category.lower().strip()
|
|
|
|
class_ids.append(class_name_to_id[cls_name])
|
|
|
|
# ensure input_boxes with shape (_, 4)
|
|
|
|
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
|
|
|
class_ids = np.array(class_ids)
|
|
|
|
confidences = np.array(confidences)
|
|
|
|
return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
|
|
|
|
|
|
|
|
slicer = sv.InferenceSlicer(
|
|
|
|
callback=callback,
|
|
|
|
slice_wh=SLICE_WH,
|
|
|
|
overlap_ratio_wh=OVERLAP_RATIO,
|
|
|
|
iou_threshold=0.5,
|
|
|
|
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
|
|
|
)
|
|
|
|
detections = slicer(cv2.imread(IMG_PATH))
|
|
|
|
class_names = [class_id_to_name[id] for id in detections.class_id]
|
|
|
|
confidences = detections.confidence
|
|
|
|
class_ids = detections.class_id
|
|
|
|
import pdb; pdb.set_trace()
|
|
|
|
input_boxes = detections.xyxy
|
|
|
|
else:
|
|
|
|
image_url = client.upload_file(IMG_PATH)
|
|
|
|
|
|
|
|
task = DetectionTask(
|
|
|
|
image_url=image_url,
|
|
|
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
|
|
|
targets=[DetectionTarget.BBox], # detect bbox
|
|
|
|
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
|
|
|
)
|
|
|
|
|
|
|
|
client.run_task(task)
|
|
|
|
result = task.result
|
|
|
|
|
|
|
|
objects = result.objects # the list of detected objects
|
|
|
|
|
|
|
|
|
|
|
|
input_boxes = []
|
|
|
|
confidences = []
|
|
|
|
class_names = []
|
|
|
|
class_ids = []
|
|
|
|
|
|
|
|
for idx, obj in enumerate(objects):
|
|
|
|
input_boxes.append(obj.bbox)
|
|
|
|
confidences.append(obj.score)
|
|
|
|
cls_name = obj.category.lower().strip()
|
|
|
|
class_names.append(cls_name)
|
|
|
|
class_ids.append(class_name_to_id[cls_name])
|
|
|
|
|
|
|
|
input_boxes = np.array(input_boxes)
|
|
|
|
class_ids = np.array(class_ids)
|
2024-08-01 21:30:56 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
Init SAM 2 Model and Predict Mask with Box Prompt
|
|
|
|
"""
|
|
|
|
|
|
|
|
# environment settings
|
|
|
|
# use bfloat16
|
2024-08-31 20:40:59 +08:00
|
|
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
2024-08-01 21:30:56 +08:00
|
|
|
|
|
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
|
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
# build SAM2 image predictor
|
2024-08-31 20:22:17 +08:00
|
|
|
sam2_checkpoint = SAM2_CHECKPOINT
|
|
|
|
model_cfg = SAM2_MODEL_CONFIG
|
2024-08-31 20:40:59 +08:00
|
|
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
2024-08-01 21:30:56 +08:00
|
|
|
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
|
|
|
|
2024-10-24 16:57:04 +08:00
|
|
|
image = Image.open(IMG_PATH)
|
2024-08-01 21:30:56 +08:00
|
|
|
|
|
|
|
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
|
|
|
|
|
|
|
masks, scores, logits = sam2_predictor.predict(
|
|
|
|
point_coords=None,
|
|
|
|
point_labels=None,
|
|
|
|
box=input_boxes,
|
|
|
|
multimask_output=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
Post-process the output of the model to get the masks, scores, and logits for visualization
|
|
|
|
"""
|
|
|
|
# convert the shape to (n, H, W)
|
2024-08-09 01:54:40 +08:00
|
|
|
if masks.ndim == 4:
|
2024-08-01 21:30:56 +08:00
|
|
|
masks = masks.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
Visualization the Predict Results
|
|
|
|
"""
|
|
|
|
|
|
|
|
labels = [
|
|
|
|
f"{class_name} {confidence:.2f}"
|
|
|
|
for class_name, confidence
|
|
|
|
in zip(class_names, confidences)
|
|
|
|
]
|
2024-08-06 01:59:27 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
Visualize image with supervision useful API
|
|
|
|
"""
|
2024-10-24 16:57:04 +08:00
|
|
|
img = cv2.imread(IMG_PATH)
|
2024-08-01 21:30:56 +08:00
|
|
|
detections = sv.Detections(
|
|
|
|
xyxy=input_boxes, # (n, 4)
|
2024-08-06 01:59:27 +08:00
|
|
|
mask=masks.astype(bool), # (n, h, w)
|
|
|
|
class_id=class_ids
|
2024-08-01 21:30:56 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
box_annotator = sv.BoxAnnotator()
|
2024-08-06 01:59:27 +08:00
|
|
|
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
|
|
|
|
|
|
|
|
label_annotator = sv.LabelAnnotator()
|
|
|
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
2024-08-31 20:22:17 +08:00
|
|
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
2024-08-01 21:30:56 +08:00
|
|
|
|
|
|
|
mask_annotator = sv.MaskAnnotator()
|
|
|
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
2024-08-31 20:22:17 +08:00
|
|
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
|
|
|
|
2024-10-23 11:27:18 +08:00
|
|
|
print(f'Annotated image has already been saved as to "{OUTPUT_DIR}"')
|
|
|
|
|
2024-08-31 20:22:17 +08:00
|
|
|
"""
|
|
|
|
Dump the results in standard format and save as json files
|
|
|
|
"""
|
|
|
|
|
|
|
|
def single_mask_to_rle(mask):
|
|
|
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
|
|
|
rle["counts"] = rle["counts"].decode("utf-8")
|
|
|
|
return rle
|
|
|
|
|
|
|
|
if DUMP_JSON_RESULTS:
|
2024-10-23 11:27:18 +08:00
|
|
|
print("Start dumping the annotation...")
|
2024-08-31 20:22:17 +08:00
|
|
|
# convert mask into rle format
|
|
|
|
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
|
|
|
|
|
|
|
input_boxes = input_boxes.tolist()
|
|
|
|
scores = scores.tolist()
|
2024-08-31 20:40:59 +08:00
|
|
|
# FIXME: class_names should be a list of strings without spaces
|
|
|
|
class_names = [class_name.strip() for class_name in class_names]
|
2024-08-31 20:22:17 +08:00
|
|
|
# save the results in standard format
|
|
|
|
results = {
|
2024-10-24 16:57:04 +08:00
|
|
|
"image_path": IMG_PATH,
|
2024-08-31 20:22:17 +08:00
|
|
|
"annotations" : [
|
|
|
|
{
|
|
|
|
"class_name": class_name,
|
|
|
|
"bbox": box,
|
|
|
|
"segmentation": mask_rle,
|
|
|
|
"score": score,
|
|
|
|
}
|
|
|
|
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
|
|
|
],
|
|
|
|
"box_format": "xyxy",
|
|
|
|
"img_width": image.width,
|
|
|
|
"img_height": image.height,
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
|
|
|
|
json.dump(results, f, indent=4)
|
2024-10-23 11:27:18 +08:00
|
|
|
|
|
|
|
print(f'Annotation has already been saved to "{OUTPUT_DIR}"')
|