support dump results in local demo

This commit is contained in:
rentainhe
2024-08-31 20:55:49 +08:00
parent a99354bb25
commit e0daab208a

View File

@@ -1,35 +1,55 @@
import os
import cv2 import cv2
import json
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from torchvision.ops import box_convert from torchvision.ops import box_convert
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
"""
Hyper parameters
"""
TEXT_PROMPT = "car. tire."
IMG_PATH = "notebooks/images/truck.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
DUMP_JSON_RESULTS = True
# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# environment settings # environment settings
# use bfloat16 # use bfloat16
# build SAM2 image predictor # build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = "sam2_hiera_l.yaml" model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
# build grounding dino model # build grounding dino model
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
grounding_model = load_model( grounding_model = load_model(
model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", model_config_path=GROUNDING_DINO_CONFIG,
model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth", model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
device=device device=DEVICE
) )
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car. tire." text = TEXT_PROMPT
img_path = 'notebooks/images/truck.jpg' img_path = IMG_PATH
image_source, image = load_image(img_path) image_source, image = load_image(img_path)
@@ -39,8 +59,8 @@ boxes, confidences, labels = predict(
model=grounding_model, model=grounding_model,
image=image, image=image,
caption=text, caption=text,
box_threshold=0.35, box_threshold=BOX_THRESHOLD,
text_threshold=0.25 text_threshold=TEXT_THRESHOLD,
) )
# process the box prompt for SAM 2 # process the box prompt for SAM 2
@@ -98,8 +118,43 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
label_annotator = sv.LabelAnnotator() label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator() mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
"""
Dump the results in standard format and save as json files
"""
def single_mask_to_rle(mask):
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
if DUMP_JSON_RESULTS:
# convert mask into rle format
mask_rles = [single_mask_to_rle(mask) for mask in masks]
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# save the results in standard format
results = {
"image_path": img_path,
"annotations" : [
{
"class_name": class_name,
"bbox": box,
"segmentation": mask_rle,
"score": score,
}
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
],
"box_format": "xyxy",
"img_width": w,
"img_height": h,
}
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
json.dump(results, f, indent=4)