add dump results to hf model demo
This commit is contained in:
@@ -22,11 +22,12 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
Hyper parameters
|
||||
"""
|
||||
API_TOKEN = "Your API token"
|
||||
TEXT_PROMPT = "car"
|
||||
TEXT_PROMPT = "car . building ."
|
||||
IMG_PATH = "notebooks/images/cars.jpg"
|
||||
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
|
||||
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
|
||||
DUMP_JSON_RESULTS = True
|
||||
|
||||
@@ -79,7 +80,7 @@ Init SAM 2 Model and Predict Mask with Box Prompt
|
||||
|
||||
# environment settings
|
||||
# use bfloat16
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
@@ -89,7 +90,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# build SAM2 image predictor
|
||||
sam2_checkpoint = SAM2_CHECKPOINT
|
||||
model_cfg = SAM2_MODEL_CONFIG
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||
|
||||
image = Image.open(img_path)
|
||||
@@ -160,6 +161,8 @@ if DUMP_JSON_RESULTS:
|
||||
|
||||
input_boxes = input_boxes.tolist()
|
||||
scores = scores.tolist()
|
||||
# FIXME: class_names should be a list of strings without spaces
|
||||
class_names = [class_name.strip() for class_name in class_names]
|
||||
# save the results in standard format
|
||||
results = {
|
||||
"image_path": img_path,
|
||||
|
@@ -1,7 +1,11 @@
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
import pycocotools.mask as mask_util
|
||||
from pathlib import Path
|
||||
from supervision.draw.color import ColorPalette
|
||||
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
||||
from PIL import Image
|
||||
@@ -9,9 +13,24 @@ from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
|
||||
"""
|
||||
Hyper parameters
|
||||
"""
|
||||
GROUNDING_MODEL = "IDEA-Research/grounding-dino-tiny"
|
||||
TEXT_PROMPT = "car. tire."
|
||||
IMG_PATH = "notebooks/images/truck.jpg"
|
||||
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
OUTPUT_DIR = Path("outputs/grounded_sam2_hf_model_demo")
|
||||
DUMP_JSON_RESULTS = True
|
||||
|
||||
# create output directory
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# environment settings
|
||||
# use bfloat16
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
@@ -19,28 +38,27 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# build SAM2 image predictor
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
||||
sam2_checkpoint = SAM2_CHECKPOINT
|
||||
model_cfg = SAM2_MODEL_CONFIG
|
||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||
|
||||
# build grounding dino from huggingface
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = GROUNDING_MODEL
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
|
||||
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "car. tire."
|
||||
img_path = 'notebooks/images/truck.jpg'
|
||||
text = TEXT_PROMPT
|
||||
img_path = IMG_PATH
|
||||
|
||||
image = Image.open(img_path)
|
||||
|
||||
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
||||
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
|
||||
with torch.no_grad():
|
||||
outputs = grounding_model(**inputs)
|
||||
|
||||
@@ -114,8 +132,44 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
|
||||
|
||||
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
||||
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
||||
|
||||
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|
||||
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
||||
|
||||
|
||||
"""
|
||||
Dump the results in standard format and save as json files
|
||||
"""
|
||||
|
||||
def single_mask_to_rle(mask):
|
||||
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
return rle
|
||||
|
||||
if DUMP_JSON_RESULTS:
|
||||
# convert mask into rle format
|
||||
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
||||
|
||||
input_boxes = input_boxes.tolist()
|
||||
scores = scores.tolist()
|
||||
# save the results in standard format
|
||||
results = {
|
||||
"image_path": img_path,
|
||||
"annotations" : [
|
||||
{
|
||||
"class_name": class_name,
|
||||
"bbox": box,
|
||||
"segmentation": mask_rle,
|
||||
"score": score,
|
||||
}
|
||||
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
||||
],
|
||||
"box_format": "xyxy",
|
||||
"img_width": image.width,
|
||||
"img_height": image.height,
|
||||
}
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_hf_model_demo_results.json"), "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
Reference in New Issue
Block a user