support 1.5 image demo

This commit is contained in:
rentainhe
2024-08-01 21:30:56 +08:00
parent f5b99eea3d
commit d27829ff17
4 changed files with 157 additions and 10 deletions

View File

@@ -1,9 +1,12 @@
# Grounded-SAM-2
Grounded SAM 2: Ground and Track Anything with Grounding DINO and SAM 2
Grounded SAM 2: Ground and Track Anything with Grounding DINO, Grounding DINO 1.5 and SAM 2
## Contents
- [Installation](#installation)
- [Grounded-SAM-2 Demo](#grounded-sam-2-demo)
- [Grounded-SAM-2 Image Demo](#grounded-sam-2-image-demo-with-grounding-dino)
- [Grounded-SAM-2 Image Demo (with Grounding DINO 1.5)](#grounded-sam-2-image-demo-with-grounding-dino-15--16)
## Installation
@@ -13,34 +16,41 @@ Since we need the CUDA compilation environment to compile the `Deformable Attent
export CUDA_HOME=/path/to/cuda-12.1/
```
Install `segment-anything-2`:
Install `Segment Anything 2`:
```bash
pip install -e .
```
Install `grounding dino`:
Install `Grounding DINO`:
```bash
pip install --no-build-isolation -e grounding_dino
```
Download the pretrained `grounding dino` and `sam 2` checkpoints:
Downgrade the version of the `supervision` library to `0.6.0` to use its original API for visualization (we will update our code to be compatible with the latest version of `supervision` in the future release):
```bash
pip install supervision==0.6.0
```
Download the pretrained `SAM 2` checkpoints:
```bash
cd checkpoints
bash download_ckpts.sh
```
Download the pretrained `Grounding DINO` checkpoints:
```bash
cd gdino_checkpoints
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth
```
## Run demo
### Grounded-SAM-2 Image Demo
## Grounded-SAM-2 Demo
### Grounded-SAM-2 Image Demo (with Grounding DINO)
Note that `Grounding DINO` has already been supported in [Huggingface](https://huggingface.co/IDEA-Research/grounding-dino-tiny), so we provide two choices for running `Grounded-SAM-2` model:
- Use huggingface API to inference Grounding DINO (which is simple and clear)
@@ -53,3 +63,19 @@ python grounded_sam2_hf_model_demo.py
```bash
python grounded_sam2_local_demo.py
```
### Grounded-SAM-2 Image Demo (with Grounding DINO 1.5 & 1.6)
We've already released our most capable open-set detection model [Grounding DINO 1.5 & 1.6](https://github.com/IDEA-Research/Grounding-DINO-1.5-API), which can be combined with SAM 2 for stronger open-set detection and segmentation capability. You can apply the API token first and run Grounded-SAM-2 with Grounding DINO 1.5 as follows:
Install the latest DDS cloudapi:
```bash
pip install dds-cloudapi-sdk
```
Apply your API token from our official website here: [request API token](https://deepdataspace.com/request_api).
```bash
python grounded_sam2_gd1.5_demo.py
```

121
grounded_sam2_gd1.5_demo.py Normal file
View File

@@ -0,0 +1,121 @@
# dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
"""
Prompt Grounding DINO 1.5 with Text for Box Prompt Generation with Cloud API
"""
# Step 1: initialize the config
token = "3491a2a256fb7ed01b2e757b713c4cb0"
config = Config(token)
# Step 2: initialize the client
client = Client(config)
# Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url
img_path = "notebooks/images/cars.jpg"
image_url = client.upload_file(img_path)
task = DetectionTask(
image_url=image_url,
prompts=[TextPrompt(text="car")],
targets=[DetectionTarget.BBox], # detect bbox
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
)
client.run_task(task)
result = task.result
objects = result.objects # the list of detected objects
input_boxes = []
confidences = []
class_names = []
for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox)
confidences.append(obj.score)
class_names.append(obj.category)
input_boxes = np.array(input_boxes)
"""
Init SAM 2 Model and Predict Mask with Box Prompt
"""
# environment settings
# use bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# build SAM2 image predictor
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_predictor = SAM2ImagePredictor(sam2_model)
image = Image.open(img_path)
sam2_predictor.set_image(np.array(image.convert("RGB")))
masks, scores, logits = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
"""
Post-process the output of the model to get the masks, scores, and logits for visualization
"""
# convert the shape to (n, H, W)
if masks.ndim == 3:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
"""
Visualization the Predict Results
"""
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence
in zip(class_names, confidences)
]
img = cv2.imread(img_path)
detections = sv.Detections(
xyxy=input_boxes, # (n, 4)
mask=masks, # (n, h, w)
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)
mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
cv2.imwrite("grounded_sam2_annotated_image_with_mask.jpg", annotated_frame)

View File

@@ -103,8 +103,8 @@ img = cv2.imread(img_path)
detections = sv.Detections(
xyxy=input_boxes, # (n, 4)
mask=masks, # (n, h, w)
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)

View File

@@ -94,8 +94,8 @@ img = cv2.imread(img_path)
detections = sv.Detections(
xyxy=input_boxes, # (n, 4)
mask=masks, # (n, h, w)
)
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)