77 Commits

Author SHA1 Message Date
rentainhe
8b56c25344 update to latest sam2 2024-12-21 11:21:48 +08:00
Ronghang Hu
2b90b9f5ce remove .pin_memory() in obj_pos of SAM2Base to resolve and error in MPS (#495)
In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487.

(close https://github.com/facebookresearch/sam2/issues/487)
2024-12-15 16:47:17 -08:00
Ronghang Hu
722d1d1511 patch for the case of offload_state_to_cpu=True in the new SAM2VideoPredictor (#490)
This PR adds a pathc for the case of `offload_state_to_cpu=True` where `pred_masks` might have been offload to CPU device (close https://github.com/facebookresearch/sam2/issues/489)
2024-12-12 15:12:13 -08:00
Ronghang Hu
393ae336a7 SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2:

- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
  * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
  * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
  * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
  * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
  * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
  * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
  * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
2024-12-11 15:00:55 -08:00
Haitham Khedr
c2ec8e14a1 remove unused paths (#384) 2024-10-14 10:40:54 -04:00
Roman Rädle
c98aa6bea3 Merge pull request #364 from facebookresearch/pr364
[sam2][demo][1/x] Fix file upload

Summary:

The Strawberry GraphQL library recently disabled multipart requests by default. This resulted in a video upload request returning "Unsupported content type" instead of uploading the video, processing it, and returning the video path.

This issue was raised in #361. A forward fix is to add `multipart_uploads_enabled=True` to the endpoint view.

Test Plan:

Tested locally with cURL and upload succeeds

*Request*

```
curl http://localhost:7263/graphql \
  -F operations='{ "query": "mutation($file: Upload!){ uploadVideo(file: $file) { path } }", "variables": { "file": null } }' \
  -F map='{ "file": ["variables.file"] }' \
  -F file=@video.mov
```

*Response*

```
{"data": {"uploadVideo": {"path": "uploads/<HASH>.mp4"}}}
```
2024-10-08 15:28:14 -07:00
Roman Rädle
ff9704fc0e [sam2][demo][1/x] Fix file upload
Summary:

The Strawberry GraphQL library recently disabled multipart requests by default. This resulted in a video upload request returning "Unsupported content type" instead of uploading the video, processing it, and returning the video path.

This issue was raised in #361. A forward fix is to add `multipart_uploads_enabled=True` to the endpoint view.

Test Plan:

Tested locally with cURL and upload succeeds

*Request*

```
curl http://localhost:7263/graphql \
  -F operations='{ "query": "mutation($file: Upload!){ uploadVideo(file: $file) { path } }", "variables": { "file": null } }' \
  -F map='{ "file": ["variables.file"] }' \
  -F file=@video.mov
```

*Response*

```
{"data": {"uploadVideo": {"path": "uploads/<HASH>.mp4"}}}
```
2024-10-08 14:58:28 -07:00
Ronghang Hu
29267c8e39 [doc] Check and raise an error if the user is running Python from the parent directory of the sam2 repo (#359)
If the user has "sam2/sam2" in their path, they are likey importing the repo itself as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). This typically happens because the user is running Python from the parent directory that contains the sam2 repo they cloned.

In general, the user should not run Python from the parent dir when the repo is cloned into (same is true for e.g. Numpy repo that contains names like `numpy/numpy` where the module and the repo have the same name), as the user encountered in https://github.com/facebookresearch/sam2/issues/346.

(close https://github.com/facebookresearch/sam2/issues/346)
2024-10-05 00:34:06 -07:00
Ronghang Hu
e22521832f [demo] add GPU to resources (#355)
This small PR adds GPU specification in `docker-compose.yaml` for the SAM 2 interactive webdemo, following https://docs.docker.com/compose/how-tos/gpu-support/#example-of-a-compose-file-for-running-a-service-with-access-to-1-gpu-device. It fixes a GPU access error as reported in https://github.com/facebookresearch/sam2/issues/354.

(close https://github.com/facebookresearch/sam2/issues/354)
2024-10-03 16:48:56 -07:00
Haitham Khedr
8bf0920e66 Add MANIFEST.in (#353) 2024-10-03 10:40:13 -07:00
Haitham Khedr
52198ead0e Merge pull request #2 from kit1980/patch-1
Use `weights_only` for loading
2024-10-01 22:32:58 +02:00
Ronghang Hu
98fcb164bf Update links after renaming the repo from segment-anything-2 to sam2 (#341)
This PR update repo links after we renamed the repo from `segment-anything-2` to `sam2`. It also changes `NAME` in setup.py to `SAM-2` (which is already the named used in pip setup since python packages don't allow whitespace)
2024-09-30 20:27:44 -07:00
Ronghang Hu
05d9e57fb3 [docs] add a release note and new installation instructions for SAM 2.1 (#338) 2024-09-30 09:55:58 -07:00
Ronghang Hu
429a2c7360 minor update README.md 2024-09-28 23:32:25 -07:00
Chay Ryali
3a7889d905 Merge pull request #335 from facebookresearch/sam2.1
SAM 2.1
2024-09-28 23:01:29 -07:00
Haitham Khedr
aa9b8722d0 SAM2.1
SAM2.1 checkpoints + training code + Demo
2024-09-29 05:49:56 +00:00
Sergii Dymchenko
0f6515ae85 Merge branch 'main' into patch-1 2024-08-26 15:49:40 -07:00
Ronghang Hu
7e1596c0b6 open README.md with unicode (to support Hugging Face emoji); fix various typos (#218)
(close #217, #66, #67, #69, #91, #126, #127, #145)
2024-08-14 09:06:25 -07:00
Haitham Khedr
0db838b117 Merge pull request #205 from facebookresearch/haitham/fix_hf_image_predictor
Fix HF image predictor
2024-08-12 17:04:04 -07:00
Haitham Khedr
fd5125b97a accept kwargs in auto_mask_generator 2024-08-13 00:02:36 +00:00
Haitham Khedr
1191677e1e Fix HF image predictor 2024-08-12 23:41:41 +00:00
Ronghang Hu
dce7b5446f improving warning message and adding further tips for installation (#204) 2024-08-12 11:37:41 -07:00
Ronghang Hu
1034ee2a1a better support for non-CUDA devices (CPU, MPS) (#192) 2024-08-12 10:46:50 -07:00
Chay Ryali
778e112740 Merge pull request #167 from arun477/patch-1
remove unused attributes from hieradet.py
2024-08-09 10:31:56 -07:00
Arun
8f607e2de1 Merge branch 'main' into patch-1 2024-08-09 11:14:43 +05:30
Arun
46945a2122 Update hieradet.py
ufmt formatting fixed.
2024-08-09 11:14:11 +05:30
Ronghang Hu
d421e0b040 add Colab support to the notebooks; pack config files in sam2_configs package during installation (#176) 2024-08-08 11:03:22 -07:00
Arun
102ddb8899 Merge branch 'main' into patch-1 2024-08-08 09:59:47 +05:30
Ronghang Hu
6186d1529a also catch errors during installation in case CUDAExtension cannot be loaded (#175)
Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
2024-08-07 12:26:11 -07:00
Ronghang Hu
6ecb5ff8d0 Add interface for box prompt in SAM 2 video predictor (#174)
This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained).

The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
2024-08-07 11:54:30 -07:00
Arun
086daf0641 Merge branch 'main' into patch-1 2024-08-07 21:50:26 +05:30
Haitham Khedr
6ba4c65cb2 Merge pull request #128 from NielsRogge/add_hf
Integrate with Hugging Face
2024-08-07 08:54:49 -07:00
Niels
9b58611e24 Address comment 2024-08-07 17:48:12 +02:00
Arun
6ec8560436 Update hieradet.py
Not used  
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5

F.scaled_dot_product_attention takes care of this automatically.
2024-08-07 11:35:46 +05:30
Niels
43c385c263 Update docstrings 2024-08-06 23:00:26 +02:00
Niels
322aa3e7e5 Revert code snippet 2024-08-06 22:57:07 +02:00
Diego Garcia
511199d7a9 Updated INSTALL.md with CUDA_HOME-related troubleshooting (#140)
This is referring to https://github.com/facebookresearch/segment-anything-2/issues/137 , which in itself refers to a common problem during installation, mentioned on https://github.com/facebookresearch/segment-anything-2/issues/19 .

Some users may encounter significant trouble installing the project, running into the error `OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.`. Simply adding the `--no-build-isolation` flag to the pip install, e.g. `pip install --no-build-isolation -e .`, usually solves this problem. However, this fix is not mentioned anywhere within the readmes or installation troubleshooting docs. This PR adds this recommendation into the INSTALL.md file under the "My installation failed with `CUDA_HOME environment variable is not set` " section, ensuring that more users are aware of this potential fix.

Examples of users experiencing related difficulties when installing:
https://github.com/facebookresearch/segment-anything-2/issues/19
https://github.com/facebookresearch/segment-anything-2/issues/41
https://github.com/facebookresearch/segment-anything-2/issues/99
https://github.com/facebookresearch/segment-anything-2/issues/133
2024-08-06 13:45:15 -07:00
Niels
8f15c6255a Format using ufmt 2024-08-06 22:43:35 +02:00
jhj0517
0bac418736 Update INSTALL.md (#156)
This PR suggests a way to resolve the error of `unsupported Microsoft Visual Studio version!` in INSTALL.md.
Adding `-allow-unsupported-compiler` argument for the `nvcc` worked. 

Editing [setup.py](https://github.com/facebookresearch/segment-anything-2/blob/main/setup.py) is required to add the `-allow-unsupported-compiler` argument for `nvcc`.

```python
def get_extensions():
    srcs = ["sam2/csrc/connected_components.cu"]
    compile_args = {
        "cxx": [],
        "nvcc": [
            "-DCUDA_HAS_FP16=1",
            "-D__CUDA_NO_HALF_OPERATORS__",
            "-D__CUDA_NO_HALF_CONVERSIONS__",
            "-D__CUDA_NO_HALF2_OPERATORS__",
            "-allow-unsupported-compiler"  # Add this argument
        ],
    }
    ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
    return ext_modules
```
2024-08-06 13:43:12 -07:00
Niels
27a167c004 Update README 2024-08-06 22:41:32 +02:00
Ronghang Hu
6f7e700c37 Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
2024-08-06 10:52:01 -07:00
Niels
a36edf1e01 Clean up 2024-08-06 08:34:42 +02:00
Niels
e815f70a38 Address comment 2024-08-06 08:32:36 +02:00
Niels
fbf7e3a664 Add link 2024-08-05 22:12:15 +02:00
Niels
e9503c96fe Move HF to separate section 2024-08-05 22:10:57 +02:00
Niels
c3393d8b5f Include original code snippet 2024-08-05 22:08:54 +02:00
Haitham Khedr
0230c5ff93 Merge pull request #152 from haithamkhedr/main
Configure a workflow for format checking
2024-08-05 10:02:32 -07:00
Haitham Khedr
5e3d6ca6b5 Merge pull request #1 from haithamkhedr/CI 2024-08-05 09:36:54 -07:00
Haitham Khedr
3b0fd9e4a9 Update workflow 2024-08-05 09:28:28 -07:00
Haitham Khedr
acd3939f88 Add workflow 2024-08-05 09:16:29 -07:00
Niels
841cc1f015 Update docstring 2024-08-05 09:44:03 +02:00
Niels
e93be7f6aa Update README 2024-08-05 09:43:04 +02:00
Niels
cb48213066 Update links 2024-08-05 09:41:40 +02:00
Niels
6aeee34775 Make huggingface_hub soft dependency 2024-08-05 09:37:53 +02:00
Niels
0c28c630c2 Do not load config from the hub 2024-08-03 14:45:20 +02:00
Niels
3af4e82263 Add model_id_to_filenames 2024-08-03 14:18:23 +02:00
Niels
17b74501fb Use classmethod 2024-08-03 14:14:12 +02:00
Niels
b72a8a97f0 First draft 2024-08-03 12:57:05 +02:00
Ronghang Hu
57bc94b739 Merge pull request #119 from facebookresearch/ronghanghu/installation_faq
[doc] add `INSTALL.md` as an installation FAQ page
2024-08-02 15:07:25 -07:00
Ronghang Hu
b744a3c084 [doc] add INSTALL.md as an installation FAQ page 2024-08-02 22:05:46 +00:00
Haitham Khedr
d1fc9a0686 Merge pull request #116 from facebookresearch/arXiv-paper
Update arXiv paper citation
2024-08-02 13:02:08 -07:00
Haitham Khedr
59550d4deb Update README.md 2024-08-02 12:58:23 -07:00
Haitham Khedr
de4db16676 Update README.md 2024-08-02 12:56:06 -07:00
Ronghang Hu
0e78a11899 Merge pull request #61 from DanBrown47/main
Change git repo url from SSH to HTTPS
2024-07-31 21:30:11 -07:00
Danwand
fa2796bb47 Change git repo url from SSH to HTTPS
The change is made for researchers to easly clone the project to check out from systems and platforms with SSH not in sync with github. Eg : Google Colab, Remote GPU Servers etc
2024-07-31 13:55:06 +05:30
Haitham Khedr
86827e2fba Merge pull request #32 from CharlesCNorton/patch-5
Fix: Hyphenate to "model-in-the-loop"
2024-07-30 07:54:23 -07:00
Ronghang Hu
cd270ed4f1 Merge pull request #29 from CharlesCNorton/patch-3
fix: correct spelling
2024-07-30 07:49:01 -07:00
Ronghang Hu
32750fa695 Merge pull request #30 from CharlesCNorton/patch-4
Fix typo in comment: "evalaution" to "evaluation"
2024-07-30 07:48:16 -07:00
CharlesCNorton
e62ec497b8 Fix: Hyphenate to "model-in-the-loop"
The phrase "model-in-the-loop" is now hyphenated to align with standard practices in technical literature, where hyphenation of compound adjectives clarifies their function as a single descriptor.
2024-07-30 08:35:59 -04:00
CharlesCNorton
c8127182c1 Fix typo in comment: "evalaution" to "evaluation"
Corrected the spelling of "evaluation" in the comment describing the use of the `--use_all_masks` flag in the `main` function.
2024-07-30 08:24:39 -04:00
CharlesCNorton
f882beb157 fix: correct spelling
Corrected two instances of "a even number" to "an even number" for correct article usage before a vowel sound.
2024-07-30 08:13:47 -04:00
Ronghang Hu
82b026cd55 Merge pull request #7 from CharlesCNorton/patch-2
Correct typo in sav_dataset README.md
2024-07-29 19:28:17 -07:00
CharlesCNorton
de05a2e0c5 Correct typo in sav_dataset README.md
Change "there licenses" to "their licenses" in the README.MD
2024-07-29 22:25:56 -04:00
Ronghang Hu
b3011f0ea6 Merge pull request #5 from CharlesCNorton/patch-1
Fix typo in README: "Aything" corrected to "Anything"
2024-07-29 19:24:31 -07:00
CharlesCNorton
662fd3d90e Fix typo in README: "Aything" corrected to "Anything"
Corrected a typo in the Segment Anything Video Dataset section of the README file. The word "Aything" has been updated to "Anything."
2024-07-29 22:15:41 -04:00
Sergii Dymchenko
658aaba327 Use weights_only for loading
sam2/build_sam.py:81:14: TOR102 [*] `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.

Found with https://github.com/pytorch-labs/torchfix/
2024-07-29 16:54:54 -07:00
Haitham Khedr
0c5f8c5432 Initial commit 2024-07-29 21:54:20 +00:00
49 changed files with 1961 additions and 5870 deletions

17
.github/workflows/check_fmt.yml vendored Normal file
View File

@@ -0,0 +1,17 @@
name: SAM2/fmt
on:
pull_request:
branches:
- main
jobs:
ufmt_check:
runs-on: ubuntu-latest
steps:
- name: Check formatting
uses: omnilib/ufmt@action-v1
with:
path: sam2 tools
version: "2.0.0b2"
python-version: "3.10"
black-version: "24.2.0"
usort-version: "1.0.2"

3
.gitignore vendored
View File

@@ -145,5 +145,4 @@ dmypy.json
outputs/ outputs/
.idea/ .idea/
tmp/ demo/backend/checkpoints/*.pt
data/

View File

@@ -27,7 +27,7 @@ WORKDIR /home/appuser/Grounded-SAM-2
# Install essential Python packages # Install essential Python packages
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \ RUN python -m pip install --upgrade pip setuptools wheel numpy \
opencv-python transformers supervision pycocotools addict yapf timm opencv-python transformers supervision pycocotools addict yapf timm
# Install segment_anything package in editable mode # Install segment_anything package in editable mode

View File

@@ -2,7 +2,7 @@
### Requirements ### Requirements
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. - Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
@@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details> </details>
<details> <details>

View File

@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright 2023 - present, IDEA Research. Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.

View File

@@ -20,7 +20,6 @@ In this repo, we've supported the following demo with **simple implementations**
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience. Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
## Latest updates ## Latest updates
- **2025.04.20**: Update to `dds-cloudapi-sdk` API V2 version. The V1 version in the original API for `Grounding DINO 1.5` and `DINO-X` has been deprecated, please update to the latest `dds-cloudapi-sdk` by `pip install dds-cloudapi-sdk -U` to use `Grounding DINO 1.5 / 1.6` and `DINO-X` models. Please refer to [dds-cloudapi-sdk](https://github.com/deepdataspace/dds-cloudapi-sdk) and our [API docs](https://cloud.deepdataspace.com/docs) to view more details about the update.
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details. - **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
@@ -335,16 +334,6 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
``` ```
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
```bash
python grounded_sam2_tracking_camera_with_continuous_id.py
```
## Grounded SAM 2 Florence-2 Demos ## Grounded SAM 2 Florence-2 Demos
### Grounded SAM 2 Florence-2 Image Demo ### Grounded SAM 2 Florence-2 Image Demo

27
RELEASE_NOTES.md Normal file
View File

@@ -0,0 +1,27 @@
## SAM 2 release notes
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
### 07/29/2024 -- SAM 2 is released
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
* SAM 2 code: https://github.com/facebookresearch/sam2
* SAM 2 demo: https://sam2.metademolab.com/
* SAM 2 paper: https://arxiv.org/abs/2408.00714

View File

@@ -1,4 +1,4 @@
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
ARG MODEL_SIZE=base_plus ARG MODEL_SIZE=base_plus
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}

View File

@@ -105,7 +105,7 @@ cd demo/backend/server/
```bash ```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \ PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \ APP_ROOT="$(pwd)/../../../" \
APP_URL=http://localhost:7263 \ API_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \ MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \ DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \ DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \

View File

@@ -1,7 +1,9 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -25,7 +27,6 @@ IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -47,7 +48,7 @@ config = Config(token)
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# infer_image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -61,18 +62,13 @@ if WITH_SLICE_INFERENCE:
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile: with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
infer_image_url = client.upload_file(temp_filename) image_url = client.upload_file(temp_filename)
task = V2Task(api_path="/v2/task/dinox/detection", api_body={ task = DinoxTask(
"model": "DINO-X-1.0", image_url=image_url,
"image": infer_image_url, prompts=[TextPrompt(text=TEXT_PROMPT)],
"prompt": { bbox_threshold=0.25,
"type":"text", targets=[DetectionTarget.BBox],
"text":TEXT_PROMPT )
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
})
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# detele the tempfile # detele the tempfile
@@ -81,7 +77,7 @@ if WITH_SLICE_INFERENCE:
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result["objects"] objects = result.objects
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj.bbox)
confidences.append(obj.score) confidences.append(obj.score)
@@ -106,26 +102,19 @@ if WITH_SLICE_INFERENCE:
class_ids = detections.class_id class_ids = detections.class_id
input_boxes = detections.xyxy input_boxes = detections.xyxy
else: else:
infer_image_url = client.upload_file(IMG_PATH) image_url = client.upload_file(IMG_PATH)
task = V2Task( task = DinoxTask(
api_path="/v2/task/dinox/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=TEXT_PROMPT)],
"model": "DINO-X-1.0", bbox_threshold=0.25,
"image": infer_image_url, targets=[DetectionTarget.BBox],
"prompt": {
"type":"text",
"text":TEXT_PROMPT
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects
objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -134,9 +123,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
cls_name = obj["category"].lower().strip() cls_name = obj.category.lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -1,7 +1,10 @@
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -24,9 +27,8 @@ TEXT_PROMPT = "car . building ."
IMG_PATH = "notebooks/images/cars.jpg" IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -47,7 +49,8 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task using V2Task API # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -62,33 +65,26 @@ if WITH_SLICE_INFERENCE:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename) image_url = client.upload_file(temp_filename)
task = V2Task( task = DetectionTask(
api_path="/v2/task/grounding_dino/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=TEXT_PROMPT)],
"model": GROUNDING_MODEL, targets=[DetectionTarget.BBox], # detect bbox
"image": image_url, model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
"prompt": { bbox_threshold=BOX_THRESHOLD, # box confidence threshold
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# delete the tempfile # detele the tempfile
os.remove(temp_filename) os.remove(temp_filename)
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result["objects"] objects = result.objects
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
cls_name = obj["category"].lower().strip() cls_name = obj.category.lower().strip()
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])
# ensure input_boxes with shape (_, 4) # ensure input_boxes with shape (_, 4)
input_boxes = np.array(input_boxes).reshape(-1, 4) input_boxes = np.array(input_boxes).reshape(-1, 4)
@@ -100,7 +96,7 @@ if WITH_SLICE_INFERENCE:
callback=callback, callback=callback,
slice_wh=SLICE_WH, slice_wh=SLICE_WH,
overlap_ratio_wh=OVERLAP_RATIO, overlap_ratio_wh=OVERLAP_RATIO,
iou_threshold=IOU_THRESHOLD, iou_threshold=0.5,
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
) )
detections = slicer(cv2.imread(IMG_PATH)) detections = slicer(cv2.imread(IMG_PATH))
@@ -111,25 +107,18 @@ if WITH_SLICE_INFERENCE:
else: else:
image_url = client.upload_file(IMG_PATH) image_url = client.upload_file(IMG_PATH)
task = V2Task( task = DetectionTask(
api_path="/v2/task/grounding_dino/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=TEXT_PROMPT)],
"model": GROUNDING_MODEL, targets=[DetectionTarget.BBox], # detect bbox
"image": image_url, model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
"prompt": { bbox_threshold=BOX_THRESHOLD, # box confidence threshold
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -138,9 +127,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
cls_name = obj["category"].lower().strip() cls_name = obj.category.lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -23,7 +23,7 @@ parser.add_argument("--text-prompt", default="car. tire.")
parser.add_argument("--img-path", default="notebooks/images/truck.jpg") parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt") parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml") parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
parser.add_argument("--output-dir", default="outputs/grounded_sam2_hf_demo") parser.add_argument("--output-dir", default="outputs/test_sam2.1")
parser.add_argument("--no-dump-json", action="store_true") parser.add_argument("--no-dump-json", action="store_true")
parser.add_argument("--force-cpu", action="store_true") parser.add_argument("--force-cpu", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# use bfloat16 # use bfloat16
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -61,7 +61,6 @@ boxes, confidences, labels = predict(
caption=text, caption=text,
box_threshold=BOX_THRESHOLD, box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD, text_threshold=TEXT_THRESHOLD,
device=DEVICE
) )
# process the box prompt for SAM 2 # process the box prompt for SAM 2
@@ -71,9 +70,9 @@ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# FIXME: figure how does this influence the G-DINO model # FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -1,536 +0,0 @@
import copy
import os
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
# Setup environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GroundingDinoPredictor:
"""
Wrapper for using a GroundingDINO model for zero-shot object detection.
"""
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
"""
Initialize the GroundingDINO predictor.
Args:
model_id (str): HuggingFace model ID to load.
device (str): Device to run the model on ('cuda' or 'cpu').
"""
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
self.device = device
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
device
)
def predict(
self,
image: "PIL.Image.Image",
text_prompts: str,
box_threshold=0.25,
text_threshold=0.25,
):
"""
Perform object detection using text prompts.
Args:
image (PIL.Image.Image): Input RGB image.
text_prompts (str): Text prompt describing target objects.
box_threshold (float): Confidence threshold for box selection.
text_threshold (float): Confidence threshold for text match.
Returns:
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
"""
inputs = self.processor(
images=image, text=text_prompts, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]],
)
return results[0]["boxes"], results[0]["labels"]
class SAM2ImageSegmentor:
"""
Wrapper class for SAM2-based segmentation given bounding boxes.
"""
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
"""
Initialize the SAM2 image segmentor.
Args:
sam_model_cfg (str): Path to the SAM2 config file.
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
device (str): Device to load the model on ('cuda' or 'cpu').
"""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.device = device
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
self.predictor = SAM2ImagePredictor(sam_model)
def set_image(self, image: np.ndarray):
"""
Set the input image for segmentation.
Args:
image (np.ndarray): RGB image array with shape (H, W, 3).
"""
self.predictor.set_image(image)
def predict_masks_from_boxes(self, boxes: torch.Tensor):
"""
Predict segmentation masks from given bounding boxes.
Args:
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- masks: Binary masks per box, shape (N, H, W)
- scores: Confidence scores for each mask
- logits: Raw logits from the model
"""
masks, scores, logits = self.predictor.predict(
point_coords=None,
point_labels=None,
box=boxes,
multimask_output=False,
)
# Normalize shape to (N, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
return masks, scores, logits
class IncrementalObjectTracker:
def __init__(
self,
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text="car.",
detection_interval=20,
):
"""
Initialize an incremental object tracker using GroundingDINO and SAM2.
Args:
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
sam2_model_cfg (str): Path to SAM2 model config file.
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
device (str): Device to run the models on ('cuda' or 'cpu').
prompt_text (str): Initial text prompt for detection.
detection_interval (int): Frame interval between full detections.
"""
self.device = device
self.detection_interval = detection_interval
self.prompt_text = prompt_text
# Load models
self.grounding_predictor = GroundingDinoPredictor(
model_id=grounding_model_id, device=device
)
self.sam2_segmentor = SAM2ImageSegmentor(
sam_model_cfg=sam2_model_cfg,
sam_model_ckpt=sam2_ckpt_path,
device=device,
)
self.video_predictor = build_sam2_video_predictor(
sam2_model_cfg, sam2_ckpt_path
)
# Initialize inference state
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
self.total_frames = 0
self.objects_count = 0
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
# Store tracking results
self.last_mask_dict = MaskDictionaryModel()
self.track_dict = MaskDictionaryModel()
def add_image(self, image_np: np.ndarray):
"""
Add a new image frame to the tracker and perform detection or tracking update.
Args:
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
Returns:
np.ndarray: Annotated image with object masks and labels.
"""
import numpy as np
from PIL import Image
img_pil = Image.fromarray(image_np)
# Step 1: Perform detection every detection_interval frames
if self.total_frames % self.detection_interval == 0:
if (
self.inference_state["video_height"] is None
or self.inference_state["video_width"] is None
):
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
print(
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
)
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
# 1.1 GroundingDINO object detection
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
if boxes.shape[0] == 0:
return
# 1.2 SAM2 segmentation from detection boxes
self.sam2_segmentor.set_image(image_np)
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
# 1.3 Build MaskDictionaryModel
mask_dict = MaskDictionaryModel(
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
)
mask_dict.add_new_frame_annotation(
mask_list=torch.tensor(masks).to(self.device),
box_list=torch.tensor(boxes),
label_list=labels,
)
# 1.4 Object ID tracking and IOU-based update
self.objects_count = mask_dict.update_masks(
tracking_annotation_dict=self.last_mask_dict,
iou_threshold=0.3,
objects_count=self.objects_count,
)
# 1.5 Reset video tracker state
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
self.video_predictor.reset_state(self.inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, _, _ = self.video_predictor.add_new_mask(
self.inference_state,
frame_idx,
object_id,
object_info.mask,
)
self.track_dict = copy.deepcopy(mask_dict)
self.last_mask_dict = mask_dict
else:
# Step 2: Use incremental tracking for intermediate frames
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
# Step 3: Tracking propagation using the video predictor
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
inference_state=self.inference_state,
frame_idx=frame_idx,
)
# Step 4: Update the mask dictionary based on tracked masks
frame_masks = MaskDictionaryModel()
for i, obj_id in enumerate(obj_ids):
out_mask = video_res_masks[i] > 0.0
object_info = ObjectInfo(
instance_id=obj_id,
mask=out_mask[0],
class_name=self.track_dict.get_target_class_name(obj_id),
logit=self.track_dict.get_target_logit(obj_id),
)
object_info.update_box()
frame_masks.labels[obj_id] = object_info
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
self.last_mask_dict = copy.deepcopy(frame_masks)
# Step 5: Build mask array
H, W = image_np.shape[:2]
mask_img = torch.zeros((H, W), dtype=torch.int32)
for obj_id, obj_info in self.last_mask_dict.labels.items():
mask_img[obj_info.mask == True] = obj_id
mask_array = mask_img.cpu().numpy()
# Step 6: Visualization
annotated_frame = self.visualize_frame_with_mask_and_metadata(
image_np=image_np,
mask_array=mask_array,
json_metadata=self.last_mask_dict.to_dict(),
)
print(f"[Tracker] Total processed frames: {self.total_frames}")
self.total_frames += 1
torch.cuda.empty_cache()
return annotated_frame
def set_prompt(self, new_prompt: str):
"""
Dynamically update the GroundingDINO prompt and reset tracking state
to force a new object detection.
"""
self.prompt_text = new_prompt
self.total_frames = 0 # Trigger immediate re-detection
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
self.inference_state["video_height"] = None
self.inference_state["video_width"] = None
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
"""
Save the current mask, metadata, raw image, and annotated result.
Args:
output_dir (str): The root output directory.
raw_image (np.ndarray, optional): The original input image (RGB).
"""
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
image_data_dir = os.path.join(output_dir, "images")
vis_data_dir = os.path.join(output_dir, "result")
os.makedirs(mask_data_dir, exist_ok=True)
os.makedirs(json_data_dir, exist_ok=True)
os.makedirs(image_data_dir, exist_ok=True)
os.makedirs(vis_data_dir, exist_ok=True)
frame_masks = self.last_mask_dict
# Ensure mask_name is valid
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
base_name = f"image_{self.total_frames:05d}"
# Save segmentation mask
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
for obj_id, obj_info in frame_masks.labels.items():
mask_img[obj_info.mask == True] = obj_id
np.save(
os.path.join(mask_data_dir, frame_masks.mask_name),
mask_img.numpy().astype(np.uint16),
)
# Save metadata as JSON
json_path = os.path.join(json_data_dir, base_name + ".json")
frame_masks.to_json(json_path)
# Save raw input image
if raw_image is not None:
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
# Save annotated image with mask, bounding boxes, and labels
annotated_image = self.visualize_frame_with_mask_and_metadata(
image_np=raw_image,
mask_array=mask_img.numpy().astype(np.uint16),
json_metadata=frame_masks.to_dict(),
)
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
)
print(
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
)
def visualize_frame_with_mask_and_metadata(
self,
image_np: np.ndarray,
mask_array: np.ndarray,
json_metadata: dict,
):
image = image_np.copy()
H, W = image.shape[:2]
# Step 1: Parse metadata and build object entries
metadata_lookup = json_metadata.get("labels", {})
all_object_ids = []
all_object_boxes = []
all_object_classes = []
all_object_masks = []
for obj_id_str, obj_info in metadata_lookup.items():
instance_id = obj_info.get("instance_id")
if instance_id is None or instance_id == 0:
continue
if instance_id not in np.unique(mask_array):
continue
object_mask = mask_array == instance_id
all_object_ids.append(instance_id)
x1 = obj_info.get("x1", 0)
y1 = obj_info.get("y1", 0)
x2 = obj_info.get("x2", 0)
y2 = obj_info.get("y2", 0)
all_object_boxes.append([x1, y1, x2, y2])
all_object_classes.append(obj_info.get("class_name", "unknown"))
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
# Step 2: Check if valid objects exist
if len(all_object_ids) == 0:
print("No valid object instances found in metadata.")
return image
# Step 3: Sort by instance ID
paired = list(
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
)
paired.sort(key=lambda x: x[0])
all_object_ids = [p[0] for p in paired]
all_object_boxes = [p[1] for p in paired]
all_object_masks = [p[2] for p in paired]
all_object_classes = [p[3] for p in paired]
# Step 4: Build detections
all_object_masks = np.concatenate(all_object_masks, axis=0)
detections = sv.Detections(
xyxy=np.array(all_object_boxes),
mask=all_object_masks,
class_id=np.array(all_object_ids, dtype=np.int32),
)
labels = [
f"{instance_id}: {class_name}"
for instance_id, class_name in zip(all_object_ids, all_object_classes)
]
# Step 5: Annotate image
annotated_frame = image.copy()
mask_annotator = sv.MaskAnnotator()
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
annotated_frame = box_annotator.annotate(annotated_frame, detections)
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
return annotated_frame
import os
import cv2
import torch
from utils.common_utils import CommonUtils
def main():
# Parameter settings
output_dir = "./outputs"
prompt_text = "hand."
detection_interval = 20
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
os.makedirs(output_dir, exist_ok=True)
# Initialize the object tracker
tracker = IncrementalObjectTracker(
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text=prompt_text,
detection_interval=detection_interval,
)
tracker.set_prompt("person.")
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("[Error] Cannot open camera.")
return
print("[Info] Camera opened. Press 'q' to quit.")
frame_idx = 0
try:
while True:
ret, frame = cap.read()
if not ret:
print("[Warning] Failed to capture frame.")
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
print(f"[Frame {frame_idx}] Processing live frame...")
process_image = tracker.add_image(frame_rgb)
if process_image is None or not isinstance(process_image, np.ndarray):
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
frame_idx += 1
continue
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
# cv2.imshow("Live Inference", process_image_bgr)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[Info] Quit signal received.")
# break
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
frame_idx += 1
if frame_idx >= max_frames:
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
break
except KeyboardInterrupt:
print("[Info] Interrupted by user (Ctrl+C).")
finally:
cap.release()
cv2.destroyAllWindows()
print("[Done] Live inference complete.")
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,9 @@
# dds cloudapi for DINO-X - update to V2Task API # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk.tasks.dinox import DinoxTask
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -28,7 +30,6 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_DINOX = "Your API token" API_TOKEN_FOR_DINOX = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -97,29 +98,22 @@ config = Config(API_TOKEN_FOR_DINOX)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task using V2Task class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = V2Task( task = DinoxTask(
api_path="/v2/task/dinox/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=TEXT_PROMPT)],
"model": "DINO-X-1.0", bbox_threshold=0.25,
"image": image_url, targets=[DetectionTarget.BBox],
"prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -127,9 +121,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
class_names.append(obj["category"]) class_names.append(obj.category)
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,7 +1,10 @@
# dds cloudapi for Grounding DINO 1.5 - Update to V2Task API # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -28,7 +31,6 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token" API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -97,38 +99,33 @@ config = Config(API_TOKEN_FOR_GD1_5)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task using V2Task class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = V2Task( task = DetectionTask(
api_path="/v2/task/grounding_dino/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=TEXT_PROMPT)],
"model": "GroundingDino-1.5-Pro", targets=[DetectionTarget.BBox], # detect bbox
"image": image_url, model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
"prompt": { bbox_threshold=BOX_THRESHOLD,
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
class_names.append(obj["category"]) class_names.append(obj.category)
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,7 +1,11 @@
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import torch import torch
@@ -47,9 +51,6 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car." text = "car."
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
GROUNDING_MODEL = "GroundingDino-1.6-Pro" # 使用字符串替代枚举值
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg` # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/car" video_dir = "notebooks/videos/car"
@@ -101,32 +102,24 @@ for start_frame_idx in range(0, len(frame_names), step):
client = Client(config) client = Client(config)
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = V2Task( task = DetectionTask(
api_path="/v2/task/grounding_dino/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text=text)],
"model": GROUNDING_MODEL, targets=[DetectionTarget.BBox], # detect bbox
"image": image_url, model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
"prompt": {
"type": "text",
"text": text
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
class_names.append(obj["category"]) class_names.append(obj.category)
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)
OBJECTS = class_names OBJECTS = class_names
@@ -161,7 +154,7 @@ for start_frame_idx in range(0, len(frame_names), step):
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=IOU_THRESHOLD, objects_count=objects_count) objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count) print("objects_count", objects_count)
else: else:

View File

@@ -1,7 +1,10 @@
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.v2_task import V2Task from dds_cloudapi_sdk import DetectionTask
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -51,11 +54,6 @@ inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# 添加参数设置
TEXT_PROMPT = "children. pillow"
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
""" """
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
@@ -72,29 +70,23 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task using V2Task class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = V2Task( task = DetectionTask(
api_path="/v2/task/grounding_dino/detection", image_url=image_url,
api_body={ prompts=[TextPrompt(text="children. pillow")],
"model": "GroundingDino-1.5-Pro", targets=[DetectionTarget.BBox], # detect bbox
"image": image_url, model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
"prompt": { bbox_threshold=0.2,
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -102,9 +94,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj["bbox"]) input_boxes.append(obj.bbox)
confidences.append(obj["score"]) confidences.append(obj.score)
class_names.append(obj["category"]) class_names.append(obj.category)
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from grounding_dino.groundingdino.util.misc import NestedTensor from grounding_dino.groundingdino.util.misc import NestedTensor
@@ -113,7 +113,7 @@ class WindowAttention(nn.Module):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

View File

@@ -15,24 +15,11 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/version.h>
// Check PyTorch version and define appropriate macros
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
// PyTorch 2.x and above
#define GET_TENSOR_TYPE(x) x.scalar_type()
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
#else
// PyTorch 1.x
#define GET_TENSOR_TYPE(x) x.type()
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
#endif
namespace groundingdino { namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward( at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value, const at::Tensor &value,
const at::Tensor &spatial_shapes, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &level_start_index,
const at::Tensor &sampling_loc, const at::Tensor &sampling_loc,
@@ -45,11 +32,11 @@ at::Tensor ms_deform_attn_cuda_forward(
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -64,7 +51,7 @@ at::Tensor ms_deform_attn_cuda_forward(
const int im2col_step_ = std::min(batch, im2col_step); const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_; const int batch_n = im2col_step_;
@@ -75,7 +62,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto columns = output_n.select(0, n); auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), spatial_shapes.data<int64_t>(),
@@ -95,7 +82,7 @@ at::Tensor ms_deform_attn_cuda_forward(
std::vector<at::Tensor> ms_deform_attn_cuda_backward( std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &value,
const at::Tensor &spatial_shapes, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &level_start_index,
const at::Tensor &sampling_loc, const at::Tensor &sampling_loc,
@@ -111,12 +98,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -141,11 +128,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto grad_output_g = grad_output_n.select(0, n); auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,
@@ -166,4 +153,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
}; };
} }
} // namespace groundingdino } // namespace groundingdino

View File

@@ -8,7 +8,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.layers import DropPath from timm.models.layers import DropPath
class FeatureResizer(nn.Module): class FeatureResizer(nn.Module):

View File

@@ -470,7 +470,6 @@ class TransformerEncoder(nn.Module):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
indexing="ij"
) )
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
@@ -860,7 +859,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt): def forward_ffn(self, tgt):
with torch.amp.autocast("cuda", enabled=False): with torch.cuda.amp.autocast(enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2) tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)

View File

@@ -79,7 +79,6 @@ def gen_encoder_output_proposals(
grid_y, grid_x = torch.meshgrid( grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
indexing="ij"
) )
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2

View File

@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
y = torch.arange(0, h, dtype=torch.float) y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x, indexing="ij") y, x = torch.meshgrid(y, x)
x_mask = masks * x.unsqueeze(0) x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] x_max = x_mask.flatten(1).max(-1)[0]

View File

@@ -63,7 +63,6 @@ def predict(
model = model.to(device) model = model.to(device)
image = image.to(device) image = image.to(device)
model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(image[None], captions=[caption]) outputs = model(image[None], captions=[caption])
@@ -77,10 +76,10 @@ def predict(
tokenizer = model.tokenizer tokenizer = model.tokenizer
tokenized = tokenizer(caption) tokenized = tokenizer(caption)
if remove_combined: if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]] sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
phrases = [] phrases = []
for logit in logits: for logit in logits:
max_idx = logit.argmax() max_idx = logit.argmax()

View File

@@ -1,68 +1,6 @@
[build-system] [build-system]
requires = ["setuptools>=61.0", "wheel"] requires = [
"setuptools>=61.0",
"torch>=2.5.1",
]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project]
name = "Grounded-SAM-2"
version = "1.0"
description = "Grounded SAM 2: Ground and Track Anything in Videos"
readme = "README.md"
requires-python = ">=3.10.0"
license = { text = "Apache 2.0" }
authors = [{ name = "Meta AI", email = "segment-anything@meta.com" }]
keywords = ["segmentation", "computer vision", "deep learning"]
dependencies = [
"torch>=2.3.1",
"torchvision>=0.18.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"opencv-python-headless>=4.11.0.86",
"supervision>=0.26.1",
"pycocotools>=2.0.10",
"transformers>=4.55.1",
"addict>=2.4.0",
"yapf>=0.43.0",
"timm>=1.0.19",
"pdf2image>=1.17.0",
]
[project.optional-dependencies]
notebooks = [
"matplotlib>=3.9.1",
"jupyter>=1.0.0",
"opencv-python>=4.7.0",
"eva-decord>=0.6.1",
]
interactive-demo = [
"Flask>=3.0.3",
"Flask-Cors>=5.0.0",
"av>=13.0.0",
"dataclasses-json>=0.6.7",
"eva-decord>=0.6.1",
"gunicorn>=23.0.0",
"imagesize>=1.4.1",
"pycocotools>=2.0.8",
"strawberry-graphql>=0.243.0",
]
dev = [
"black==24.2.0",
"usort==1.0.2",
"ufmt==2.0.0b2",
"fvcore>=0.1.5.post20221221",
"pandas>=2.2.2",
"scikit-image>=0.24.0",
"tensorboard>=2.17.0",
"pycocotools>=2.0.8",
"tensordict>=0.5.0",
"opencv-python>=4.7.0",
"submitit>=1.5.1",
]
[tool.setuptools]
# extensions = [{ name = "sam2._C", sources = ["sam2/csrc/connected_components.cu"] }]
packages = ["sam2", "grounding_dino"]

92
sam2/benchmark.py Normal file
View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from sam2.build_sam import build_sam2_video_predictor
# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)
# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()
# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0
print("FPS: ", count * num_frames / total)

View File

@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
vos_optimized=False,
**kwargs, **kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
] ]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]
if apply_postprocessing: if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [ hydra_overrides_extra += [

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -97,7 +97,7 @@ trainer:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -108,7 +108,7 @@ trainer:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = ( windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp) return windows, (Hp, Wp)
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view( x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1 B, Hp // window_size, Wp // window_size, window_size, window_size, -1
) )
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :]
return x return x

View File

@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000, temperature: int = 10000,
normalize: bool = True, normalize: bool = True,
scale: Optional[float] = None, scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
): ):
super().__init__() super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width" assert num_pos_feats % 2 == 0, "Expecting even model width"
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale self.scale = scale
self.cache = {} self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y): def _encode_xy(self, x, y):
# The positions are expected to be normalized # The positions are expected to be normalized
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
return pos return pos
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor): def _pe(self, B, device, *cache_key):
cache_key = (x.shape[-2], x.shape[-1]) H, W = cache_key
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = ( y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1) .view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1]) .repeat(B, 1, W)
) )
x_embed = ( x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1) .view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1) .repeat(B, H, 1)
) )
if self.normalize: if self.normalize:
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0] self.cache[cache_key] = pos[0]
return pos return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module): class PositionEmbeddingRandom(nn.Module):
""" """

View File

@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords( point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size points, self.input_image_size
) )
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding = torch.where(
point_embedding[labels == 0] += self.point_embeddings[0].weight (labels == -1).unsqueeze(-1),
point_embedding[labels == 1] += self.point_embeddings[1].weight torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding[labels == 2] += self.point_embeddings[2].weight point_embedding,
point_embedding[labels == 3] += self.point_embeddings[3].weight )
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

