Merge pull request #40 from IDEA-Research/dump_json_results

[Update] Support automatically dumping json results in image demos
This commit is contained in:
Ren Tianhe
2024-08-31 20:58:49 +08:00
committed by GitHub
5 changed files with 209 additions and 42 deletions

View File

@@ -14,6 +14,7 @@ Grounded SAM 2 does not introduce significant methodological changes compared to
## News ## News
- `2024/08/31`: Support `dump json results` in Grounded SAM 2 Image Demos (with Grounding DINO).
- `2024/08/20`: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`. - `2024/08/20`: Support **Florence-2 SAM 2 Image Demo** which includes `dense region caption`, `object detection`, `phrase grounding`, and cascaded auto-label pipeline `caption + phrase grounding`.
- `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe). - `2024/08/09`: Support **Ground and Track New Object** throughout the whole videos. This feature is still under development now. Credits to [Shuo Shen](https://github.com/ShuoShenDe).
- `2024/08/07`: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos. - `2024/08/07`: Support **Custom Video Inputs**, users need only submit their video file (e.g. `.mp4` file) with specific text prompts to get an impressive demo videos.

View File

@@ -109,7 +109,7 @@ def object_detection_and_segmentation(
text_input=None, text_input=None,
output_dir=OUTPUT_DIR output_dir=OUTPUT_DIR
): ):
assert text_input is None, "Text input should not be none when calling object detection pipeline." assert text_input is None, "Text input should be None when calling object detection pipeline."
# run florence-2 object detection in demo # run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image) results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -185,7 +185,7 @@ def dense_region_caption_and_segmentation(
text_input=None, text_input=None,
output_dir=OUTPUT_DIR output_dir=OUTPUT_DIR
): ):
assert text_input is None, "Text input should not be none when calling dense region caption pipeline." assert text_input is None, "Text input should be None when calling dense region caption pipeline."
# run florence-2 object detection in demo # run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image) results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -262,7 +262,7 @@ def region_proposal_and_segmentation(
text_input=None, text_input=None,
output_dir=OUTPUT_DIR output_dir=OUTPUT_DIR
): ):
assert text_input is None, "Text input should not be none when calling region proposal pipeline." assert text_input is None, "Text input should be None when calling region proposal pipeline."
# run florence-2 object detection in demo # run florence-2 object detection in demo
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image) results = run_florence2(task_prompt, text_input, florence2_model, florence2_processor, image)
@@ -355,7 +355,7 @@ def phrase_grounding_and_segmentation(
} }
} }
""" """
assert text_input is not None, "Text input should not be none when calling phrase grounding pipeline." assert text_input is not None, "Text input should not be None when calling phrase grounding pipeline."
results = results[task_prompt] results = results[task_prompt]
# parse florence-2 detection results # parse florence-2 detection results
input_boxes = np.array(results["bboxes"]) input_boxes = np.array(results["bboxes"])
@@ -428,7 +428,7 @@ def referring_expression_segmentation(
} }
} }
""" """
assert text_input is not None, "Text input should not be none when calling referring segmentation pipeline." assert text_input is not None, "Text input should not be None when calling referring segmentation pipeline."
results = results[task_prompt] results = results[task_prompt]
# parse florence-2 detection results # parse florence-2 detection results
polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2) polygon_points = np.array(results["polygons"][0], dtype=np.int32).reshape(-1, 2)
@@ -542,7 +542,7 @@ def open_vocabulary_detection_and_segmentation(
} }
} }
""" """
assert text_input is not None, "Text input should not be none when calling open-vocabulary detection pipeline." assert text_input is not None, "Text input should not be None when calling open-vocabulary detection pipeline."
results = results[task_prompt] results = results[task_prompt]
# parse florence-2 detection results # parse florence-2 detection results
input_boxes = np.array(results["bboxes"]) input_boxes = np.array(results["bboxes"])

View File

@@ -6,19 +6,39 @@ from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget from dds_cloudapi_sdk import DetectionTarget
import os
import cv2 import cv2
import json
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from PIL import Image from PIL import Image
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
"""
Hyper parameters
"""
API_TOKEN = "Your API token"
TEXT_PROMPT = "car . building ."
IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
DUMP_JSON_RESULTS = True
# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
""" """
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
""" """
# Step 1: initialize the config # Step 1: initialize the config
token = "Your API token" token = API_TOKEN
config = Config(token) config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
@@ -27,14 +47,14 @@ client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
img_path = "notebooks/images/cars.jpg" img_path = IMG_PATH
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = DetectionTask(
image_url=image_url, image_url=image_url,
prompts=[TextPrompt(text="car")], prompts=[TextPrompt(text=TEXT_PROMPT)],
targets=[DetectionTarget.BBox], # detect bbox targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
) )
client.run_task(task) client.run_task(task)
@@ -60,7 +80,7 @@ Init SAM 2 Model and Predict Mask with Box Prompt
# environment settings # environment settings
# use bfloat16 # use bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
@@ -68,9 +88,9 @@ if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
# build SAM2 image predictor # build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = "sam2_hiera_l.yaml" model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
image = Image.open(img_path) image = Image.open(img_path)
@@ -120,8 +140,45 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
label_annotator = sv.LabelAnnotator() label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator() mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
"""
Dump the results in standard format and save as json files
"""
def single_mask_to_rle(mask):
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
if DUMP_JSON_RESULTS:
# convert mask into rle format
mask_rles = [single_mask_to_rle(mask) for mask in masks]
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# FIXME: class_names should be a list of strings without spaces
class_names = [class_name.strip() for class_name in class_names]
# save the results in standard format
results = {
"image_path": img_path,
"annotations" : [
{
"class_name": class_name,
"bbox": box,
"segmentation": mask_rle,
"score": score,
}
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
],
"box_format": "xyxy",
"img_width": image.width,
"img_height": image.height,
}
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_gd1.5_image_demo_results.json"), "w") as f:
json.dump(results, f, indent=4)

