9 Commits

Author SHA1 Message Date
will ye
2111d9c52c Fix demos for CPU inference (#104) 2025-05-27 00:24:30 +08:00
will ye
75aaf0c3ae Change default output dir for HF demo (#105) 2025-05-27 00:24:17 +08:00
Embodied Learner
c5780dabeb feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes … (#97)
* feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes #74)

* update README
2025-05-08 11:02:33 +08:00
Sami Haidar
7fec804683 Pinned setuptools in Dockerfile (#99)
Co-authored-by: Sami Haidar Wehbe <sami@autoenhance.ai>
2025-05-08 11:02:04 +08:00
rentainhe
9412a16276 update DINO-X api to V2 2025-04-21 01:06:01 +08:00
rentainhe
d49257700a update DINO-X api usage to dds v2 2025-04-20 01:04:26 +08:00
rentainhe
3c5a4136d4 update DINO-X api usage to dds v2 2025-04-20 00:38:38 +08:00
Andrew Choi
8238557f52 Add torch2.6 support for ms_deform_attn_cuda (#94) 2025-04-18 00:38:51 +08:00
Reuben Feinman
0bc3970292 update setuptools build requirement to fix build error (#91) 2025-03-24 22:26:04 +08:00
41 changed files with 1404 additions and 1951 deletions

View File

@@ -1,17 +0,0 @@
name: SAM2/fmt
on:
pull_request:
branches:
- main
jobs:
ufmt_check:
runs-on: ubuntu-latest
steps:
- name: Check formatting
uses: omnilib/ufmt@action-v1
with:
path: sam2 tools
version: "2.0.0b2"
python-version: "3.10"
black-version: "24.2.0"
usort-version: "1.0.2"

1
.gitignore vendored
View File

@@ -145,4 +145,3 @@ dmypy.json
outputs/ outputs/
.idea/ .idea/
demo/backend/checkpoints/*.pt

View File

@@ -27,7 +27,7 @@ WORKDIR /home/appuser/Grounded-SAM-2
# Install essential Python packages # Install essential Python packages
RUN python -m pip install --upgrade pip setuptools wheel numpy \ RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
opencv-python transformers supervision pycocotools addict yapf timm opencv-python transformers supervision pycocotools addict yapf timm
# Install segment_anything package in editable mode # Install segment_anything package in editable mode

View File

@@ -2,7 +2,7 @@
### Requirements ### Requirements
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. - Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
@@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details> </details>
<details> <details>

View File

@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright [yyyy] [name of copyright owner] Copyright 2023 - present, IDEA Research.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View File

@@ -20,6 +20,7 @@ In this repo, we've supported the following demo with **simple implementations**
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience. Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
## Latest updates ## Latest updates
- **2025.04.20**: Update to `dds-cloudapi-sdk` API V2 version. The V1 version in the original API for `Grounding DINO 1.5` and `DINO-X` has been deprecated, please update to the latest `dds-cloudapi-sdk` by `pip install dds-cloudapi-sdk -U` to use `Grounding DINO 1.5 / 1.6` and `DINO-X` models. Please refer to [dds-cloudapi-sdk](https://github.com/deepdataspace/dds-cloudapi-sdk) and our [API docs](https://cloud.deepdataspace.com/docs) to view more details about the update.
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details. - **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
@@ -334,6 +335,16 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
``` ```
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
```bash
python grounded_sam2_tracking_camera_with_continuous_id.py
```
## Grounded SAM 2 Florence-2 Demos ## Grounded SAM 2 Florence-2 Demos
### Grounded SAM 2 Florence-2 Image Demo ### Grounded SAM 2 Florence-2 Image Demo

View File

@@ -1,27 +0,0 @@
## SAM 2 release notes
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
### 07/29/2024 -- SAM 2 is released
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
* SAM 2 code: https://github.com/facebookresearch/sam2
* SAM 2 demo: https://sam2.metademolab.com/
* SAM 2 paper: https://arxiv.org/abs/2408.00714

View File

@@ -1,4 +1,4 @@
ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
ARG MODEL_SIZE=base_plus ARG MODEL_SIZE=base_plus
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}

View File

@@ -105,7 +105,7 @@ cd demo/backend/server/
```bash ```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \ PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \ APP_ROOT="$(pwd)/../../../" \
API_URL=http://localhost:7263 \ APP_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \ MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \ DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \ DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \

View File

@@ -1,9 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -27,6 +25,7 @@ IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -48,7 +47,7 @@ config = Config(token)
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # infer_image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -62,13 +61,18 @@ if WITH_SLICE_INFERENCE:
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile: with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename) infer_image_url = client.upload_file(temp_filename)
task = DinoxTask( task = V2Task(api_path="/v2/task/dinox/detection", api_body={
image_url=image_url, "model": "DINO-X-1.0",
prompts=[TextPrompt(text=TEXT_PROMPT)], "image": infer_image_url,
bbox_threshold=0.25, "prompt": {
targets=[DetectionTarget.BBox], "type":"text",
) "text":TEXT_PROMPT
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
})
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# detele the tempfile # detele the tempfile
@@ -77,7 +81,7 @@ if WITH_SLICE_INFERENCE:
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result.objects objects = result["objects"]
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj.bbox)
confidences.append(obj.score) confidences.append(obj.score)
@@ -102,19 +106,26 @@ if WITH_SLICE_INFERENCE:
class_ids = detections.class_id class_ids = detections.class_id
input_boxes = detections.xyxy input_boxes = detections.xyxy
else: else:
image_url = client.upload_file(IMG_PATH) infer_image_url = client.upload_file(IMG_PATH)
task = DinoxTask( task = V2Task(
image_url=image_url, api_path="/v2/task/dinox/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
bbox_threshold=0.25, "model": "DINO-X-1.0",
targets=[DetectionTarget.BBox], "image": infer_image_url,
"prompt": {
"type":"text",
"text":TEXT_PROMPT
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects
objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -123,9 +134,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -27,8 +24,9 @@ TEXT_PROMPT = "car . building ."
IMG_PATH = "notebooks/images/cars.jpg" IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -49,8 +47,7 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task API
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -65,26 +62,33 @@ if WITH_SLICE_INFERENCE:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename) image_url = client.upload_file(temp_filename)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, # box confidence threshold "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# detele the tempfile # delete the tempfile
os.remove(temp_filename) os.remove(temp_filename)
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result.objects objects = result["objects"]
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])
# ensure input_boxes with shape (_, 4) # ensure input_boxes with shape (_, 4)
input_boxes = np.array(input_boxes).reshape(-1, 4) input_boxes = np.array(input_boxes).reshape(-1, 4)
@@ -96,7 +100,7 @@ if WITH_SLICE_INFERENCE:
callback=callback, callback=callback,
slice_wh=SLICE_WH, slice_wh=SLICE_WH,
overlap_ratio_wh=OVERLAP_RATIO, overlap_ratio_wh=OVERLAP_RATIO,
iou_threshold=0.5, iou_threshold=IOU_THRESHOLD,
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
) )
detections = slicer(cv2.imread(IMG_PATH)) detections = slicer(cv2.imread(IMG_PATH))
@@ -107,18 +111,25 @@ if WITH_SLICE_INFERENCE:
else: else:
image_url = client.upload_file(IMG_PATH) image_url = client.upload_file(IMG_PATH)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, # box confidence threshold "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -127,9 +138,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -23,7 +23,7 @@ parser.add_argument("--text-prompt", default="car. tire.")
parser.add_argument("--img-path", default="notebooks/images/truck.jpg") parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt") parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml") parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
parser.add_argument("--output-dir", default="outputs/test_sam2.1") parser.add_argument("--output-dir", default="outputs/grounded_sam2_hf_demo")
parser.add_argument("--no-dump-json", action="store_true") parser.add_argument("--no-dump-json", action="store_true")
parser.add_argument("--force-cpu", action="store_true") parser.add_argument("--force-cpu", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# use bfloat16 # use bfloat16
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -61,6 +61,7 @@ boxes, confidences, labels = predict(
caption=text, caption=text,
box_threshold=BOX_THRESHOLD, box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD, text_threshold=TEXT_THRESHOLD,
device=DEVICE
) )
# process the box prompt for SAM 2 # process the box prompt for SAM 2
@@ -70,9 +71,9 @@ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# FIXME: figure how does this influence the G-DINO model # FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -0,0 +1,536 @@
import copy
import os
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
# Setup environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GroundingDinoPredictor:
"""
Wrapper for using a GroundingDINO model for zero-shot object detection.
"""
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
"""
Initialize the GroundingDINO predictor.
Args:
model_id (str): HuggingFace model ID to load.
device (str): Device to run the model on ('cuda' or 'cpu').
"""
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
self.device = device
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
device
)
def predict(
self,
image: "PIL.Image.Image",
text_prompts: str,
box_threshold=0.25,
text_threshold=0.25,
):
"""
Perform object detection using text prompts.
Args:
image (PIL.Image.Image): Input RGB image.
text_prompts (str): Text prompt describing target objects.
box_threshold (float): Confidence threshold for box selection.
text_threshold (float): Confidence threshold for text match.
Returns:
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
"""
inputs = self.processor(
images=image, text=text_prompts, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]],
)
return results[0]["boxes"], results[0]["labels"]
class SAM2ImageSegmentor:
"""
Wrapper class for SAM2-based segmentation given bounding boxes.
"""
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
"""
Initialize the SAM2 image segmentor.
Args:
sam_model_cfg (str): Path to the SAM2 config file.
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
device (str): Device to load the model on ('cuda' or 'cpu').
"""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.device = device
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
self.predictor = SAM2ImagePredictor(sam_model)
def set_image(self, image: np.ndarray):
"""
Set the input image for segmentation.
Args:
image (np.ndarray): RGB image array with shape (H, W, 3).
"""
self.predictor.set_image(image)
def predict_masks_from_boxes(self, boxes: torch.Tensor):
"""
Predict segmentation masks from given bounding boxes.
Args:
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- masks: Binary masks per box, shape (N, H, W)
- scores: Confidence scores for each mask
- logits: Raw logits from the model
"""
masks, scores, logits = self.predictor.predict(
point_coords=None,
point_labels=None,
box=boxes,
multimask_output=False,
)
# Normalize shape to (N, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
return masks, scores, logits
class IncrementalObjectTracker:
def __init__(
self,
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text="car.",
detection_interval=20,
):
"""
Initialize an incremental object tracker using GroundingDINO and SAM2.
Args:
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
sam2_model_cfg (str): Path to SAM2 model config file.
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
device (str): Device to run the models on ('cuda' or 'cpu').
prompt_text (str): Initial text prompt for detection.
detection_interval (int): Frame interval between full detections.
"""
self.device = device
self.detection_interval = detection_interval
self.prompt_text = prompt_text
# Load models
self.grounding_predictor = GroundingDinoPredictor(
model_id=grounding_model_id, device=device
)
self.sam2_segmentor = SAM2ImageSegmentor(
sam_model_cfg=sam2_model_cfg,
sam_model_ckpt=sam2_ckpt_path,
device=device,
)
self.video_predictor = build_sam2_video_predictor(
sam2_model_cfg, sam2_ckpt_path
)
# Initialize inference state
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
self.total_frames = 0
self.objects_count = 0
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
# Store tracking results
self.last_mask_dict = MaskDictionaryModel()
self.track_dict = MaskDictionaryModel()
def add_image(self, image_np: np.ndarray):
"""
Add a new image frame to the tracker and perform detection or tracking update.
Args:
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
Returns:
np.ndarray: Annotated image with object masks and labels.
"""
import numpy as np
from PIL import Image
img_pil = Image.fromarray(image_np)
# Step 1: Perform detection every detection_interval frames
if self.total_frames % self.detection_interval == 0:
if (
self.inference_state["video_height"] is None
or self.inference_state["video_width"] is None
):
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
print(
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
)
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
# 1.1 GroundingDINO object detection
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
if boxes.shape[0] == 0:
return
# 1.2 SAM2 segmentation from detection boxes
self.sam2_segmentor.set_image(image_np)
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
# 1.3 Build MaskDictionaryModel
mask_dict = MaskDictionaryModel(
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
)
mask_dict.add_new_frame_annotation(
mask_list=torch.tensor(masks).to(self.device),
box_list=torch.tensor(boxes),
label_list=labels,
)
# 1.4 Object ID tracking and IOU-based update
self.objects_count = mask_dict.update_masks(
tracking_annotation_dict=self.last_mask_dict,
iou_threshold=0.3,
objects_count=self.objects_count,
)
# 1.5 Reset video tracker state
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
self.video_predictor.reset_state(self.inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, _, _ = self.video_predictor.add_new_mask(
self.inference_state,
frame_idx,
object_id,
object_info.mask,
)
self.track_dict = copy.deepcopy(mask_dict)
self.last_mask_dict = mask_dict
else:
# Step 2: Use incremental tracking for intermediate frames
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
# Step 3: Tracking propagation using the video predictor
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
inference_state=self.inference_state,
frame_idx=frame_idx,
)
# Step 4: Update the mask dictionary based on tracked masks
frame_masks = MaskDictionaryModel()
for i, obj_id in enumerate(obj_ids):
out_mask = video_res_masks[i] > 0.0
object_info = ObjectInfo(
instance_id=obj_id,
mask=out_mask[0],
class_name=self.track_dict.get_target_class_name(obj_id),
logit=self.track_dict.get_target_logit(obj_id),
)
object_info.update_box()
frame_masks.labels[obj_id] = object_info
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
self.last_mask_dict = copy.deepcopy(frame_masks)
# Step 5: Build mask array
H, W = image_np.shape[:2]
mask_img = torch.zeros((H, W), dtype=torch.int32)
for obj_id, obj_info in self.last_mask_dict.labels.items():
mask_img[obj_info.mask == True] = obj_id
mask_array = mask_img.cpu().numpy()
# Step 6: Visualization
annotated_frame = self.visualize_frame_with_mask_and_metadata(
image_np=image_np,
mask_array=mask_array,
json_metadata=self.last_mask_dict.to_dict(),
)
print(f"[Tracker] Total processed frames: {self.total_frames}")
self.total_frames += 1
torch.cuda.empty_cache()
return annotated_frame
def set_prompt(self, new_prompt: str):
"""
Dynamically update the GroundingDINO prompt and reset tracking state
to force a new object detection.
"""
self.prompt_text = new_prompt
self.total_frames = 0 # Trigger immediate re-detection
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
self.inference_state["video_height"] = None
self.inference_state["video_width"] = None
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
"""
Save the current mask, metadata, raw image, and annotated result.
Args:
output_dir (str): The root output directory.
raw_image (np.ndarray, optional): The original input image (RGB).
"""
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
image_data_dir = os.path.join(output_dir, "images")
vis_data_dir = os.path.join(output_dir, "result")
os.makedirs(mask_data_dir, exist_ok=True)
os.makedirs(json_data_dir, exist_ok=True)
os.makedirs(image_data_dir, exist_ok=True)
os.makedirs(vis_data_dir, exist_ok=True)
frame_masks = self.last_mask_dict
# Ensure mask_name is valid
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
base_name = f"image_{self.total_frames:05d}"
# Save segmentation mask
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
for obj_id, obj_info in frame_masks.labels.items():
mask_img[obj_info.mask == True] = obj_id
np.save(
os.path.join(mask_data_dir, frame_masks.mask_name),
mask_img.numpy().astype(np.uint16),
)
# Save metadata as JSON
json_path = os.path.join(json_data_dir, base_name + ".json")
frame_masks.to_json(json_path)
# Save raw input image
if raw_image is not None:
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
# Save annotated image with mask, bounding boxes, and labels
annotated_image = self.visualize_frame_with_mask_and_metadata(
image_np=raw_image,
mask_array=mask_img.numpy().astype(np.uint16),
json_metadata=frame_masks.to_dict(),
)
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
)
print(
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
)
def visualize_frame_with_mask_and_metadata(
self,
image_np: np.ndarray,
mask_array: np.ndarray,
json_metadata: dict,
):
image = image_np.copy()
H, W = image.shape[:2]
# Step 1: Parse metadata and build object entries
metadata_lookup = json_metadata.get("labels", {})
all_object_ids = []
all_object_boxes = []
all_object_classes = []
all_object_masks = []
for obj_id_str, obj_info in metadata_lookup.items():
instance_id = obj_info.get("instance_id")
if instance_id is None or instance_id == 0:
continue
if instance_id not in np.unique(mask_array):
continue
object_mask = mask_array == instance_id
all_object_ids.append(instance_id)
x1 = obj_info.get("x1", 0)
y1 = obj_info.get("y1", 0)
x2 = obj_info.get("x2", 0)
y2 = obj_info.get("y2", 0)
all_object_boxes.append([x1, y1, x2, y2])
all_object_classes.append(obj_info.get("class_name", "unknown"))
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
# Step 2: Check if valid objects exist
if len(all_object_ids) == 0:
print("No valid object instances found in metadata.")
return image
# Step 3: Sort by instance ID
paired = list(
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
)
paired.sort(key=lambda x: x[0])
all_object_ids = [p[0] for p in paired]
all_object_boxes = [p[1] for p in paired]
all_object_masks = [p[2] for p in paired]
all_object_classes = [p[3] for p in paired]
# Step 4: Build detections
all_object_masks = np.concatenate(all_object_masks, axis=0)
detections = sv.Detections(
xyxy=np.array(all_object_boxes),
mask=all_object_masks,
class_id=np.array(all_object_ids, dtype=np.int32),
)
labels = [
f"{instance_id}: {class_name}"
for instance_id, class_name in zip(all_object_ids, all_object_classes)
]
# Step 5: Annotate image
annotated_frame = image.copy()
mask_annotator = sv.MaskAnnotator()
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
annotated_frame = box_annotator.annotate(annotated_frame, detections)
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
return annotated_frame
import os
import cv2
import torch
from utils.common_utils import CommonUtils
def main():
# Parameter settings
output_dir = "./outputs"
prompt_text = "hand."
detection_interval = 20
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
os.makedirs(output_dir, exist_ok=True)
# Initialize the object tracker
tracker = IncrementalObjectTracker(
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text=prompt_text,
detection_interval=detection_interval,
)
tracker.set_prompt("person.")
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("[Error] Cannot open camera.")
return
print("[Info] Camera opened. Press 'q' to quit.")
frame_idx = 0
try:
while True:
ret, frame = cap.read()
if not ret:
print("[Warning] Failed to capture frame.")
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
print(f"[Frame {frame_idx}] Processing live frame...")
process_image = tracker.add_image(frame_rgb)
if process_image is None or not isinstance(process_image, np.ndarray):
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
frame_idx += 1
continue
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
# cv2.imshow("Live Inference", process_image_bgr)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[Info] Quit signal received.")
# break
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
frame_idx += 1
if frame_idx >= max_frames:
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
break
except KeyboardInterrupt:
print("[Info] Interrupted by user (Ctrl+C).")
finally:
cap.release()
cv2.destroyAllWindows()
print("[Done] Live inference complete.")
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for DINO-X - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -30,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_DINOX = "Your API token" API_TOKEN_FOR_DINOX = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -98,22 +97,29 @@ config = Config(API_TOKEN_FOR_DINOX)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DinoxTask( task = V2Task(
image_url=image_url, api_path="/v2/task/dinox/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
bbox_threshold=0.25, "model": "DINO-X-1.0",
targets=[DetectionTarget.BBox], "image": image_url,
"prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -121,9 +127,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - Update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -31,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token" API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -99,33 +97,38 @@ config = Config(API_TOKEN_FOR_GD1_5)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": "GroundingDino-1.5-Pro",
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,11 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import torch import torch
@@ -51,6 +47,9 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car." text = "car."
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
GROUNDING_MODEL = "GroundingDino-1.6-Pro" # 使用字符串替代枚举值
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg` # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/car" video_dir = "notebooks/videos/car"
@@ -102,24 +101,32 @@ for start_frame_idx in range(0, len(frame_names), step):
client = Client(config) client = Client(config)
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=text)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
"prompt": {
"type": "text",
"text": text
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)
OBJECTS = class_names OBJECTS = class_names
@@ -154,7 +161,7 @@ for start_frame_idx in range(0, len(frame_names), step):
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=IOU_THRESHOLD, objects_count=objects_count)
print("objects_count", objects_count) print("objects_count", objects_count)
else: else:

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -54,6 +51,11 @@ inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# 添加参数设置
TEXT_PROMPT = "children. pillow"
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
""" """
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
@@ -70,23 +72,29 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text="children. pillow")], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": "GroundingDino-1.5-Pro",
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=0.2, "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -94,9 +102,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -15,6 +15,19 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/version.h>
// Check PyTorch version and define appropriate macros
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
// PyTorch 2.x and above
#define GET_TENSOR_TYPE(x) x.scalar_type()
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
#else
// PyTorch 1.x
#define GET_TENSOR_TYPE(x) x.type()
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
#endif
namespace groundingdino { namespace groundingdino {
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto columns = output_n.select(0, n); auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), spatial_shapes.data<int64_t>(),
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -132,7 +145,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto grad_output_g = grad_output_n.select(0, n); auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,

View File

@@ -1,6 +1,6 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=61.0", "setuptools>=62.3.0,<75.9",
"torch>=2.5.1", "torch>=2.3.1",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,92 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from sam2.build_sam import build_sam2_video_predictor
# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)
# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()
# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0
print("FPS: ", count * num_frames / total)

View File

@@ -104,18 +104,11 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
vos_optimized=False,
**kwargs, **kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
] ]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]
if apply_postprocessing: if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [ hydra_overrides_extra += [

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -97,7 +97,7 @@ trainer:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -108,7 +108,7 @@ trainer:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [64, 64] feat_sizes: [32, 32]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -32,7 +32,9 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp) return windows, (Hp, Wp)
@@ -50,13 +52,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.reshape( x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1 B, Hp // window_size, Wp // window_size, window_size, window_size, -1
) )
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
x = x[:, :H, :W, :] x = x[:, :H, :W, :].contiguous()
return x return x

View File

@@ -25,11 +25,6 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000, temperature: int = 10000,
normalize: bool = True, normalize: bool = True,
scale: Optional[float] = None, scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
): ):
super().__init__() super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width" assert num_pos_feats % 2 == 0, "Expecting even model width"
@@ -43,12 +38,6 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale self.scale = scale
self.cache = {} self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y): def _encode_xy(self, x, y):
# The positions are expected to be normalized # The positions are expected to be normalized
@@ -87,20 +76,19 @@ class PositionEmbeddingSine(nn.Module):
return pos return pos
@torch.no_grad() @torch.no_grad()
def _pe(self, B, device, *cache_key): def forward(self, x: torch.Tensor):
H, W = cache_key cache_key = (x.shape[-2], x.shape[-1])
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
y_embed = ( y_embed = (
torch.arange(1, H + 1, dtype=torch.float32, device=device) torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
.view(1, -1, 1) .view(1, -1, 1)
.repeat(B, 1, W) .repeat(x.shape[0], 1, x.shape[-1])
) )
x_embed = ( x_embed = (
torch.arange(1, W + 1, dtype=torch.float32, device=device) torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
.view(1, 1, -1) .view(1, 1, -1)
.repeat(B, H, 1) .repeat(x.shape[0], x.shape[-2], 1)
) )
if self.normalize: if self.normalize:
@@ -108,7 +96,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
@@ -123,12 +111,6 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0] self.cache[cache_key] = pos[0]
return pos return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module): class PositionEmbeddingRandom(nn.Module):
""" """

View File

@@ -92,32 +92,12 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords( point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size points, self.input_image_size
) )
point_embedding[labels == -1] = 0.0
point_embedding = torch.where( point_embedding[labels == -1] += self.not_a_point_embed.weight
(labels == -1).unsqueeze(-1), point_embedding[labels == 0] += self.point_embeddings[0].weight
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding, point_embedding[labels == 2] += self.point_embeddings[2].weight
) point_embedding[labels == 3] += self.point_embeddings[3].weight
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

View File

@@ -4,7 +4,9 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import math import math
import warnings
from functools import partial from functools import partial
from typing import Tuple, Type from typing import Tuple, Type
@@ -14,6 +16,29 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):
@@ -240,6 +265,19 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
@@ -258,7 +296,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length # whether to repeat q rope to match k length
# this is needed for cross-attention to memories # this is needed for cross-attention to memories
rope_k_repeat=False, rope_k_repeat=False,
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -267,9 +305,7 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
) )
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = ( self.freqs_cis = freqs_cis
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat self.rope_k_repeat = rope_k_repeat
def forward( def forward(
@@ -303,6 +339,19 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)

View File

@@ -628,9 +628,7 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs: if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1 t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list).to( obj_pos = torch.tensor(pos_list, device=device)
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,7 @@ import os
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@@ -209,6 +210,74 @@ def load_video_frames(
"Only MP4 video and JPEG folder are supported at this moment" "Only MP4 video and JPEG folder are supported at this moment"
) )
def process_stream_frame(
img_array: np.ndarray,
image_size: int,
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
offload_to_cpu: bool = False,
compute_device: torch.device = torch.device("cuda"),
):
"""
Convert a raw image array (H,W,3 or 3,H,W) into a modelready tensor.
Steps
-----
1. Resize the shorter side to `image_size`, keeping aspect ratio,
then centercrop/pad to `image_size` × `image_size`.
2. Change layout to [3, H, W] and cast to float32 in [0,1].
3. Normalise with ImageNet statistics.
4. Optionally move to `compute_device`.
Returns
-------
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
orig_h : int
orig_w : int
"""
# ↪ uses your existing helper so behaviour matches the batch loader
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
# Normalisation (done *after* potential device move for efficiency)
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if not offload_to_cpu:
img_tensor = img_tensor.to(compute_device)
img_mean_t = img_mean_t.to(compute_device)
img_std_t = img_std_t.to(compute_device)
img_tensor.sub_(img_mean_t).div_(img_std_t)
return img_tensor, orig_h, orig_w
def _resize_and_convert_to_tensor(img_array, image_size):
"""
Resize the input image array and convert it into a tensor.
Also return original image height and width.
"""
# Convert numpy array to PIL image and ensure RGB
img_pil = Image.fromarray(img_array).convert("RGB")
# Save original size (PIL: size = (width, height))
video_width, video_height = img_pil.size
# Resize with high-quality LANCZOS filter
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
# Convert resized image back to numpy and then to float tensor
img_resized_array = np.array(img_resized)
if img_resized_array.dtype == np.uint8:
img_resized_array = img_resized_array / 255.0
else:
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
# Convert to PyTorch tensor and permute to [C, H, W]
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
return img_tensor, video_height, video_width
def load_video_frames_from_jpg_images( def load_video_frames_from_jpg_images(
video_path, video_path,

View File

@@ -22,8 +22,8 @@ with open("README.md", "r", encoding="utf-8") as f:
# Required dependencies # Required dependencies
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
"torch>=2.5.1", "torch>=2.3.1",
"torchvision>=0.20.1", "torchvision>=0.18.1",
"numpy>=1.24.4", "numpy>=1.24.4",
"tqdm>=4.66.1", "tqdm>=4.66.1",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
@@ -58,7 +58,7 @@ EXTRA_PACKAGES = {
"scikit-image>=0.24.0", "scikit-image>=0.24.0",
"tensorboard>=2.17.0", "tensorboard>=2.17.0",
"pycocotools>=2.0.8", "pycocotools>=2.0.8",
"tensordict>=0.6.0", "tensordict>=0.5.0",
"opencv-python>=4.7.0", "opencv-python>=4.7.0",
"submitit>=1.5.1", "submitit>=1.5.1",
], ],

View File

@@ -375,7 +375,7 @@ def main():
parser.add_argument( parser.add_argument(
"--sam2_checkpoint", "--sam2_checkpoint",
type=str, type=str,
default="./checkpoints/sam2.1_hiera_base_plus.pt", default="./checkpoints/sam2.1_hiera_b+.pt",
help="path to the SAM 2 model checkpoint", help="path to the SAM 2 model checkpoint",
) )
parser.add_argument( parser.add_argument(
@@ -434,11 +434,6 @@ def main():
help="whether to track objects that appear later in the video (i.e. not on the first frame; " help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
) )
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args() args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs # if we use per-object PNG files, they could possibly overlap in inputs and outputs
@@ -450,7 +445,6 @@ def main():
ckpt_path=args.sam2_checkpoint, ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing, apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra, hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
) )
if args.use_all_masks: if args.use_all_masks: