update DINO-X api to V2
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
# dds cloudapi for Grounding DINO 1.5
|
||||
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
||||
from dds_cloudapi_sdk import Config
|
||||
from dds_cloudapi_sdk import Client
|
||||
from dds_cloudapi_sdk import DetectionTask
|
||||
from dds_cloudapi_sdk import TextPrompt
|
||||
from dds_cloudapi_sdk import DetectionModel
|
||||
from dds_cloudapi_sdk import DetectionTarget
|
||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||
|
||||
import os
|
||||
import cv2
|
||||
@@ -27,8 +24,9 @@ TEXT_PROMPT = "car . building ."
|
||||
IMG_PATH = "notebooks/images/cars.jpg"
|
||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
||||
GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro
|
||||
BOX_THRESHOLD = 0.2
|
||||
IOU_THRESHOLD = 0.8
|
||||
WITH_SLICE_INFERENCE = False
|
||||
SLICE_WH = (480, 480)
|
||||
OVERLAP_RATIO = (0.2, 0.2)
|
||||
@@ -49,8 +47,7 @@ config = Config(token)
|
||||
# Step 2: initialize the client
|
||||
client = Client(config)
|
||||
|
||||
# Step 3: run the task by DetectionTask class
|
||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||
# Step 3: run the task using V2Task API
|
||||
# if you are processing local image file, upload them to DDS server to get the image url
|
||||
|
||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||
@@ -65,26 +62,33 @@ if WITH_SLICE_INFERENCE:
|
||||
temp_filename = tmpfile.name
|
||||
cv2.imwrite(temp_filename, image_slice)
|
||||
image_url = client.upload_file(temp_filename)
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
||||
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
||||
task = V2Task(
|
||||
api_path="/v2/task/grounding_dino/detection",
|
||||
api_body={
|
||||
"model": GROUNDING_MODEL,
|
||||
"image": image_url,
|
||||
"prompt": {
|
||||
"type": "text",
|
||||
"text": TEXT_PROMPT
|
||||
},
|
||||
"targets": ["bbox"],
|
||||
"bbox_threshold": BOX_THRESHOLD,
|
||||
"iou_threshold": IOU_THRESHOLD,
|
||||
}
|
||||
)
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
# detele the tempfile
|
||||
# delete the tempfile
|
||||
os.remove(temp_filename)
|
||||
|
||||
input_boxes = []
|
||||
confidences = []
|
||||
class_ids = []
|
||||
objects = result.objects
|
||||
objects = result["objects"]
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
cls_name = obj.category.lower().strip()
|
||||
input_boxes.append(obj["bbox"])
|
||||
confidences.append(obj["score"])
|
||||
cls_name = obj["category"].lower().strip()
|
||||
class_ids.append(class_name_to_id[cls_name])
|
||||
# ensure input_boxes with shape (_, 4)
|
||||
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
||||
@@ -96,7 +100,7 @@ if WITH_SLICE_INFERENCE:
|
||||
callback=callback,
|
||||
slice_wh=SLICE_WH,
|
||||
overlap_ratio_wh=OVERLAP_RATIO,
|
||||
iou_threshold=0.5,
|
||||
iou_threshold=IOU_THRESHOLD,
|
||||
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
||||
)
|
||||
detections = slicer(cv2.imread(IMG_PATH))
|
||||
@@ -107,18 +111,25 @@ if WITH_SLICE_INFERENCE:
|
||||
else:
|
||||
image_url = client.upload_file(IMG_PATH)
|
||||
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
||||
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
||||
task = V2Task(
|
||||
api_path="/v2/task/grounding_dino/detection",
|
||||
api_body={
|
||||
"model": GROUNDING_MODEL,
|
||||
"image": image_url,
|
||||
"prompt": {
|
||||
"type": "text",
|
||||
"text": TEXT_PROMPT
|
||||
},
|
||||
"targets": ["bbox"],
|
||||
"bbox_threshold": BOX_THRESHOLD,
|
||||
"iou_threshold": IOU_THRESHOLD,
|
||||
}
|
||||
)
|
||||
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
|
||||
objects = result.objects # the list of detected objects
|
||||
objects = result["objects"] # the list of detected objects
|
||||
|
||||
|
||||
input_boxes = []
|
||||
@@ -127,9 +138,9 @@ else:
|
||||
class_ids = []
|
||||
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
cls_name = obj.category.lower().strip()
|
||||
input_boxes.append(obj["bbox"])
|
||||
confidences.append(obj["score"])
|
||||
cls_name = obj["category"].lower().strip()
|
||||
class_names.append(cls_name)
|
||||
class_ids.append(class_name_to_id[cls_name])
|
||||
|
||||
|
Reference in New Issue
Block a user