View File

@@ -1,7 +1,11 @@
import os
import cv2 import cv2
import json
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from supervision.draw.color import ColorPalette from supervision.draw.color import ColorPalette
from utils.supervision_utils import CUSTOM_COLOR_MAP from utils.supervision_utils import CUSTOM_COLOR_MAP
from PIL import Image from PIL import Image
@@ -9,9 +13,24 @@ from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
"""
Hyper parameters
"""
GROUNDING_MODEL = "IDEA-Research/grounding-dino-tiny"
TEXT_PROMPT = "car. tire."
IMG_PATH = "notebooks/images/truck.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_hf_model_demo")
DUMP_JSON_RESULTS = True
# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# environment settings # environment settings
# use bfloat16 # use bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
@@ -19,28 +38,27 @@ if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
# build SAM2 image predictor # build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = "sam2_hiera_l.yaml" model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
# build grounding dino from huggingface # build grounding dino from huggingface
model_id = "IDEA-Research/grounding-dino-tiny" model_id = GROUNDING_MODEL
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car. tire." text = TEXT_PROMPT
img_path = 'notebooks/images/truck.jpg' img_path = IMG_PATH
image = Image.open(img_path) image = Image.open(img_path)
sam2_predictor.set_image(np.array(image.convert("RGB"))) sam2_predictor.set_image(np.array(image.convert("RGB")))
inputs = processor(images=image, text=text, return_tensors="pt").to(device) inputs = processor(images=image, text=text, return_tensors="pt").to(DEVICE)
with torch.no_grad(): with torch.no_grad():
outputs = grounding_model(**inputs) outputs = grounding_model(**inputs)
@@ -114,8 +132,44 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP)) label_annotator = sv.LabelAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP)) mask_annotator = sv.MaskAnnotator(color=ColorPalette.from_hex(CUSTOM_COLOR_MAP))
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
"""
Dump the results in standard format and save as json files
"""
def single_mask_to_rle(mask):
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
if DUMP_JSON_RESULTS:
# convert mask into rle format
mask_rles = [single_mask_to_rle(mask) for mask in masks]
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# save the results in standard format
results = {
"image_path": img_path,
"annotations" : [
{
"class_name": class_name,
"bbox": box,
"segmentation": mask_rle,
"score": score,
}
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
],
"box_format": "xyxy",
"img_width": image.width,
"img_height": image.height,
}
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_hf_model_demo_results.json"), "w") as f:
json.dump(results, f, indent=4)

View File

@@ -1,35 +1,55 @@
import os
import cv2 import cv2
import json
import torch import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import pycocotools.mask as mask_util
from pathlib import Path
from torchvision.ops import box_convert from torchvision.ops import box_convert
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
"""
Hyper parameters
"""
TEXT_PROMPT = "car. tire."
IMG_PATH = "notebooks/images/truck.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_local_demo")
DUMP_JSON_RESULTS = True
# create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# environment settings # environment settings
# use bfloat16 # use bfloat16
# build SAM2 image predictor # build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = "sam2_hiera_l.yaml" model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model) sam2_predictor = SAM2ImagePredictor(sam2_model)
# build grounding dino model # build grounding dino model
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
grounding_model = load_model( grounding_model = load_model(
model_config_path="grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py", model_config_path=GROUNDING_DINO_CONFIG,
model_checkpoint_path="gdino_checkpoints/groundingdino_swint_ogc.pth", model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
device=device device=DEVICE
) )
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car. tire." text = TEXT_PROMPT
img_path = 'notebooks/images/truck.jpg' img_path = IMG_PATH
image_source, image = load_image(img_path) image_source, image = load_image(img_path)
@@ -39,8 +59,8 @@ boxes, confidences, labels = predict(
model=grounding_model, model=grounding_model,
image=image, image=image,
caption=text, caption=text,
box_threshold=0.35, box_threshold=BOX_THRESHOLD,
text_threshold=0.25 text_threshold=TEXT_THRESHOLD,
) )
# process the box prompt for SAM 2 # process the box prompt for SAM 2
@@ -98,8 +118,43 @@ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections
label_annotator = sv.LabelAnnotator() label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "groundingdino_annotated_image.jpg"), annotated_frame)
mask_annotator = sv.MaskAnnotator() mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame) cv2.imwrite(os.path.join(OUTPUT_DIR, "grounded_sam2_annotated_image_with_mask.jpg"), annotated_frame)
"""
Dump the results in standard format and save as json files
"""
def single_mask_to_rle(mask):
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
if DUMP_JSON_RESULTS:
# convert mask into rle format
mask_rles = [single_mask_to_rle(mask) for mask in masks]
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# save the results in standard format
results = {
"image_path": img_path,
"annotations" : [
{
"class_name": class_name,
"bbox": box,
"segmentation": mask_rle,
"score": score,
}
for class_name, box, mask_rle, score in zip(class_names, input_boxes, mask_rles, scores)
],
"box_format": "xyxy",
"img_width": w,
"img_height": h,
}
with open(os.path.join(OUTPUT_DIR, "grounded_sam2_local_image_demo_results.json"), "w") as f:
json.dump(results, f, indent=4)