diff --git a/grounded_sam2_gd1.5_demo.py b/grounded_sam2_gd1.5_demo.py index d23ad85..7094bee 100644 --- a/grounded_sam2_gd1.5_demo.py +++ b/grounded_sam2_gd1.5_demo.py @@ -10,6 +10,7 @@ import os import cv2 import json import torch +import tempfile import numpy as np import supervision as sv import pycocotools.mask as mask_util @@ -27,6 +28,9 @@ IMG_PATH = "notebooks/images/cars.jpg" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro +WITH_SLICE_INFERENCE = False +SLICE_WH = (480, 480) +OVERLAP_RATIO = (0.2, 0.2) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo") DUMP_JSON_RESULTS = True @@ -47,32 +51,88 @@ client = Client(config) # Step 3: run the task by DetectionTask class # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # if you are processing local image file, upload them to DDS server to get the image url -img_path = IMG_PATH -image_url = client.upload_file(img_path) -task = DetectionTask( - image_url=image_url, - prompts=[TextPrompt(text=TEXT_PROMPT)], - targets=[DetectionTarget.BBox], # detect bbox - model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model -) +classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] +class_name_to_id = {name: id for id, name in enumerate(classes)} +class_id_to_name = {id: name for name, id in class_name_to_id.items()} -client.run_task(task) -result = task.result +if WITH_SLICE_INFERENCE: + def callback(image_slice: np.ndarray) -> sv.Detections: + print("Inference on image slice") + # save the img as temp img file for GD-1.5 API usage + with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile: + temp_filename = tmpfile.name + cv2.imwrite(temp_filename, image_slice) + image_url = client.upload_file(temp_filename) + task = DetectionTask( + image_url=image_url, + prompts=[TextPrompt(text=TEXT_PROMPT)], + targets=[DetectionTarget.BBox], # detect bbox + model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model + ) + client.run_task(task) + result = task.result + # detele the tempfile + os.remove(temp_filename) + + input_boxes = [] + confidences = [] + class_ids = [] + objects = result.objects + for idx, obj in enumerate(objects): + input_boxes.append(obj.bbox) + confidences.append(obj.score) + cls_name = obj.category.lower().strip() + class_ids.append(class_name_to_id[cls_name]) + # ensure input_boxes with shape (_, 4) + input_boxes = np.array(input_boxes).reshape(-1, 4) + class_ids = np.array(class_ids) + confidences = np.array(confidences) + return sv.Detections(xyxy=input_boxes, confidence=confidences, class_id=class_ids) + + slicer = sv.InferenceSlicer( + callback=callback, + slice_wh=SLICE_WH, + overlap_ratio_wh=OVERLAP_RATIO, + iou_threshold=0.5, + overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION + ) + detections = slicer(cv2.imread(IMG_PATH)) + class_names = [class_id_to_name[id] for id in detections.class_id] + confidences = detections.confidence + class_ids = detections.class_id + import pdb; pdb.set_trace() + input_boxes = detections.xyxy +else: + image_url = client.upload_file(IMG_PATH) -objects = result.objects # the list of detected objects + task = DetectionTask( + image_url=image_url, + prompts=[TextPrompt(text=TEXT_PROMPT)], + targets=[DetectionTarget.BBox], # detect bbox + model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model + ) + + client.run_task(task) + result = task.result + + objects = result.objects # the list of detected objects -input_boxes = [] -confidences = [] -class_names = [] + input_boxes = [] + confidences = [] + class_names = [] + class_ids = [] -for idx, obj in enumerate(objects): - input_boxes.append(obj.bbox) - confidences.append(obj.score) - class_names.append(obj.category) + for idx, obj in enumerate(objects): + input_boxes.append(obj.bbox) + confidences.append(obj.score) + cls_name = obj.category.lower().strip() + class_names.append(cls_name) + class_ids.append(class_name_to_id[cls_name]) -input_boxes = np.array(input_boxes) + input_boxes = np.array(input_boxes) + class_ids = np.array(class_ids) """ Init SAM 2 Model and Predict Mask with Box Prompt @@ -93,7 +153,7 @@ model_cfg = SAM2_MODEL_CONFIG sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE) sam2_predictor = SAM2ImagePredictor(sam2_model) -image = Image.open(img_path) +image = Image.open(IMG_PATH) sam2_predictor.set_image(np.array(image.convert("RGB"))) @@ -117,8 +177,6 @@ if masks.ndim == 4: Visualization the Predict Results """ -class_ids = np.array(list(range(len(class_names)))) - labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence @@ -128,7 +186,7 @@ labels = [ """ Visualize image with supervision useful API """ -img = cv2.imread(img_path) +img = cv2.imread(IMG_PATH) detections = sv.Detections( xyxy=input_boxes, # (n, 4) mask=masks.astype(bool), # (n, h, w) @@ -168,7 +226,7 @@ if DUMP_JSON_RESULTS: class_names = [class_name.strip() for class_name in class_names] # save the results in standard format results = { - "image_path": img_path, + "image_path": IMG_PATH, "annotations" : [ { "class_name": class_name,