support slice inference on gd1.5 sam2 demo

This commit is contained in:
rentainhe
2024-10-24 16:57:04 +08:00
parent be550a93b1
commit 041bb0bfa4

View File

@@ -10,6 +10,7 @@ import os
import cv2 import cv2
import json import json
import torch import torch
import tempfile
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
@@ -27,6 +28,9 @@ IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo") OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
DUMP_JSON_RESULTS = True DUMP_JSON_RESULTS = True
@@ -47,32 +51,88 @@ client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
img_path = IMG_PATH
image_url = client.upload_file(img_path)
task = DetectionTask( classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
class_name_to_id = {name: id for id, name in enumerate(classes)}
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
if WITH_SLICE_INFERENCE:
def callback(image_slice: np.ndarray) -> sv.Detections:
print("Inference on image slice")
# save the img as temp img file for GD-1.5 API usage
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename)
task = DetectionTask(
image_url=image_url, image_url=image_url,
prompts=[TextPrompt(text=TEXT_PROMPT)], prompts=[TextPrompt(text=TEXT_PROMPT)],
targets=[DetectionTarget.BBox], # detect bbox targets=[DetectionTarget.BBox], # detect bbox
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
) )
client.run_task(task)
result = task.result
# detele the tempfile
os.remove(temp_filename)
client.run_task(task) input_boxes = []
result = task.result confidences = []
class_ids = []
objects = result.objects # the list of detected objects objects = result.objects
for idx, obj in enumerate(objects):
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj.bbox)
confidences.append(obj.score) confidences.append(obj.score)
class_names.append(obj.category) cls_name = obj.category.lower().strip()
class_ids.append(class_name_to_id[cls_name])
# ensure input_boxes with shape (_, 4)
input_boxes = np.array(input_boxes).reshape(-1, 4)
class_ids = np.array(class_ids)
confidences = np.array(confidences)
return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids)
input_boxes = np.array(input_boxes) slicer = sv.InferenceSlicer(
callback=callback,
slice_wh=SLICE_WH,
overlap_ratio_wh=OVERLAP_RATIO,
iou_threshold=0.5,
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
)
detections = slicer(cv2.imread(IMG_PATH))
class_names = [class_id_to_name[id] for id in detections.class_id]
confidences = detections.confidence
class_ids = detections.class_id
import pdb; pdb.set_trace()
input_boxes = detections.xyxy
else:
image_url = client.upload_file(IMG_PATH)
task = DetectionTask(
image_url=image_url,
prompts=[TextPrompt(text=TEXT_PROMPT)],
targets=[DetectionTarget.BBox], # detect bbox
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
class_ids = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
cls_name = obj.category.lower().strip()
class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name])
input_boxes = np.array(input_boxes)
class_ids = np.array(class_ids)
""" """
Init SAM 2 Model and Predict Mask with Box Prompt Init SAM 2 Model and Predict Mask with Box Prompt
@@ -93,7 +153,7 @@ model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE) sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
image = Image.open(img_path) image = Image.open(IMG_PATH)
sam2_predictor.set_image(np.array(image.convert("RGB"))) sam2_predictor.set_image(np.array(image.convert("RGB")))
@@ -117,8 +177,6 @@ if masks.ndim == 4:
Visualization the Predict Results Visualization the Predict Results
""" """
class_ids = np.array(list(range(len(class_names))))
labels = [ labels = [
f"{class_name} {confidence:.2f}" f"{class_name} {confidence:.2f}"
for class_name, confidence for class_name, confidence
@@ -128,7 +186,7 @@ labels = [
""" """
Visualize image with supervision useful API Visualize image with supervision useful API
""" """
img = cv2.imread(img_path) img = cv2.imread(IMG_PATH)
detections = sv.Detections( detections = sv.Detections(
xyxy=input_boxes, # (n, 4) xyxy=input_boxes, # (n, 4)
mask=masks.astype(bool), # (n, h, w) mask=masks.astype(bool), # (n, h, w)
@@ -168,7 +226,7 @@ if DUMP_JSON_RESULTS:
class_names = [class_name.strip() for class_name in class_names] class_names = [class_name.strip() for class_name in class_names]
# save the results in standard format # save the results in standard format
results = { results = {
"image_path": img_path, "image_path": IMG_PATH,
"annotations" : [ "annotations" : [
{ {
"class_name": class_name, "class_name": class_name,