Merge pull request #40 from IDEA-Research/dump_json_results
[Update] Support automatically dumping json results in image demos
This commit is contained in:
@@ -14,6 +14,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
|
|||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
- `2024/08/31`: Support `dump json results` in Grounded SAM 2 Image Demos (with Grounding DINO).
|
||||||
- `2024/08/20`: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`.
|
- `2024/08/20`: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`.
|
||||||
- `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
|
- `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
|
||||||
- `2024/08/07`: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos.
|
- `2024/08/07`: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos.
|
||||||
|
@@ -109,7 +109,7 @@ def object_detection_and_segmentation(
|
|||||||
text_input=None,
|
text_input=None,
|
||||||
output_dir=OUTPUT_DIR
|
output_dir=OUTPUT_DIR
|
||||||
):
|
):
|
||||||
assert text_input is None, "Text input should not be none when calling object detection pipeline."
|
assert text_input is None, "Text input should be None when calling object detection pipeline."
|
||||||
# run florence-2 object detection in demo
|
# run florence-2 object detection in demo
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
@@ -185,7 +185,7 @@ def dense_region_caption_and_segmentation(
|
|||||||
text_input=None,
|
text_input=None,
|
||||||
output_dir=OUTPUT_DIR
|
output_dir=OUTPUT_DIR
|
||||||
):
|
):
|
||||||
assert text_input is None, "Text input should not be none when calling dense region caption pipeline."
|
assert text_input is None, "Text input should be None when calling dense region caption pipeline."
|
||||||
# run florence-2 object detection in demo
|
# run florence-2 object detection in demo
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
@@ -262,7 +262,7 @@ def region_proposal_and_segmentation(
|
|||||||
text_input=None,
|
text_input=None,
|
||||||
output_dir=OUTPUT_DIR
|
output_dir=OUTPUT_DIR
|
||||||
):
|
):
|
||||||
assert text_input is None, "Text input should not be none when calling region proposal pipeline."
|
assert text_input is None, "Text input should be None when calling region proposal pipeline."
|
||||||
# run florence-2 object detection in demo
|
# run florence-2 object detection in demo
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
|
||||||
@@ -355,7 +355,7 @@ def phrase_grounding_and_segmentation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert text_input is not None, "Text input should not be none when calling phrase grounding pipeline."
|
assert text_input is not None, "Text input should not be None when calling phrase grounding pipeline."
|
||||||
results = results[task_prompt]
|
results = results[task_prompt]
|
||||||
# parse florence-2 detection results
|
# parse florence-2 detection results
|
||||||
input_boxes = np.array(results["bboxes"])
|
input_boxes = np.array(results["bboxes"])
|
||||||
@@ -428,7 +428,7 @@ def referring_expression_segmentation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert text_input is not None, "Text input should not be none when calling referring segmentation pipeline."
|
assert text_input is not None, "Text input should not be None when calling referring segmentation pipeline."
|
||||||
results = results[task_prompt]
|
results = results[task_prompt]
|
||||||
# parse florence-2 detection results
|
# parse florence-2 detection results
|
||||||
polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
|
polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
|
||||||
@@ -542,7 +542,7 @@ def open_vocabulary_detection_and_segmentation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline."
|
assert text_input is not None, "Text input should not be None when calling open-vocabulary detection pipeline."
|
||||||
results = results[task_prompt]
|
results = results[task_prompt]
|
||||||
# parse florence-2 detection results
|
# parse florence-2 detection results
|
||||||
input_boxes = np.array(results["bboxes"])
|
input_boxes = np.array(results["bboxes"])
|
||||||
|
@@ -6,19 +6,39 @@ from dds_cloudapi_sdk import TextPrompt
|
|||||||
from dds_cloudapi_sdk import DetectionModel
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
from dds_cloudapi_sdk import DetectionTarget
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sam2.build_sam import build_sam2
|
from sam2.build_sam import build_sam2
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
|
||||||
|
"""
|
||||||
|
Hyper parameters
|
||||||
|
"""
|
||||||
|
API_TOKEN = "Your API token"
|
||||||
|
TEXT_PROMPT = "car . building ."
|
||||||
|
IMG_PATH = "notebooks/images/cars.jpg"
|
||||||
|
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||||
|
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
|
||||||
|
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
|
||||||
|
DUMP_JSON_RESULTS = True
|
||||||
|
|
||||||
|
# create output directory
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
|
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
|
||||||
"""
|
"""
|
||||||
# Step 1: initialize the config
|
# Step 1: initialize the config
|
||||||
token = "Your API token"
|
token = API_TOKEN
|
||||||
config = Config(token)
|
config = Config(token)
|
||||||
|
|
||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
@@ -27,14 +47,14 @@ client = Client(config)
|
|||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task by DetectionTask class
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
img_path = "notebooks/images/cars.jpg"
|
img_path = IMG_PATH
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = DetectionTask(
|
task = DetectionTask(
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
prompts=[TextPrompt(text="car")],
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
|
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
@@ -60,7 +80,7 @@ Init SAM 2 Model and Predict Mask with Box Prompt
|
|||||||
|
|
||||||
# environment settings
|
# environment settings
|
||||||
# use bfloat16
|
# use bfloat16
|
||||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
@@ -68,9 +88,9 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
|||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
# build SAM2 image predictor
|
# build SAM2 image predictor
|
||||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
sam2_checkpoint = SAM2_CHECKPOINT
|
||||||
model_cfg = "sam2_hiera_l.yaml"
|
model_cfg = SAM2_MODEL_CONFIG
|
||||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||||
|
|
||||||
image = Image.open(img_path)
|
image = Image.open(img_path)
|
||||||
@@ -120,8 +140,45 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
|
|||||||
|
|
||||||
label_annotator = sv.LabelAnnotator()
|
label_annotator = sv.LabelAnnotator()
|
||||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
||||||
|
|
||||||
mask_annotator = sv.MaskAnnotator()
|
mask_annotator = sv.MaskAnnotator()
|
||||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dump the results in standard format and save as json files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def single_mask_to_rle(mask):
|
||||||
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
return rle
|
||||||
|
|
||||||
|
if DUMP_JSON_RESULTS:
|
||||||
|
# convert mask into rle format
|
||||||
|
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
||||||
|
|
||||||
|
input_boxes = input_boxes.tolist()
|
||||||
|
scores = scores.tolist()
|
||||||
|
# FIXME: class_names should be a list of strings without spaces
|
||||||
|
class_names = [class_name.strip() for class_name in class_names]
|
||||||
|
# save the results in standard format
|
||||||
|
results = {
|
||||||
|
"image_path": img_path,
|
||||||
|
"annotations" : [
|
||||||
|
{
|
||||||
|
"class_name": class_name,
|
||||||
|
"bbox": box,
|
||||||
|
"segmentation": mask_rle,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
||||||
|
],
|
||||||
|
"box_format": "xyxy",
|
||||||
|
"img_width": image.width,
|
||||||
|
"img_height": image.height,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
@@ -1,7 +1,11 @@
|
|||||||
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
from pathlib import Path
|
||||||
from supervision.draw.color import ColorPalette
|
from supervision.draw.color import ColorPalette
|
||||||
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -9,9 +13,24 @@ from sam2.build_sam import build_sam2
|
|||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||||
|
|
||||||
|
"""
|
||||||
|
Hyper parameters
|
||||||
|
"""
|
||||||
|
GROUNDING_MODEL = "IDEA-Research/grounding-dino-tiny"
|
||||||
|
TEXT_PROMPT = "car. tire."
|
||||||
|
IMG_PATH = "notebooks/images/truck.jpg"
|
||||||
|
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||||
|
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
OUTPUT_DIR = Path("outputs/grounded_sam2_hf_model_demo")
|
||||||
|
DUMP_JSON_RESULTS = True
|
||||||
|
|
||||||
|
# create output directory
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# environment settings
|
# environment settings
|
||||||
# use bfloat16
|
# use bfloat16
|
||||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
@@ -19,28 +38,27 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
|||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
# build SAM2 image predictor
|
# build SAM2 image predictor
|
||||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
sam2_checkpoint = SAM2_CHECKPOINT
|
||||||
model_cfg = "sam2_hiera_l.yaml"
|
model_cfg = SAM2_MODEL_CONFIG
|
||||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||||
|
|
||||||
# build grounding dino from huggingface
|
# build grounding dino from huggingface
|
||||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
model_id = GROUNDING_MODEL
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
processor = AutoProcessor.from_pretrained(model_id)
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||||
# VERY important: text queries need to be lowercased + end with a dot
|
# VERY important: text queries need to be lowercased + end with a dot
|
||||||
text = "car. tire."
|
text = TEXT_PROMPT
|
||||||
img_path = 'notebooks/images/truck.jpg'
|
img_path = IMG_PATH
|
||||||
|
|
||||||
image = Image.open(img_path)
|
image = Image.open(img_path)
|
||||||
|
|
||||||
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
sam2_predictor.set_image(np.array(image.convert("RGB")))
|
||||||
|
|
||||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = grounding_model(**inputs)
|
outputs = grounding_model(**inputs)
|
||||||
|
|
||||||
@@ -114,8 +132,44 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
|
|||||||
|
|
||||||
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
||||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
||||||
|
|
||||||
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
|
||||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dump the results in standard format and save as json files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def single_mask_to_rle(mask):
|
||||||
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
return rle
|
||||||
|
|
||||||
|
if DUMP_JSON_RESULTS:
|
||||||
|
# convert mask into rle format
|
||||||
|
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
||||||
|
|
||||||
|
input_boxes = input_boxes.tolist()
|
||||||
|
scores = scores.tolist()
|
||||||
|
# save the results in standard format
|
||||||
|
results = {
|
||||||
|
"image_path": img_path,
|
||||||
|
"annotations" : [
|
||||||
|
{
|
||||||
|
"class_name": class_name,
|
||||||
|
"bbox": box,
|
||||||
|
"segmentation": mask_rle,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
||||||
|
],
|
||||||
|
"box_format": "xyxy",
|
||||||
|
"img_width": image.width,
|
||||||
|
"img_height": image.height,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_hf_model_demo_results.json"), "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
@@ -1,35 +1,55 @@
|
|||||||
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
from pathlib import Path
|
||||||
from torchvision.ops import box_convert
|
from torchvision.ops import box_convert
|
||||||
from sam2.build_sam import build_sam2
|
from sam2.build_sam import build_sam2
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
||||||
|
|
||||||
|
"""
|
||||||
|
Hyper parameters
|
||||||
|
"""
|
||||||
|
TEXT_PROMPT = "car. tire."
|
||||||
|
IMG_PATH = "notebooks/images/truck.jpg"
|
||||||
|
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
||||||
|
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
|
||||||
|
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||||
|
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
||||||
|
BOX_THRESHOLD = 0.35
|
||||||
|
TEXT_THRESHOLD = 0.25
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
|
||||||
|
DUMP_JSON_RESULTS = True
|
||||||
|
|
||||||
|
# create output directory
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# environment settings
|
# environment settings
|
||||||
# use bfloat16
|
# use bfloat16
|
||||||
|
|
||||||
# build SAM2 image predictor
|
# build SAM2 image predictor
|
||||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
sam2_checkpoint = SAM2_CHECKPOINT
|
||||||
model_cfg = "sam2_hiera_l.yaml"
|
model_cfg = SAM2_MODEL_CONFIG
|
||||||
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
|
||||||
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
||||||
|
|
||||||
# build grounding dino model
|
# build grounding dino model
|
||||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
grounding_model = load_model(
|
grounding_model = load_model(
|
||||||
model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
model_config_path=GROUNDING_DINO_CONFIG,
|
||||||
model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth",
|
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
|
||||||
device=device
|
device=DEVICE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||||
# VERY important: text queries need to be lowercased + end with a dot
|
# VERY important: text queries need to be lowercased + end with a dot
|
||||||
text = "car. tire."
|
text = TEXT_PROMPT
|
||||||
img_path = 'notebooks/images/truck.jpg'
|
img_path = IMG_PATH
|
||||||
|
|
||||||
image_source, image = load_image(img_path)
|
image_source, image = load_image(img_path)
|
||||||
|
|
||||||
@@ -39,8 +59,8 @@ boxes, confidences, labels = predict(
|
|||||||
model=grounding_model,
|
model=grounding_model,
|
||||||
image=image,
|
image=image,
|
||||||
caption=text,
|
caption=text,
|
||||||
box_threshold=0.35,
|
box_threshold=BOX_THRESHOLD,
|
||||||
text_threshold=0.25
|
text_threshold=TEXT_THRESHOLD,
|
||||||
)
|
)
|
||||||
|
|
||||||
# process the box prompt for SAM 2
|
# process the box prompt for SAM 2
|
||||||
@@ -98,8 +118,43 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
|
|||||||
|
|
||||||
label_annotator = sv.LabelAnnotator()
|
label_annotator = sv.LabelAnnotator()
|
||||||
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||||||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
|
||||||
|
|
||||||
mask_annotator = sv.MaskAnnotator()
|
mask_annotator = sv.MaskAnnotator()
|
||||||
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)
|
cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dump the results in standard format and save as json files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def single_mask_to_rle(mask):
|
||||||
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
return rle
|
||||||
|
|
||||||
|
if DUMP_JSON_RESULTS:
|
||||||
|
# convert mask into rle format
|
||||||
|
mask_rles = [single_mask_to_rle(mask) for mask in masks]
|
||||||
|
|
||||||
|
input_boxes = input_boxes.tolist()
|
||||||
|
scores = scores.tolist()
|
||||||
|
# save the results in standard format
|
||||||
|
results = {
|
||||||
|
"image_path": img_path,
|
||||||
|
"annotations" : [
|
||||||
|
{
|
||||||
|
"class_name": class_name,
|
||||||
|
"bbox": box,
|
||||||
|
"segmentation": mask_rle,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
|
||||||
|
],
|
||||||
|
"box_format": "xyxy",
|
||||||
|
"img_width": w,
|
||||||
|
"img_height": h,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
Reference in New Issue
Block a user