support dump results in 1.5 image demo

This commit is contained in:
rentainhe
2024-08-31 20:22:17 +08:00
parent 5d27e4f4f4
commit 4f3adf3222

View File

@@ -6,19 +6,38 @@ from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget from dds_cloudapi_sdk import DetectionTarget
import os
import cv2 import cv2
import json
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from PIL import Image from PIL import Image
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
"""
Hyper parameters
"""
API_TOKEN = "Your API token"
TEXT_PROMPT = "car"
IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
DUMP_JSON_RESULTS = True
# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
""" """
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
""" """
# Step 1: initialize the config # Step 1: initialize the config
token = "Your API token" token = API_TOKEN
config = Config(token) config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
@@ -27,14 +46,14 @@ client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
img_path = "notebooks/images/cars.jpg" img_path = IMG_PATH
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = DetectionTask(
image_url=image_url, image_url=image_url,
prompts=[TextPrompt(text="car")], prompts=[TextPrompt(text=TEXT_PROMPT)],
targets=[DetectionTarget.BBox], # detect bbox targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
) )
client.run_task(task) client.run_task(task)
@@ -68,8 +87,8 @@ if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
# build SAM2 image predictor # build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = "sam2_hiera_l.yaml" model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
@@ -120,8 +139,43 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
label_annotator = sv.LabelAnnotator() label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator() mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
"""
Dump the results in standard format and save as json files
"""
def single_mask_to_rle(mask):
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
if DUMP_JSON_RESULTS:
# convert mask into rle format
mask_rles = [single_mask_to_rle(mask) for mask in masks]
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# save the results in standard format
results = {
"image_path": img_path,
"annotations" : [
{
"class_name": class_name,
"bbox": box,
"segmentation": mask_rle,
"score": score,
}
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
],
"box_format": "xyxy",
"img_width": image.width,
"img_height": image.height,
}
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
json.dump(results, f, indent=4)