Compare commits
77 Commits
2111d9c52c
...
update_sam
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8b56c25344 | ||
![]() |
2b90b9f5ce | ||
![]() |
722d1d1511 | ||
![]() |
393ae336a7 | ||
![]() |
c2ec8e14a1 | ||
![]() |
c98aa6bea3 | ||
![]() |
ff9704fc0e | ||
![]() |
29267c8e39 | ||
![]() |
e22521832f | ||
![]() |
8bf0920e66 | ||
![]() |
52198ead0e | ||
![]() |
98fcb164bf | ||
![]() |
05d9e57fb3 | ||
![]() |
429a2c7360 | ||
![]() |
3a7889d905 | ||
![]() |
aa9b8722d0 | ||
![]() |
0f6515ae85 | ||
![]() |
7e1596c0b6 | ||
![]() |
0db838b117 | ||
![]() |
fd5125b97a | ||
![]() |
1191677e1e | ||
![]() |
dce7b5446f | ||
![]() |
1034ee2a1a | ||
![]() |
778e112740 | ||
![]() |
8f607e2de1 | ||
![]() |
46945a2122 | ||
![]() |
d421e0b040 | ||
![]() |
102ddb8899 | ||
![]() |
6186d1529a | ||
![]() |
6ecb5ff8d0 | ||
![]() |
086daf0641 | ||
![]() |
6ba4c65cb2 | ||
![]() |
9b58611e24 | ||
![]() |
6ec8560436 | ||
![]() |
43c385c263 | ||
![]() |
322aa3e7e5 | ||
![]() |
511199d7a9 | ||
![]() |
8f15c6255a | ||
![]() |
0bac418736 | ||
![]() |
27a167c004 | ||
![]() |
6f7e700c37 | ||
![]() |
a36edf1e01 | ||
![]() |
e815f70a38 | ||
![]() |
fbf7e3a664 | ||
![]() |
e9503c96fe | ||
![]() |
c3393d8b5f | ||
![]() |
0230c5ff93 | ||
![]() |
5e3d6ca6b5 | ||
![]() |
3b0fd9e4a9 | ||
![]() |
acd3939f88 | ||
![]() |
841cc1f015 | ||
![]() |
e93be7f6aa | ||
![]() |
cb48213066 | ||
![]() |
6aeee34775 | ||
![]() |
0c28c630c2 | ||
![]() |
3af4e82263 | ||
![]() |
17b74501fb | ||
![]() |
b72a8a97f0 | ||
![]() |
57bc94b739 | ||
![]() |
b744a3c084 | ||
![]() |
d1fc9a0686 | ||
![]() |
59550d4deb | ||
![]() |
de4db16676 | ||
![]() |
0e78a11899 | ||
![]() |
fa2796bb47 | ||
![]() |
86827e2fba | ||
![]() |
cd270ed4f1 | ||
![]() |
32750fa695 | ||
![]() |
e62ec497b8 | ||
![]() |
c8127182c1 | ||
![]() |
f882beb157 | ||
![]() |
82b026cd55 | ||
![]() |
de05a2e0c5 | ||
![]() |
b3011f0ea6 | ||
![]() |
662fd3d90e | ||
![]() |
658aaba327 | ||
![]() |
0c5f8c5432 |
17
.github/workflows/check_fmt.yml
vendored
Normal file
17
.github/workflows/check_fmt.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
name: SAM2/fmt
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
ufmt_check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check formatting
|
||||||
|
uses: omnilib/ufmt@action-v1
|
||||||
|
with:
|
||||||
|
path: sam2 tools
|
||||||
|
version: "2.0.0b2"
|
||||||
|
python-version: "3.10"
|
||||||
|
black-version: "24.2.0"
|
||||||
|
usort-version: "1.0.2"
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -144,4 +144,5 @@ dmypy.json
|
|||||||
*.pth
|
*.pth
|
||||||
outputs/
|
outputs/
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
demo/backend/checkpoints/*.pt
|
||||||
|
@@ -27,7 +27,7 @@ WORKDIR /home/appuser/Grounded-SAM-2
|
|||||||
|
|
||||||
|
|
||||||
# Install essential Python packages
|
# Install essential Python packages
|
||||||
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
|
RUN python -m pip install --upgrade pip setuptools wheel numpy \
|
||||||
opencv-python transformers supervision pycocotools addict yapf timm
|
opencv-python transformers supervision pycocotools addict yapf timm
|
||||||
|
|
||||||
# Install segment_anything package in editable mode
|
# Install segment_anything package in editable mode
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
### Requirements
|
### Requirements
|
||||||
|
|
||||||
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
|
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
|
||||||
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
|
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
|
||||||
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
|
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
|
||||||
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
||||||
@@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar
|
|||||||
|
|
||||||
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
|
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
|
||||||
|
|
||||||
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
|
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
|
||||||
|
|
||||||
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
|
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
4
LICENSE
4
LICENSE
@@ -186,7 +186,7 @@
|
|||||||
same "printed page" as the copyright notice for easier
|
same "printed page" as the copyright notice for easier
|
||||||
identification within third-party archives.
|
identification within third-party archives.
|
||||||
|
|
||||||
Copyright 2023 - present, IDEA Research.
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@@ -198,4 +198,4 @@
|
|||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
11
README.md
11
README.md
@@ -20,7 +20,6 @@ In this repo, we've supported the following demo with **simple implementations**
|
|||||||
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
|
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
|
||||||
|
|
||||||
## Latest updates
|
## Latest updates
|
||||||
- **2025.04.20**: Update to `dds-cloudapi-sdk` API V2 version. The V1 version in the original API for `Grounding DINO 1.5` and `DINO-X` has been deprecated, please update to the latest `dds-cloudapi-sdk` by `pip install dds-cloudapi-sdk -U` to use `Grounding DINO 1.5 / 1.6` and `DINO-X` models. Please refer to [dds-cloudapi-sdk](https://github.com/deepdataspace/dds-cloudapi-sdk) and our [API docs](https://cloud.deepdataspace.com/docs) to view more details about the update.
|
|
||||||
|
|
||||||
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
|
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
|
||||||
|
|
||||||
@@ -335,16 +334,6 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
|
|
||||||
|
|
||||||
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python grounded_sam2_tracking_camera_with_continuous_id.py
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Grounded SAM 2 Florence-2 Demos
|
## Grounded SAM 2 Florence-2 Demos
|
||||||
### Grounded SAM 2 Florence-2 Image Demo
|
### Grounded SAM 2 Florence-2 Image Demo
|
||||||
|
|
||||||
|
27
RELEASE_NOTES.md
Normal file
27
RELEASE_NOTES.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
## SAM 2 release notes
|
||||||
|
|
||||||
|
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
|
||||||
|
|
||||||
|
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
|
||||||
|
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
|
||||||
|
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
|
||||||
|
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
|
||||||
|
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
|
||||||
|
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
|
||||||
|
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
|
||||||
|
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
|
||||||
|
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
|
||||||
|
|
||||||
|
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
|
||||||
|
|
||||||
|
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
|
||||||
|
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
|
||||||
|
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
|
||||||
|
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
|
||||||
|
|
||||||
|
### 07/29/2024 -- SAM 2 is released
|
||||||
|
|
||||||
|
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
|
||||||
|
* SAM 2 code: https://github.com/facebookresearch/sam2
|
||||||
|
* SAM 2 demo: https://sam2.metademolab.com/
|
||||||
|
* SAM 2 paper: https://arxiv.org/abs/2408.00714
|
@@ -1,4 +1,4 @@
|
|||||||
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
|
ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
|
||||||
ARG MODEL_SIZE=base_plus
|
ARG MODEL_SIZE=base_plus
|
||||||
|
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
@@ -105,7 +105,7 @@ cd demo/backend/server/
|
|||||||
```bash
|
```bash
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK=1 \
|
PYTORCH_ENABLE_MPS_FALLBACK=1 \
|
||||||
APP_ROOT="$(pwd)/../../../" \
|
APP_ROOT="$(pwd)/../../../" \
|
||||||
APP_URL=http://localhost:7263 \
|
API_URL=http://localhost:7263 \
|
||||||
MODEL_SIZE=base_plus \
|
MODEL_SIZE=base_plus \
|
||||||
DATA_PATH="$(pwd)/../../data" \
|
DATA_PATH="$(pwd)/../../data" \
|
||||||
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
|
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
||||||
|
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -25,7 +27,6 @@ IMG_PATH = "notebooks/images/cars.jpg"
|
|||||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
IOU_THRESHOLD = 0.8
|
|
||||||
WITH_SLICE_INFERENCE = False
|
WITH_SLICE_INFERENCE = False
|
||||||
SLICE_WH = (480, 480)
|
SLICE_WH = (480, 480)
|
||||||
OVERLAP_RATIO = (0.2, 0.2)
|
OVERLAP_RATIO = (0.2, 0.2)
|
||||||
@@ -47,7 +48,7 @@ config = Config(token)
|
|||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task by DetectionTask class
|
||||||
# infer_image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
|
|
||||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||||
@@ -61,18 +62,13 @@ if WITH_SLICE_INFERENCE:
|
|||||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
||||||
temp_filename = tmpfile.name
|
temp_filename = tmpfile.name
|
||||||
cv2.imwrite(temp_filename, image_slice)
|
cv2.imwrite(temp_filename, image_slice)
|
||||||
infer_image_url = client.upload_file(temp_filename)
|
image_url = client.upload_file(temp_filename)
|
||||||
task = V2Task(api_path="/v2/task/dinox/detection", api_body={
|
task = DinoxTask(
|
||||||
"model": "DINO-X-1.0",
|
image_url=image_url,
|
||||||
"image": infer_image_url,
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"prompt": {
|
bbox_threshold=0.25,
|
||||||
"type":"text",
|
targets=[DetectionTarget.BBox],
|
||||||
"text":TEXT_PROMPT
|
)
|
||||||
},
|
|
||||||
"targets": ["bbox", "mask"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
})
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
# detele the tempfile
|
# detele the tempfile
|
||||||
@@ -81,7 +77,7 @@ if WITH_SLICE_INFERENCE:
|
|||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_ids = []
|
class_ids = []
|
||||||
objects = result["objects"]
|
objects = result.objects
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj.score)
|
confidences.append(obj.score)
|
||||||
@@ -106,26 +102,19 @@ if WITH_SLICE_INFERENCE:
|
|||||||
class_ids = detections.class_id
|
class_ids = detections.class_id
|
||||||
input_boxes = detections.xyxy
|
input_boxes = detections.xyxy
|
||||||
else:
|
else:
|
||||||
infer_image_url = client.upload_file(IMG_PATH)
|
image_url = client.upload_file(IMG_PATH)
|
||||||
|
|
||||||
task = V2Task(
|
task = DinoxTask(
|
||||||
api_path="/v2/task/dinox/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"model": "DINO-X-1.0",
|
bbox_threshold=0.25,
|
||||||
"image": infer_image_url,
|
targets=[DetectionTarget.BBox],
|
||||||
"prompt": {
|
|
||||||
"type":"text",
|
|
||||||
"text":TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox", "mask"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
objects = result["objects"] # the list of detected objects
|
|
||||||
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -134,9 +123,9 @@ else:
|
|||||||
class_ids = []
|
class_ids = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
cls_name = obj["category"].lower().strip()
|
cls_name = obj.category.lower().strip()
|
||||||
class_names.append(cls_name)
|
class_names.append(cls_name)
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
|
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk import DetectionTask
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -24,9 +27,8 @@ TEXT_PROMPT = "car . building ."
|
|||||||
IMG_PATH = "notebooks/images/cars.jpg"
|
IMG_PATH = "notebooks/images/cars.jpg"
|
||||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro
|
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
IOU_THRESHOLD = 0.8
|
|
||||||
WITH_SLICE_INFERENCE = False
|
WITH_SLICE_INFERENCE = False
|
||||||
SLICE_WH = (480, 480)
|
SLICE_WH = (480, 480)
|
||||||
OVERLAP_RATIO = (0.2, 0.2)
|
OVERLAP_RATIO = (0.2, 0.2)
|
||||||
@@ -47,7 +49,8 @@ config = Config(token)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task using V2Task API
|
# Step 3: run the task by DetectionTask class
|
||||||
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
|
|
||||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||||
@@ -62,33 +65,26 @@ if WITH_SLICE_INFERENCE:
|
|||||||
temp_filename = tmpfile.name
|
temp_filename = tmpfile.name
|
||||||
cv2.imwrite(temp_filename, image_slice)
|
cv2.imwrite(temp_filename, image_slice)
|
||||||
image_url = client.upload_file(temp_filename)
|
image_url = client.upload_file(temp_filename)
|
||||||
task = V2Task(
|
task = DetectionTask(
|
||||||
api_path="/v2/task/grounding_dino/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"model": GROUNDING_MODEL,
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
"image": image_url,
|
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
||||||
"prompt": {
|
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
||||||
"type": "text",
|
|
||||||
"text": TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
# delete the tempfile
|
# detele the tempfile
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_ids = []
|
class_ids = []
|
||||||
objects = result["objects"]
|
objects = result.objects
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
cls_name = obj["category"].lower().strip()
|
cls_name = obj.category.lower().strip()
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
# ensure input_boxes with shape (_, 4)
|
# ensure input_boxes with shape (_, 4)
|
||||||
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
||||||
@@ -100,7 +96,7 @@ if WITH_SLICE_INFERENCE:
|
|||||||
callback=callback,
|
callback=callback,
|
||||||
slice_wh=SLICE_WH,
|
slice_wh=SLICE_WH,
|
||||||
overlap_ratio_wh=OVERLAP_RATIO,
|
overlap_ratio_wh=OVERLAP_RATIO,
|
||||||
iou_threshold=IOU_THRESHOLD,
|
iou_threshold=0.5,
|
||||||
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
||||||
)
|
)
|
||||||
detections = slicer(cv2.imread(IMG_PATH))
|
detections = slicer(cv2.imread(IMG_PATH))
|
||||||
@@ -111,25 +107,18 @@ if WITH_SLICE_INFERENCE:
|
|||||||
else:
|
else:
|
||||||
image_url = client.upload_file(IMG_PATH)
|
image_url = client.upload_file(IMG_PATH)
|
||||||
|
|
||||||
task = V2Task(
|
task = DetectionTask(
|
||||||
api_path="/v2/task/grounding_dino/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"model": GROUNDING_MODEL,
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
"image": image_url,
|
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
||||||
"prompt": {
|
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
||||||
"type": "text",
|
|
||||||
"text": TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result["objects"] # the list of detected objects
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -138,9 +127,9 @@ else:
|
|||||||
class_ids = []
|
class_ids = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
cls_name = obj["category"].lower().strip()
|
cls_name = obj.category.lower().strip()
|
||||||
class_names.append(cls_name)
|
class_names.append(cls_name)
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
|
|
||||||
|
@@ -23,7 +23,7 @@ parser.add_argument("--text-prompt", default="car. tire.")
|
|||||||
parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
|
parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
|
||||||
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
|
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
|
||||||
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
|
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
|
||||||
parser.add_argument("--output-dir", default="outputs/grounded_sam2_hf_demo")
|
parser.add_argument("--output-dir", default="outputs/test_sam2.1")
|
||||||
parser.add_argument("--no-dump-json", action="store_true")
|
parser.add_argument("--no-dump-json", action="store_true")
|
||||||
parser.add_argument("--force-cpu", action="store_true")
|
parser.add_argument("--force-cpu", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|||||||
# use bfloat16
|
# use bfloat16
|
||||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
@@ -61,7 +61,6 @@ boxes, confidences, labels = predict(
|
|||||||
caption=text,
|
caption=text,
|
||||||
box_threshold=BOX_THRESHOLD,
|
box_threshold=BOX_THRESHOLD,
|
||||||
text_threshold=TEXT_THRESHOLD,
|
text_threshold=TEXT_THRESHOLD,
|
||||||
device=DEVICE
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# process the box prompt for SAM 2
|
# process the box prompt for SAM 2
|
||||||
@@ -71,9 +70,9 @@ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
|||||||
|
|
||||||
|
|
||||||
# FIXME: figure how does this influence the G-DINO model
|
# FIXME: figure how does this influence the G-DINO model
|
||||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
@@ -1,536 +0,0 @@
|
|||||||
import copy
|
|
||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import supervision as sv
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
||||||
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
|
||||||
from utils.common_utils import CommonUtils
|
|
||||||
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
|
||||||
from utils.track_utils import sample_points_from_masks
|
|
||||||
from utils.video_utils import create_video_from_images
|
|
||||||
|
|
||||||
# Setup environment
|
|
||||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 8:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
|
|
||||||
class GroundingDinoPredictor:
|
|
||||||
"""
|
|
||||||
Wrapper for using a GroundingDINO model for zero-shot object detection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
|
|
||||||
"""
|
|
||||||
Initialize the GroundingDINO predictor.
|
|
||||||
Args:
|
|
||||||
model_id (str): HuggingFace model ID to load.
|
|
||||||
device (str): Device to run the model on ('cuda' or 'cpu').
|
|
||||||
"""
|
|
||||||
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
||||||
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict(
|
|
||||||
self,
|
|
||||||
image: "PIL.Image.Image",
|
|
||||||
text_prompts: str,
|
|
||||||
box_threshold=0.25,
|
|
||||||
text_threshold=0.25,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Perform object detection using text prompts.
|
|
||||||
Args:
|
|
||||||
image (PIL.Image.Image): Input RGB image.
|
|
||||||
text_prompts (str): Text prompt describing target objects.
|
|
||||||
box_threshold (float): Confidence threshold for box selection.
|
|
||||||
text_threshold (float): Confidence threshold for text match.
|
|
||||||
Returns:
|
|
||||||
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
|
|
||||||
"""
|
|
||||||
inputs = self.processor(
|
|
||||||
images=image, text=text_prompts, return_tensors="pt"
|
|
||||||
).to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
|
|
||||||
results = self.processor.post_process_grounded_object_detection(
|
|
||||||
outputs,
|
|
||||||
inputs.input_ids,
|
|
||||||
box_threshold=box_threshold,
|
|
||||||
text_threshold=text_threshold,
|
|
||||||
target_sizes=[image.size[::-1]],
|
|
||||||
)
|
|
||||||
|
|
||||||
return results[0]["boxes"], results[0]["labels"]
|
|
||||||
|
|
||||||
|
|
||||||
class SAM2ImageSegmentor:
|
|
||||||
"""
|
|
||||||
Wrapper class for SAM2-based segmentation given bounding boxes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
|
|
||||||
"""
|
|
||||||
Initialize the SAM2 image segmentor.
|
|
||||||
Args:
|
|
||||||
sam_model_cfg (str): Path to the SAM2 config file.
|
|
||||||
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
|
|
||||||
device (str): Device to load the model on ('cuda' or 'cpu').
|
|
||||||
"""
|
|
||||||
from sam2.build_sam import build_sam2
|
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
|
|
||||||
self.predictor = SAM2ImagePredictor(sam_model)
|
|
||||||
|
|
||||||
def set_image(self, image: np.ndarray):
|
|
||||||
"""
|
|
||||||
Set the input image for segmentation.
|
|
||||||
Args:
|
|
||||||
image (np.ndarray): RGB image array with shape (H, W, 3).
|
|
||||||
"""
|
|
||||||
self.predictor.set_image(image)
|
|
||||||
|
|
||||||
def predict_masks_from_boxes(self, boxes: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Predict segmentation masks from given bounding boxes.
|
|
||||||
Args:
|
|
||||||
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
|
|
||||||
Returns:
|
|
||||||
Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
- masks: Binary masks per box, shape (N, H, W)
|
|
||||||
- scores: Confidence scores for each mask
|
|
||||||
- logits: Raw logits from the model
|
|
||||||
"""
|
|
||||||
masks, scores, logits = self.predictor.predict(
|
|
||||||
point_coords=None,
|
|
||||||
point_labels=None,
|
|
||||||
box=boxes,
|
|
||||||
multimask_output=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize shape to (N, H, W)
|
|
||||||
if masks.ndim == 2:
|
|
||||||
masks = masks[None]
|
|
||||||
scores = scores[None]
|
|
||||||
logits = logits[None]
|
|
||||||
elif masks.ndim == 4:
|
|
||||||
masks = masks.squeeze(1)
|
|
||||||
|
|
||||||
return masks, scores, logits
|
|
||||||
|
|
||||||
|
|
||||||
class IncrementalObjectTracker:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grounding_model_id="IDEA-Research/grounding-dino-tiny",
|
|
||||||
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
|
|
||||||
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
|
|
||||||
device="cuda",
|
|
||||||
prompt_text="car.",
|
|
||||||
detection_interval=20,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize an incremental object tracker using GroundingDINO and SAM2.
|
|
||||||
Args:
|
|
||||||
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
|
|
||||||
sam2_model_cfg (str): Path to SAM2 model config file.
|
|
||||||
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
|
|
||||||
device (str): Device to run the models on ('cuda' or 'cpu').
|
|
||||||
prompt_text (str): Initial text prompt for detection.
|
|
||||||
detection_interval (int): Frame interval between full detections.
|
|
||||||
"""
|
|
||||||
self.device = device
|
|
||||||
self.detection_interval = detection_interval
|
|
||||||
self.prompt_text = prompt_text
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
self.grounding_predictor = GroundingDinoPredictor(
|
|
||||||
model_id=grounding_model_id, device=device
|
|
||||||
)
|
|
||||||
self.sam2_segmentor = SAM2ImageSegmentor(
|
|
||||||
sam_model_cfg=sam2_model_cfg,
|
|
||||||
sam_model_ckpt=sam2_ckpt_path,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.video_predictor = build_sam2_video_predictor(
|
|
||||||
sam2_model_cfg, sam2_ckpt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize inference state
|
|
||||||
self.inference_state = self.video_predictor.init_state()
|
|
||||||
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
|
|
||||||
self.total_frames = 0
|
|
||||||
self.objects_count = 0
|
|
||||||
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
|
|
||||||
|
|
||||||
# Store tracking results
|
|
||||||
self.last_mask_dict = MaskDictionaryModel()
|
|
||||||
self.track_dict = MaskDictionaryModel()
|
|
||||||
|
|
||||||
def add_image(self, image_np: np.ndarray):
|
|
||||||
"""
|
|
||||||
Add a new image frame to the tracker and perform detection or tracking update.
|
|
||||||
Args:
|
|
||||||
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Annotated image with object masks and labels.
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
img_pil = Image.fromarray(image_np)
|
|
||||||
|
|
||||||
# Step 1: Perform detection every detection_interval frames
|
|
||||||
if self.total_frames % self.detection_interval == 0:
|
|
||||||
if (
|
|
||||||
self.inference_state["video_height"] is None
|
|
||||||
or self.inference_state["video_width"] is None
|
|
||||||
):
|
|
||||||
(
|
|
||||||
self.inference_state["video_height"],
|
|
||||||
self.inference_state["video_width"],
|
|
||||||
) = image_np.shape[:2]
|
|
||||||
|
|
||||||
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
|
|
||||||
print(
|
|
||||||
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
|
|
||||||
)
|
|
||||||
self.inference_state = self.video_predictor.init_state()
|
|
||||||
self.inference_state["images"] = torch.empty(
|
|
||||||
(0, 3, 1024, 1024), device=self.device
|
|
||||||
)
|
|
||||||
(
|
|
||||||
self.inference_state["video_height"],
|
|
||||||
self.inference_state["video_width"],
|
|
||||||
) = image_np.shape[:2]
|
|
||||||
|
|
||||||
# 1.1 GroundingDINO object detection
|
|
||||||
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
|
|
||||||
if boxes.shape[0] == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 1.2 SAM2 segmentation from detection boxes
|
|
||||||
self.sam2_segmentor.set_image(image_np)
|
|
||||||
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
|
|
||||||
|
|
||||||
# 1.3 Build MaskDictionaryModel
|
|
||||||
mask_dict = MaskDictionaryModel(
|
|
||||||
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
|
|
||||||
)
|
|
||||||
mask_dict.add_new_frame_annotation(
|
|
||||||
mask_list=torch.tensor(masks).to(self.device),
|
|
||||||
box_list=torch.tensor(boxes),
|
|
||||||
label_list=labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1.4 Object ID tracking and IOU-based update
|
|
||||||
self.objects_count = mask_dict.update_masks(
|
|
||||||
tracking_annotation_dict=self.last_mask_dict,
|
|
||||||
iou_threshold=0.3,
|
|
||||||
objects_count=self.objects_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1.5 Reset video tracker state
|
|
||||||
frame_idx = self.video_predictor.add_new_frame(
|
|
||||||
self.inference_state, image_np
|
|
||||||
)
|
|
||||||
self.video_predictor.reset_state(self.inference_state)
|
|
||||||
|
|
||||||
for object_id, object_info in mask_dict.labels.items():
|
|
||||||
frame_idx, _, _ = self.video_predictor.add_new_mask(
|
|
||||||
self.inference_state,
|
|
||||||
frame_idx,
|
|
||||||
object_id,
|
|
||||||
object_info.mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.track_dict = copy.deepcopy(mask_dict)
|
|
||||||
self.last_mask_dict = mask_dict
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Step 2: Use incremental tracking for intermediate frames
|
|
||||||
frame_idx = self.video_predictor.add_new_frame(
|
|
||||||
self.inference_state, image_np
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Tracking propagation using the video predictor
|
|
||||||
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
|
|
||||||
inference_state=self.inference_state,
|
|
||||||
frame_idx=frame_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 4: Update the mask dictionary based on tracked masks
|
|
||||||
frame_masks = MaskDictionaryModel()
|
|
||||||
for i, obj_id in enumerate(obj_ids):
|
|
||||||
out_mask = video_res_masks[i] > 0.0
|
|
||||||
object_info = ObjectInfo(
|
|
||||||
instance_id=obj_id,
|
|
||||||
mask=out_mask[0],
|
|
||||||
class_name=self.track_dict.get_target_class_name(obj_id),
|
|
||||||
logit=self.track_dict.get_target_logit(obj_id),
|
|
||||||
)
|
|
||||||
object_info.update_box()
|
|
||||||
frame_masks.labels[obj_id] = object_info
|
|
||||||
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
|
|
||||||
frame_masks.mask_height = out_mask.shape[-2]
|
|
||||||
frame_masks.mask_width = out_mask.shape[-1]
|
|
||||||
|
|
||||||
self.last_mask_dict = copy.deepcopy(frame_masks)
|
|
||||||
|
|
||||||
# Step 5: Build mask array
|
|
||||||
H, W = image_np.shape[:2]
|
|
||||||
mask_img = torch.zeros((H, W), dtype=torch.int32)
|
|
||||||
for obj_id, obj_info in self.last_mask_dict.labels.items():
|
|
||||||
mask_img[obj_info.mask == True] = obj_id
|
|
||||||
|
|
||||||
mask_array = mask_img.cpu().numpy()
|
|
||||||
|
|
||||||
# Step 6: Visualization
|
|
||||||
annotated_frame = self.visualize_frame_with_mask_and_metadata(
|
|
||||||
image_np=image_np,
|
|
||||||
mask_array=mask_array,
|
|
||||||
json_metadata=self.last_mask_dict.to_dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[Tracker] Total processed frames: {self.total_frames}")
|
|
||||||
self.total_frames += 1
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return annotated_frame
|
|
||||||
|
|
||||||
def set_prompt(self, new_prompt: str):
|
|
||||||
"""
|
|
||||||
Dynamically update the GroundingDINO prompt and reset tracking state
|
|
||||||
to force a new object detection.
|
|
||||||
"""
|
|
||||||
self.prompt_text = new_prompt
|
|
||||||
self.total_frames = 0 # Trigger immediate re-detection
|
|
||||||
self.inference_state = self.video_predictor.init_state()
|
|
||||||
self.inference_state["images"] = torch.empty(
|
|
||||||
(0, 3, 1024, 1024), device=self.device
|
|
||||||
)
|
|
||||||
self.inference_state["video_height"] = None
|
|
||||||
self.inference_state["video_width"] = None
|
|
||||||
|
|
||||||
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
|
|
||||||
|
|
||||||
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
|
|
||||||
"""
|
|
||||||
Save the current mask, metadata, raw image, and annotated result.
|
|
||||||
Args:
|
|
||||||
output_dir (str): The root output directory.
|
|
||||||
raw_image (np.ndarray, optional): The original input image (RGB).
|
|
||||||
"""
|
|
||||||
mask_data_dir = os.path.join(output_dir, "mask_data")
|
|
||||||
json_data_dir = os.path.join(output_dir, "json_data")
|
|
||||||
image_data_dir = os.path.join(output_dir, "images")
|
|
||||||
vis_data_dir = os.path.join(output_dir, "result")
|
|
||||||
|
|
||||||
os.makedirs(mask_data_dir, exist_ok=True)
|
|
||||||
os.makedirs(json_data_dir, exist_ok=True)
|
|
||||||
os.makedirs(image_data_dir, exist_ok=True)
|
|
||||||
os.makedirs(vis_data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
frame_masks = self.last_mask_dict
|
|
||||||
|
|
||||||
# Ensure mask_name is valid
|
|
||||||
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
|
|
||||||
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
|
|
||||||
|
|
||||||
base_name = f"image_{self.total_frames:05d}"
|
|
||||||
|
|
||||||
# Save segmentation mask
|
|
||||||
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
|
|
||||||
for obj_id, obj_info in frame_masks.labels.items():
|
|
||||||
mask_img[obj_info.mask == True] = obj_id
|
|
||||||
np.save(
|
|
||||||
os.path.join(mask_data_dir, frame_masks.mask_name),
|
|
||||||
mask_img.numpy().astype(np.uint16),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save metadata as JSON
|
|
||||||
json_path = os.path.join(json_data_dir, base_name + ".json")
|
|
||||||
frame_masks.to_json(json_path)
|
|
||||||
|
|
||||||
# Save raw input image
|
|
||||||
if raw_image is not None:
|
|
||||||
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
|
|
||||||
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
|
|
||||||
|
|
||||||
# Save annotated image with mask, bounding boxes, and labels
|
|
||||||
annotated_image = self.visualize_frame_with_mask_and_metadata(
|
|
||||||
image_np=raw_image,
|
|
||||||
mask_array=mask_img.numpy().astype(np.uint16),
|
|
||||||
json_metadata=frame_masks.to_dict(),
|
|
||||||
)
|
|
||||||
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
|
|
||||||
cv2.imwrite(
|
|
||||||
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
|
|
||||||
)
|
|
||||||
|
|
||||||
def visualize_frame_with_mask_and_metadata(
|
|
||||||
self,
|
|
||||||
image_np: np.ndarray,
|
|
||||||
mask_array: np.ndarray,
|
|
||||||
json_metadata: dict,
|
|
||||||
):
|
|
||||||
image = image_np.copy()
|
|
||||||
H, W = image.shape[:2]
|
|
||||||
|
|
||||||
# Step 1: Parse metadata and build object entries
|
|
||||||
metadata_lookup = json_metadata.get("labels", {})
|
|
||||||
|
|
||||||
all_object_ids = []
|
|
||||||
all_object_boxes = []
|
|
||||||
all_object_classes = []
|
|
||||||
all_object_masks = []
|
|
||||||
|
|
||||||
for obj_id_str, obj_info in metadata_lookup.items():
|
|
||||||
instance_id = obj_info.get("instance_id")
|
|
||||||
if instance_id is None or instance_id == 0:
|
|
||||||
continue
|
|
||||||
if instance_id not in np.unique(mask_array):
|
|
||||||
continue
|
|
||||||
|
|
||||||
object_mask = mask_array == instance_id
|
|
||||||
all_object_ids.append(instance_id)
|
|
||||||
x1 = obj_info.get("x1", 0)
|
|
||||||
y1 = obj_info.get("y1", 0)
|
|
||||||
x2 = obj_info.get("x2", 0)
|
|
||||||
y2 = obj_info.get("y2", 0)
|
|
||||||
all_object_boxes.append([x1, y1, x2, y2])
|
|
||||||
all_object_classes.append(obj_info.get("class_name", "unknown"))
|
|
||||||
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
|
|
||||||
|
|
||||||
# Step 2: Check if valid objects exist
|
|
||||||
if len(all_object_ids) == 0:
|
|
||||||
print("No valid object instances found in metadata.")
|
|
||||||
return image
|
|
||||||
|
|
||||||
# Step 3: Sort by instance ID
|
|
||||||
paired = list(
|
|
||||||
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
|
|
||||||
)
|
|
||||||
paired.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
all_object_ids = [p[0] for p in paired]
|
|
||||||
all_object_boxes = [p[1] for p in paired]
|
|
||||||
all_object_masks = [p[2] for p in paired]
|
|
||||||
all_object_classes = [p[3] for p in paired]
|
|
||||||
|
|
||||||
# Step 4: Build detections
|
|
||||||
all_object_masks = np.concatenate(all_object_masks, axis=0)
|
|
||||||
detections = sv.Detections(
|
|
||||||
xyxy=np.array(all_object_boxes),
|
|
||||||
mask=all_object_masks,
|
|
||||||
class_id=np.array(all_object_ids, dtype=np.int32),
|
|
||||||
)
|
|
||||||
labels = [
|
|
||||||
f"{instance_id}: {class_name}"
|
|
||||||
for instance_id, class_name in zip(all_object_ids, all_object_classes)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Step 5: Annotate image
|
|
||||||
annotated_frame = image.copy()
|
|
||||||
mask_annotator = sv.MaskAnnotator()
|
|
||||||
box_annotator = sv.BoxAnnotator()
|
|
||||||
label_annotator = sv.LabelAnnotator()
|
|
||||||
|
|
||||||
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
|
|
||||||
annotated_frame = box_annotator.annotate(annotated_frame, detections)
|
|
||||||
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
|
|
||||||
|
|
||||||
return annotated_frame
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
from utils.common_utils import CommonUtils
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Parameter settings
|
|
||||||
output_dir = "./outputs"
|
|
||||||
prompt_text = "hand."
|
|
||||||
detection_interval = 20
|
|
||||||
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Initialize the object tracker
|
|
||||||
tracker = IncrementalObjectTracker(
|
|
||||||
grounding_model_id="IDEA-Research/grounding-dino-tiny",
|
|
||||||
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
|
|
||||||
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
|
|
||||||
device="cuda",
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
detection_interval=detection_interval,
|
|
||||||
)
|
|
||||||
tracker.set_prompt("person.")
|
|
||||||
|
|
||||||
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
|
|
||||||
cap = cv2.VideoCapture(0)
|
|
||||||
if not cap.isOpened():
|
|
||||||
print("[Error] Cannot open camera.")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("[Info] Camera opened. Press 'q' to quit.")
|
|
||||||
frame_idx = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
print("[Warning] Failed to capture frame.")
|
|
||||||
break
|
|
||||||
|
|
||||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
print(f"[Frame {frame_idx}] Processing live frame...")
|
|
||||||
process_image = tracker.add_image(frame_rgb)
|
|
||||||
|
|
||||||
if process_image is None or not isinstance(process_image, np.ndarray):
|
|
||||||
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
|
|
||||||
frame_idx += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
|
|
||||||
# cv2.imshow("Live Inference", process_image_bgr)
|
|
||||||
|
|
||||||
|
|
||||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
||||||
# print("[Info] Quit signal received.")
|
|
||||||
# break
|
|
||||||
|
|
||||||
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
|
|
||||||
frame_idx += 1
|
|
||||||
|
|
||||||
if frame_idx >= max_frames:
|
|
||||||
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
|
|
||||||
break
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("[Info] Interrupted by user (Ctrl+C).")
|
|
||||||
finally:
|
|
||||||
cap.release()
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
print("[Done] Live inference complete.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@@ -1,7 +1,9 @@
|
|||||||
# dds cloudapi for DINO-X - update to V2Task API
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
||||||
|
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -28,7 +30,6 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
|||||||
API_TOKEN_FOR_DINOX = "Your API token"
|
API_TOKEN_FOR_DINOX = "Your API token"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -97,29 +98,22 @@ config = Config(API_TOKEN_FOR_DINOX)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task using V2Task class
|
# Step 3: run the task by DetectionTask class
|
||||||
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = V2Task(
|
task = DinoxTask(
|
||||||
api_path="/v2/task/dinox/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"model": "DINO-X-1.0",
|
bbox_threshold=0.25,
|
||||||
"image": image_url,
|
targets=[DetectionTarget.BBox],
|
||||||
"prompt": {
|
|
||||||
"type": "text",
|
|
||||||
"text": TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result["objects"] # the list of detected objects
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -127,9 +121,9 @@ confidences = []
|
|||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
class_names.append(obj["category"])
|
class_names.append(obj.category)
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5 - Update to V2Task API
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk import DetectionTask
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -28,7 +31,6 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
|||||||
API_TOKEN_FOR_GD1_5 = "Your API token"
|
API_TOKEN_FOR_GD1_5 = "Your API token"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -97,38 +99,33 @@ config = Config(API_TOKEN_FOR_GD1_5)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task using V2Task class
|
# Step 3: run the task by DetectionTask class
|
||||||
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = V2Task(
|
task = DetectionTask(
|
||||||
api_path="/v2/task/grounding_dino/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
||||||
"model": "GroundingDino-1.5-Pro",
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
"image": image_url,
|
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
||||||
"prompt": {
|
bbox_threshold=BOX_THRESHOLD,
|
||||||
"type": "text",
|
|
||||||
"text": TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result["objects"] # the list of detected objects
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
class_names.append(obj["category"])
|
class_names.append(obj.category)
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -1,7 +1,11 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk import DetectionTask
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@@ -47,9 +51,6 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
|
|||||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||||
# VERY important: text queries need to be lowercased + end with a dot
|
# VERY important: text queries need to be lowercased + end with a dot
|
||||||
text = "car."
|
text = "car."
|
||||||
BOX_THRESHOLD = 0.2
|
|
||||||
IOU_THRESHOLD = 0.8
|
|
||||||
GROUNDING_MODEL = "GroundingDino-1.6-Pro" # 使用字符串替代枚举值
|
|
||||||
|
|
||||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||||
video_dir = "notebooks/videos/car"
|
video_dir = "notebooks/videos/car"
|
||||||
@@ -101,32 +102,24 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
task = V2Task(
|
task = DetectionTask(
|
||||||
api_path="/v2/task/grounding_dino/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text=text)],
|
||||||
"model": GROUNDING_MODEL,
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
"image": image_url,
|
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
||||||
"prompt": {
|
|
||||||
"type": "text",
|
|
||||||
"text": text
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result["objects"] # the list of detected objects
|
objects = result.objects # the list of detected objects
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
class_names.append(obj["category"])
|
class_names.append(obj.category)
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
OBJECTS = class_names
|
OBJECTS = class_names
|
||||||
@@ -161,7 +154,7 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=IOU_THRESHOLD, objects_count=objects_count)
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||||
print("objects_count", objects_count)
|
print("objects_count", objects_count)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
from dds_cloudapi_sdk import DetectionTask
|
||||||
|
from dds_cloudapi_sdk import TextPrompt
|
||||||
|
from dds_cloudapi_sdk import DetectionModel
|
||||||
|
from dds_cloudapi_sdk import DetectionTarget
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -51,11 +54,6 @@ inference_state = video_predictor.init_state(video_path=video_dir)
|
|||||||
ann_frame_idx = 0 # the frame index we interact with
|
ann_frame_idx = 0 # the frame index we interact with
|
||||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||||
|
|
||||||
# 添加参数设置
|
|
||||||
TEXT_PROMPT = "children. pillow"
|
|
||||||
BOX_THRESHOLD = 0.2
|
|
||||||
IOU_THRESHOLD = 0.8
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
||||||
@@ -72,29 +70,23 @@ config = Config(token)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task using V2Task class
|
# Step 3: run the task by DetectionTask class
|
||||||
|
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = V2Task(
|
task = DetectionTask(
|
||||||
api_path="/v2/task/grounding_dino/detection",
|
image_url=image_url,
|
||||||
api_body={
|
prompts=[TextPrompt(text="children. pillow")],
|
||||||
"model": "GroundingDino-1.5-Pro",
|
targets=[DetectionTarget.BBox], # detect bbox
|
||||||
"image": image_url,
|
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
|
||||||
"prompt": {
|
bbox_threshold=0.2,
|
||||||
"type": "text",
|
|
||||||
"text": TEXT_PROMPT
|
|
||||||
},
|
|
||||||
"targets": ["bbox"],
|
|
||||||
"bbox_threshold": BOX_THRESHOLD,
|
|
||||||
"iou_threshold": IOU_THRESHOLD,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result["objects"] # the list of detected objects
|
objects = result.objects # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -102,9 +94,9 @@ confidences = []
|
|||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj["bbox"])
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj["score"])
|
confidences.append(obj.score)
|
||||||
class_names.append(obj["category"])
|
class_names.append(obj.category)
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -15,24 +15,11 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <torch/version.h>
|
|
||||||
|
|
||||||
// Check PyTorch version and define appropriate macros
|
|
||||||
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
|
|
||||||
// PyTorch 2.x and above
|
|
||||||
#define GET_TENSOR_TYPE(x) x.scalar_type()
|
|
||||||
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
|
|
||||||
#else
|
|
||||||
// PyTorch 1.x
|
|
||||||
#define GET_TENSOR_TYPE(x) x.type()
|
|
||||||
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace groundingdino {
|
namespace groundingdino {
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
at::Tensor ms_deform_attn_cuda_forward(
|
||||||
const at::Tensor &value,
|
const at::Tensor &value,
|
||||||
const at::Tensor &spatial_shapes,
|
const at::Tensor &spatial_shapes,
|
||||||
const at::Tensor &level_start_index,
|
const at::Tensor &level_start_index,
|
||||||
const at::Tensor &sampling_loc,
|
const at::Tensor &sampling_loc,
|
||||||
@@ -45,11 +32,11 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -64,7 +51,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
const int im2col_step_ = std::min(batch, im2col_step);
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||||
|
|
||||||
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
const int batch_n = im2col_step_;
|
||||||
@@ -75,7 +62,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -95,7 +82,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
||||||
const at::Tensor &value,
|
const at::Tensor &value,
|
||||||
const at::Tensor &spatial_shapes,
|
const at::Tensor &spatial_shapes,
|
||||||
const at::Tensor &level_start_index,
|
const at::Tensor &level_start_index,
|
||||||
const at::Tensor &sampling_loc,
|
const at::Tensor &sampling_loc,
|
||||||
@@ -111,12 +98,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
||||||
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
|
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -141,11 +128,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
||||||
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
||||||
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
@@ -166,4 +153,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace groundingdino
|
} // namespace groundingdino
|
@@ -1,6 +1,6 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
"setuptools>=62.3.0,<75.9",
|
"setuptools>=61.0",
|
||||||
"torch>=2.3.1",
|
"torch>=2.5.1",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
92
sam2/benchmark.py
Normal file
92
sam2/benchmark.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
|
# Only cuda supported
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Config and checkpoint
|
||||||
|
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
|
||||||
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||||
|
|
||||||
|
# Build video predictor with vos_optimized=True setting
|
||||||
|
predictor = build_sam2_video_predictor(
|
||||||
|
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize with video
|
||||||
|
video_dir = "notebooks/videos/bedroom"
|
||||||
|
# scan all the JPEG frame names in this directory
|
||||||
|
frame_names = [
|
||||||
|
p
|
||||||
|
for p in os.listdir(video_dir)
|
||||||
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||||
|
]
|
||||||
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
inference_state = predictor.init_state(video_path=video_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# Number of runs, warmup etc
|
||||||
|
warm_up, runs = 5, 25
|
||||||
|
verbose = True
|
||||||
|
num_frames = len(frame_names)
|
||||||
|
total, count = 0, 0
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# We will select an object with a click.
|
||||||
|
# See video_predictor_example.ipynb for more detailed explanation
|
||||||
|
ann_frame_idx, ann_obj_id = 0, 1
|
||||||
|
# Add a positive click at (x, y) = (210, 350)
|
||||||
|
# For labels, `1` means positive click
|
||||||
|
points = np.array([[210, 350]], dtype=np.float32)
|
||||||
|
labels = np.array([1], np.int32)
|
||||||
|
|
||||||
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=ann_obj_id,
|
||||||
|
points=points,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup and then average FPS over several runs
|
||||||
|
with torch.autocast("cuda", torch.bfloat16):
|
||||||
|
with torch.inference_mode():
|
||||||
|
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
|
||||||
|
start = time.time()
|
||||||
|
# Start tracking
|
||||||
|
for (
|
||||||
|
out_frame_idx,
|
||||||
|
out_obj_ids,
|
||||||
|
out_mask_logits,
|
||||||
|
) in predictor.propagate_in_video(inference_state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
total += end - start
|
||||||
|
count += 1
|
||||||
|
if i == warm_up - 1:
|
||||||
|
print("Warmup FPS: ", count * num_frames / total)
|
||||||
|
total = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
print("FPS: ", count * num_frames / total)
|
@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
|
|||||||
mode="eval",
|
mode="eval",
|
||||||
hydra_overrides_extra=[],
|
hydra_overrides_extra=[],
|
||||||
apply_postprocessing=True,
|
apply_postprocessing=True,
|
||||||
|
vos_optimized=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
hydra_overrides = [
|
hydra_overrides = [
|
||||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||||
]
|
]
|
||||||
|
if vos_optimized:
|
||||||
|
hydra_overrides = [
|
||||||
|
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
|
||||||
|
"++model.compile_image_encoder=True", # Let sam2_base handle this
|
||||||
|
]
|
||||||
|
|
||||||
if apply_postprocessing:
|
if apply_postprocessing:
|
||||||
hydra_overrides_extra = hydra_overrides_extra.copy()
|
hydra_overrides_extra = hydra_overrides_extra.copy()
|
||||||
hydra_overrides_extra += [
|
hydra_overrides_extra += [
|
||||||
|
@@ -36,7 +36,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -47,7 +47,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -40,7 +40,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -51,7 +51,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -97,7 +97,7 @@ trainer:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -108,7 +108,7 @@ trainer:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -36,7 +36,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -47,7 +47,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -40,7 +40,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -51,7 +51,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@@ -32,9 +32,7 @@ def window_partition(x, window_size):
|
|||||||
Hp, Wp = H + pad_h, W + pad_w
|
Hp, Wp = H + pad_h, W + pad_w
|
||||||
|
|
||||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||||
windows = (
|
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
||||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
||||||
)
|
|
||||||
return windows, (Hp, Wp)
|
return windows, (Hp, Wp)
|
||||||
|
|
||||||
|
|
||||||
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|||||||
Hp, Wp = pad_hw
|
Hp, Wp = pad_hw
|
||||||
H, W = hw
|
H, W = hw
|
||||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||||
x = windows.view(
|
x = windows.reshape(
|
||||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||||
)
|
)
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
||||||
|
|
||||||
if Hp > H or Wp > W:
|
if Hp > H or Wp > W:
|
||||||
x = x[:, :H, :W, :].contiguous()
|
x = x[:, :H, :W, :]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
temperature: int = 10000,
|
temperature: int = 10000,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
scale: Optional[float] = None,
|
scale: Optional[float] = None,
|
||||||
|
# Following settings only relevant
|
||||||
|
# for warmping up cache for compilation
|
||||||
|
warmup_cache: bool = True,
|
||||||
|
image_size: int = 1024,
|
||||||
|
strides: Tuple[int] = (4, 8, 16, 32),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
||||||
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
if warmup_cache and torch.cuda.is_available():
|
||||||
|
# Warmup cache for cuda, to help with compilation
|
||||||
|
device = torch.device("cuda")
|
||||||
|
for stride in strides:
|
||||||
|
cache_key = (image_size // stride, image_size // stride)
|
||||||
|
self._pe(1, device, *cache_key)
|
||||||
|
|
||||||
def _encode_xy(self, x, y):
|
def _encode_xy(self, x, y):
|
||||||
# The positions are expected to be normalized
|
# The positions are expected to be normalized
|
||||||
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
return pos
|
return pos
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x: torch.Tensor):
|
def _pe(self, B, device, *cache_key):
|
||||||
cache_key = (x.shape[-2], x.shape[-1])
|
H, W = cache_key
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||||
|
|
||||||
y_embed = (
|
y_embed = (
|
||||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||||
.view(1, -1, 1)
|
.view(1, -1, 1)
|
||||||
.repeat(x.shape[0], 1, x.shape[-1])
|
.repeat(B, 1, W)
|
||||||
)
|
)
|
||||||
x_embed = (
|
x_embed = (
|
||||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
||||||
.view(1, 1, -1)
|
.view(1, 1, -1)
|
||||||
.repeat(x.shape[0], x.shape[-2], 1)
|
.repeat(B, H, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
self.cache[cache_key] = pos[0]
|
self.cache[cache_key] = pos[0]
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
B = x.shape[0]
|
||||||
|
cache_key = (x.shape[-2], x.shape[-1])
|
||||||
|
return self._pe(B, x.device, *cache_key)
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingRandom(nn.Module):
|
class PositionEmbeddingRandom(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
|
|||||||
point_embedding = self.pe_layer.forward_with_coords(
|
point_embedding = self.pe_layer.forward_with_coords(
|
||||||
points, self.input_image_size
|
points, self.input_image_size
|
||||||
)
|
)
|
||||||
point_embedding[labels == -1] = 0.0
|
|
||||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
point_embedding = torch.where(
|
||||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
(labels == -1).unsqueeze(-1),
|
||||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
||||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
point_embedding,
|
||||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 0).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[0].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 1).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[1].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 2).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[2].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 3).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[3].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
return point_embedding
|
return point_embedding
|
||||||
|
|
||||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||||
|
@@ -4,9 +4,7 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
@@ -16,29 +14,6 @@ from torch import nn, Tensor
|
|||||||
|
|
||||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||||
from sam2.modeling.sam2_utils import MLP
|
from sam2.modeling.sam2_utils import MLP
|
||||||
from sam2.utils.misc import get_sdpa_settings
|
|
||||||
|
|
||||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
||||||
# Check whether Flash Attention is available (and use it by default)
|
|
||||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
|
||||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
|
||||||
ALLOW_ALL_KERNELS = False
|
|
||||||
|
|
||||||
|
|
||||||
def sdp_kernel_context(dropout_p):
|
|
||||||
"""
|
|
||||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
|
||||||
by default, but fall back to all available kernels if Flash Attention fails.
|
|
||||||
"""
|
|
||||||
if ALLOW_ALL_KERNELS:
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
return torch.backends.cuda.sdp_kernel(
|
|
||||||
enable_flash=USE_FLASH_ATTN,
|
|
||||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
|
||||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
|
||||||
enable_mem_efficient=OLD_GPU,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TwoWayTransformer(nn.Module):
|
class TwoWayTransformer(nn.Module):
|
||||||
@@ -265,20 +240,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dropout_p = self.dropout_p if self.training else 0.0
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
# Attention
|
# Attention
|
||||||
try:
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||||
with sdp_kernel_context(dropout_p):
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
except Exception as e:
|
|
||||||
# Fall back to all kernels if the Flash attention kernel fails
|
|
||||||
warnings.warn(
|
|
||||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
||||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
||||||
category=UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
global ALLOW_ALL_KERNELS
|
|
||||||
ALLOW_ALL_KERNELS = True
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
|
|
||||||
out = self._recombine_heads(out)
|
out = self._recombine_heads(out)
|
||||||
out = self.out_proj(out)
|
out = self.out_proj(out)
|
||||||
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
|
|||||||
# whether to repeat q rope to match k length
|
# whether to repeat q rope to match k length
|
||||||
# this is needed for cross-attention to memories
|
# this is needed for cross-attention to memories
|
||||||
rope_k_repeat=False,
|
rope_k_repeat=False,
|
||||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
|
|||||||
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
||||||
)
|
)
|
||||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||||
self.freqs_cis = freqs_cis
|
self.freqs_cis = (
|
||||||
|
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
||||||
|
)
|
||||||
self.rope_k_repeat = rope_k_repeat
|
self.rope_k_repeat = rope_k_repeat
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
|
|||||||
|
|
||||||
dropout_p = self.dropout_p if self.training else 0.0
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
# Attention
|
# Attention
|
||||||
try:
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||||
with sdp_kernel_context(dropout_p):
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
except Exception as e:
|
|
||||||
# Fall back to all kernels if the Flash attention kernel fails
|
|
||||||
warnings.warn(
|
|
||||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
||||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
||||||
category=UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
global ALLOW_ALL_KERNELS
|
|
||||||
ALLOW_ALL_KERNELS = True
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
|
|
||||||
out = self._recombine_heads(out)
|
out = self._recombine_heads(out)
|
||||||
out = self.out_proj(out)
|
out = self.out_proj(out)
|
||||||
|
@@ -628,7 +628,9 @@ class SAM2Base(torch.nn.Module):
|
|||||||
if self.add_tpos_enc_to_obj_ptrs:
|
if self.add_tpos_enc_to_obj_ptrs:
|
||||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||||
obj_pos = torch.tensor(pos_list, device=device)
|
obj_pos = torch.tensor(pos_list).to(
|
||||||
|
device=device, non_blocking=True
|
||||||
|
)
|
||||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
||||||
|
File diff suppressed because it is too large
Load Diff
1172
sam2/sam2_video_predictor_legacy.py
Normal file
1172
sam2/sam2_video_predictor_legacy.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,6 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -210,74 +209,6 @@ def load_video_frames(
|
|||||||
"Only MP4 video and JPEG folder are supported at this moment"
|
"Only MP4 video and JPEG folder are supported at this moment"
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_stream_frame(
|
|
||||||
img_array: np.ndarray,
|
|
||||||
image_size: int,
|
|
||||||
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
||||||
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
||||||
offload_to_cpu: bool = False,
|
|
||||||
compute_device: torch.device = torch.device("cuda"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Convert a raw image array (H,W,3 or 3,H,W) into a model‑ready tensor.
|
|
||||||
Steps
|
|
||||||
-----
|
|
||||||
1. Resize the shorter side to `image_size`, keeping aspect ratio,
|
|
||||||
then center‑crop/pad to `image_size` × `image_size`.
|
|
||||||
2. Change layout to [3, H, W] and cast to float32 in [0,1].
|
|
||||||
3. Normalise with ImageNet statistics.
|
|
||||||
4. Optionally move to `compute_device`.
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
|
|
||||||
orig_h : int
|
|
||||||
orig_w : int
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ↪ uses your existing helper so behaviour matches the batch loader
|
|
||||||
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
|
|
||||||
|
|
||||||
# Normalisation (done *after* potential device move for efficiency)
|
|
||||||
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
|
|
||||||
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
|
|
||||||
|
|
||||||
if not offload_to_cpu:
|
|
||||||
img_tensor = img_tensor.to(compute_device)
|
|
||||||
img_mean_t = img_mean_t.to(compute_device)
|
|
||||||
img_std_t = img_std_t.to(compute_device)
|
|
||||||
|
|
||||||
img_tensor.sub_(img_mean_t).div_(img_std_t)
|
|
||||||
|
|
||||||
return img_tensor, orig_h, orig_w
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_and_convert_to_tensor(img_array, image_size):
|
|
||||||
"""
|
|
||||||
Resize the input image array and convert it into a tensor.
|
|
||||||
Also return original image height and width.
|
|
||||||
"""
|
|
||||||
# Convert numpy array to PIL image and ensure RGB
|
|
||||||
img_pil = Image.fromarray(img_array).convert("RGB")
|
|
||||||
|
|
||||||
# Save original size (PIL: size = (width, height))
|
|
||||||
video_width, video_height = img_pil.size
|
|
||||||
|
|
||||||
# Resize with high-quality LANCZOS filter
|
|
||||||
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
# Convert resized image back to numpy and then to float tensor
|
|
||||||
img_resized_array = np.array(img_resized)
|
|
||||||
|
|
||||||
if img_resized_array.dtype == np.uint8:
|
|
||||||
img_resized_array = img_resized_array / 255.0
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
|
|
||||||
|
|
||||||
# Convert to PyTorch tensor and permute to [C, H, W]
|
|
||||||
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
|
|
||||||
|
|
||||||
return img_tensor, video_height, video_width
|
|
||||||
|
|
||||||
|
|
||||||
def load_video_frames_from_jpg_images(
|
def load_video_frames_from_jpg_images(
|
||||||
video_path,
|
video_path,
|
||||||
|
6
setup.py
6
setup.py
@@ -22,8 +22,8 @@ with open("README.md", "r", encoding="utf-8") as f:
|
|||||||
|
|
||||||
# Required dependencies
|
# Required dependencies
|
||||||
REQUIRED_PACKAGES = [
|
REQUIRED_PACKAGES = [
|
||||||
"torch>=2.3.1",
|
"torch>=2.5.1",
|
||||||
"torchvision>=0.18.1",
|
"torchvision>=0.20.1",
|
||||||
"numpy>=1.24.4",
|
"numpy>=1.24.4",
|
||||||
"tqdm>=4.66.1",
|
"tqdm>=4.66.1",
|
||||||
"hydra-core>=1.3.2",
|
"hydra-core>=1.3.2",
|
||||||
@@ -58,7 +58,7 @@ EXTRA_PACKAGES = {
|
|||||||
"scikit-image>=0.24.0",
|
"scikit-image>=0.24.0",
|
||||||
"tensorboard>=2.17.0",
|
"tensorboard>=2.17.0",
|
||||||
"pycocotools>=2.0.8",
|
"pycocotools>=2.0.8",
|
||||||
"tensordict>=0.5.0",
|
"tensordict>=0.6.0",
|
||||||
"opencv-python>=4.7.0",
|
"opencv-python>=4.7.0",
|
||||||
"submitit>=1.5.1",
|
"submitit>=1.5.1",
|
||||||
],
|
],
|
||||||
|
@@ -375,7 +375,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sam2_checkpoint",
|
"--sam2_checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
default="./checkpoints/sam2.1_hiera_b+.pt",
|
default="./checkpoints/sam2.1_hiera_base_plus.pt",
|
||||||
help="path to the SAM 2 model checkpoint",
|
help="path to the SAM 2 model checkpoint",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -434,6 +434,11 @@ def main():
|
|||||||
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
|
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
|
||||||
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
|
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_vos_optimized_video_predictor",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to use vos optimized video predictor with all modules compiled",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
||||||
@@ -445,6 +450,7 @@ def main():
|
|||||||
ckpt_path=args.sam2_checkpoint,
|
ckpt_path=args.sam2_checkpoint,
|
||||||
apply_postprocessing=args.apply_postprocessing,
|
apply_postprocessing=args.apply_postprocessing,
|
||||||
hydra_overrides_extra=hydra_overrides_extra,
|
hydra_overrides_extra=hydra_overrides_extra,
|
||||||
|
vos_optimized=args.use_vos_optimized_video_predictor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_all_masks:
|
if args.use_all_masks:
|
||||||
|
Reference in New Issue
Block a user