support slice inference on gd1.5 sam2 demo
This commit is contained in:
@@ -10,6 +10,7 @@ import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
import pycocotools.mask as mask_util
|
||||
@@ -27,6 +28,9 @@ IMG_PATH = "notebooks/images/cars.jpg"
|
||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
||||
WITH_SLICE_INFERENCE = False
|
||||
SLICE_WH = (480, 480)
|
||||
OVERLAP_RATIO = (0.2, 0.2)
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
|
||||
DUMP_JSON_RESULTS = True
|
||||
@@ -47,32 +51,88 @@ client = Client(config)
|
||||
# Step 3: run the task by DetectionTask class
|
||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||
# if you are processing local image file, upload them to DDS server to get the image url
|
||||
img_path = IMG_PATH
|
||||
image_url = client.upload_file(img_path)
|
||||
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
||||
)
|
||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
||||
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
||||
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
if WITH_SLICE_INFERENCE:
|
||||
def callback(image_slice: np.ndarray) -> sv.Detections:
|
||||
print("Inference on image slice")
|
||||
# save the img as temp img file for GD-1.5 API usage
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
||||
temp_filename = tmpfile.name
|
||||
cv2.imwrite(temp_filename, image_slice)
|
||||
image_url = client.upload_file(temp_filename)
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
||||
)
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
# detele the tempfile
|
||||
os.remove(temp_filename)
|
||||
|
||||
objects = result.objects # the list of detected objects
|
||||
input_boxes = []
|
||||
confidences = []
|
||||
class_ids = []
|
||||
objects = result.objects
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
cls_name = obj.category.lower().strip()
|
||||
class_ids.append(class_name_to_id[cls_name])
|
||||
# ensure input_boxes with shape (_, 4)
|
||||
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
||||
class_ids = np.array(class_ids)
|
||||
confidences = np.array(confidences)
|
||||
return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
|
||||
|
||||
slicer = sv.InferenceSlicer(
|
||||
callback=callback,
|
||||
slice_wh=SLICE_WH,
|
||||
overlap_ratio_wh=OVERLAP_RATIO,
|
||||
iou_threshold=0.5,
|
||||
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
||||
)
|
||||
detections = slicer(cv2.imread(IMG_PATH))
|
||||
class_names = [class_id_to_name[id] for id in detections.class_id]
|
||||
confidences = detections.confidence
|
||||
class_ids = detections.class_id
|
||||
import pdb; pdb.set_trace()
|
||||
input_boxes = detections.xyxy
|
||||
else:
|
||||
image_url = client.upload_file(IMG_PATH)
|
||||
|
||||
task = DetectionTask(
|
||||
image_url=image_url,
|
||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||
targets=[DetectionTarget.BBox], # detect bbox
|
||||
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
||||
)
|
||||
|
||||
client.run_task(task)
|
||||
result = task.result
|
||||
|
||||
objects = result.objects # the list of detected objects
|
||||
|
||||
|
||||
input_boxes = []
|
||||
confidences = []
|
||||
class_names = []
|
||||
input_boxes = []
|
||||
confidences = []
|
||||
class_names = []
|
||||
class_ids = []
|
||||
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
class_names.append(obj.category)
|
||||
for idx, obj in enumerate(objects):
|
||||
input_boxes.append(obj.bbox)
|
||||
confidences.append(obj.score)
|
||||
cls_name = obj.category.lower().strip()
|
||||
class_names.append(cls_name)
|
||||
class_ids.append(class_name_to_id[cls_name])
|
||||
|
||||
input_boxes = np.array(input_boxes)
|
||||
input_boxes = np.array(input_boxes)
|
||||
class_ids = np.array(class_ids)
|
||||
|
||||
"""
|
||||
Init SAM 2 Model and Predict Mask with Box Prompt
|
||||
@@ -93,7 +153,7 @@ model_cfg = SAM2_MODEL_CONFIG
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||
|
||||
image = Image.open(img_path)
|
||||
image = Image.open(IMG_PATH)
|
||||
|
||||
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
||||
|
||||
@@ -117,8 +177,6 @@ if masks.ndim == 4:
|
||||
Visualization the Predict Results
|
||||
"""
|
||||
|
||||
class_ids = np.array(list(range(len(class_names))))
|
||||
|
||||
labels = [
|
||||
f"{class_name} {confidence:.2f}"
|
||||
for class_name, confidence
|
||||
@@ -128,7 +186,7 @@ labels = [
|
||||
"""
|
||||
Visualize image with supervision useful API
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.imread(IMG_PATH)
|
||||
detections = sv.Detections(
|
||||
xyxy=input_boxes, # (n, 4)
|
||||
mask=masks.astype(bool), # (n, h, w)
|
||||
@@ -168,7 +226,7 @@ if DUMP_JSON_RESULTS:
|
||||
class_names = [class_name.strip() for class_name in class_names]
|
||||
# save the results in standard format
|
||||
results = {
|
||||
"image_path": img_path,
|
||||
"image_path": IMG_PATH,
|
||||
"annotations" : [
|
||||
{
|
||||
"class_name": class_name,
|
||||
|
Reference in New Issue
Block a user