View File

@@ -4,9 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import math import math
import warnings
from functools import partial from functools import partial
from typing import Tuple, Type from typing import Tuple, Type
@@ -16,29 +14,6 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):
@@ -265,20 +240,7 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try: out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
out = self.out_proj(out) out = self.out_proj(out)
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length # whether to repeat q rope to match k length
# this is needed for cross-attention to memories # this is needed for cross-attention to memories
rope_k_repeat=False, rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
) )
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis self.freqs_cis = (
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat self.rope_k_repeat = rope_k_repeat
def forward( def forward(
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try: out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
out = self.out_proj(out) out = self.out_proj(out)

View File

@@ -628,7 +628,9 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs: if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1 t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device) obj_pos = torch.tensor(pos_list).to(
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ import os
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@@ -210,74 +209,6 @@ def load_video_frames(
"Only MP4 video and JPEG folder are supported at this moment" "Only MP4 video and JPEG folder are supported at this moment"
) )
def process_stream_frame(
img_array: np.ndarray,
image_size: int,
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
offload_to_cpu: bool = False,
compute_device: torch.device = torch.device("cuda"),
):
"""
Convert a raw image array (H,W,3 or 3,H,W) into a modelready tensor.
Steps
-----
1. Resize the shorter side to `image_size`, keeping aspect ratio,
then centercrop/pad to `image_size` × `image_size`.
2. Change layout to [3, H, W] and cast to float32 in [0,1].
3. Normalise with ImageNet statistics.
4. Optionally move to `compute_device`.
Returns
-------
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
orig_h : int
orig_w : int
"""
# ↪ uses your existing helper so behaviour matches the batch loader
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
# Normalisation (done *after* potential device move for efficiency)
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if not offload_to_cpu:
img_tensor = img_tensor.to(compute_device)
img_mean_t = img_mean_t.to(compute_device)
img_std_t = img_std_t.to(compute_device)
img_tensor.sub_(img_mean_t).div_(img_std_t)
return img_tensor, orig_h, orig_w
def _resize_and_convert_to_tensor(img_array, image_size):
"""
Resize the input image array and convert it into a tensor.
Also return original image height and width.
"""
# Convert numpy array to PIL image and ensure RGB
img_pil = Image.fromarray(img_array).convert("RGB")
# Save original size (PIL: size = (width, height))
video_width, video_height = img_pil.size
# Resize with high-quality LANCZOS filter
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
# Convert resized image back to numpy and then to float tensor
img_resized_array = np.array(img_resized)
if img_resized_array.dtype == np.uint8:
img_resized_array = img_resized_array / 255.0
else:
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
# Convert to PyTorch tensor and permute to [C, H, W]
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
return img_tensor, video_height, video_width
def load_video_frames_from_jpg_images( def load_video_frames_from_jpg_images(
video_path, video_path,

View File

@@ -22,8 +22,8 @@ with open("README.md", "r", encoding="utf-8") as f:
# Required dependencies # Required dependencies
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
"torch>=2.3.1", "torch>=2.5.1",
"torchvision>=0.18.1", "torchvision>=0.20.1",
"numpy>=1.24.4", "numpy>=1.24.4",
"tqdm>=4.66.1", "tqdm>=4.66.1",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
@@ -58,7 +58,7 @@ EXTRA_PACKAGES = {
"scikit-image>=0.24.0", "scikit-image>=0.24.0",
"tensorboard>=2.17.0", "tensorboard>=2.17.0",
"pycocotools>=2.0.8", "pycocotools>=2.0.8",
"tensordict>=0.5.0", "tensordict>=0.6.0",
"opencv-python>=4.7.0", "opencv-python>=4.7.0",
"submitit>=1.5.1", "submitit>=1.5.1",
], ],

View File

@@ -375,7 +375,7 @@ def main():
parser.add_argument( parser.add_argument(
"--sam2_checkpoint", "--sam2_checkpoint",
type=str, type=str,
default="./checkpoints/sam2.1_hiera_b+.pt", default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint", help="path to the SAM 2 model checkpoint",
) )
parser.add_argument( parser.add_argument(
@@ -434,6 +434,11 @@ def main():
help="whether to track objects that appear later in the video (i.e. not on the first frame; " help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
) )
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args() args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs # if we use per-object PNG files, they could possibly overlap in inputs and outputs
@@ -445,6 +450,7 @@ def main():
ckpt_path=args.sam2_checkpoint, ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing, apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra, hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
) )
if args.use_all_masks: if args.use_all_masks:

View File

@@ -623,7 +623,7 @@ class Trainer:
# compute output # compute output
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast("cuda", with torch.cuda.amp.autocast(
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=( dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype) get_amp_type(self.optim_conf.amp.amp_dtype)
@@ -858,8 +858,7 @@ class Trainer:
# grads will also update a model even if the step doesn't produce # grads will also update a model even if the step doesn't produce
# gradients # gradients
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
with torch.amp.autocast( with torch.cuda.amp.autocast(
"cuda",
enabled=self.optim_conf.amp.enabled, enabled=self.optim_conf.amp.enabled,
dtype=get_amp_type(self.optim_conf.amp.amp_dtype), dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
): ):

4388
uv.lock generated

File diff suppressed because it is too large Load Diff