77 Commits

Author SHA1 Message Date
rentainhe
8b56c25344 update to latest sam2 2024-12-21 11:21:48 +08:00
Ronghang Hu
2b90b9f5ce remove .pin_memory() in obj_pos of SAM2Base to resolve and error in MPS (#495)
In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487.

(close https://github.com/facebookresearch/sam2/issues/487)
2024-12-15 16:47:17 -08:00
Ronghang Hu
722d1d1511 patch for the case of offload_state_to_cpu=True in the new SAM2VideoPredictor (#490)
This PR adds a pathc for the case of `offload_state_to_cpu=True` where `pred_masks` might have been offload to CPU device (close https://github.com/facebookresearch/sam2/issues/489)
2024-12-12 15:12:13 -08:00
Ronghang Hu
393ae336a7 SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2:

- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
  * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
  * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
  * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
  * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
  * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
  * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
  * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
2024-12-11 15:00:55 -08:00
Haitham Khedr
c2ec8e14a1 remove unused paths (#384) 2024-10-14 10:40:54 -04:00
Roman Rädle
c98aa6bea3 Merge pull request #364 from facebookresearch/pr364
[sam2][demo][1/x] Fix file upload

Summary:

The Strawberry GraphQL library recently disabled multipart requests by default. This resulted in a video upload request returning "Unsupported content type" instead of uploading the video, processing it, and returning the video path.

This issue was raised in #361. A forward fix is to add `multipart_uploads_enabled=True` to the endpoint view.

Test Plan:

Tested locally with cURL and upload succeeds

*Request*

```
curl http://localhost:7263/graphql \
  -F operations='{ "query": "mutation($file: Upload!){ uploadVideo(file: $file) { path } }", "variables": { "file": null } }' \
  -F map='{ "file": ["variables.file"] }' \
  -F file=@video.mov
```

*Response*

```
{"data": {"uploadVideo": {"path": "uploads/<HASH>.mp4"}}}
```
2024-10-08 15:28:14 -07:00
Roman Rädle
ff9704fc0e [sam2][demo][1/x] Fix file upload
Summary:

The Strawberry GraphQL library recently disabled multipart requests by default. This resulted in a video upload request returning "Unsupported content type" instead of uploading the video, processing it, and returning the video path.

This issue was raised in #361. A forward fix is to add `multipart_uploads_enabled=True` to the endpoint view.

Test Plan:

Tested locally with cURL and upload succeeds

*Request*

```
curl http://localhost:7263/graphql \
  -F operations='{ "query": "mutation($file: Upload!){ uploadVideo(file: $file) { path } }", "variables": { "file": null } }' \
  -F map='{ "file": ["variables.file"] }' \
  -F file=@video.mov
```

*Response*

```
{"data": {"uploadVideo": {"path": "uploads/<HASH>.mp4"}}}
```
2024-10-08 14:58:28 -07:00
Ronghang Hu
29267c8e39 [doc] Check and raise an error if the user is running Python from the parent directory of the sam2 repo (#359)
If the user has "sam2/sam2" in their path, they are likey importing the repo itself as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). This typically happens because the user is running Python from the parent directory that contains the sam2 repo they cloned.

In general, the user should not run Python from the parent dir when the repo is cloned into (same is true for e.g. Numpy repo that contains names like `numpy/numpy` where the module and the repo have the same name), as the user encountered in https://github.com/facebookresearch/sam2/issues/346.

(close https://github.com/facebookresearch/sam2/issues/346)
2024-10-05 00:34:06 -07:00
Ronghang Hu
e22521832f [demo] add GPU to resources (#355)
This small PR adds GPU specification in `docker-compose.yaml` for the SAM 2 interactive webdemo, following https://docs.docker.com/compose/how-tos/gpu-support/#example-of-a-compose-file-for-running-a-service-with-access-to-1-gpu-device. It fixes a GPU access error as reported in https://github.com/facebookresearch/sam2/issues/354.

(close https://github.com/facebookresearch/sam2/issues/354)
2024-10-03 16:48:56 -07:00
Haitham Khedr
8bf0920e66 Add MANIFEST.in (#353) 2024-10-03 10:40:13 -07:00
Haitham Khedr
52198ead0e Merge pull request #2 from kit1980/patch-1
Use `weights_only` for loading
2024-10-01 22:32:58 +02:00
Ronghang Hu
98fcb164bf Update links after renaming the repo from segment-anything-2 to sam2 (#341)
This PR update repo links after we renamed the repo from `segment-anything-2` to `sam2`. It also changes `NAME` in setup.py to `SAM-2` (which is already the named used in pip setup since python packages don't allow whitespace)
2024-09-30 20:27:44 -07:00
Ronghang Hu
05d9e57fb3 [docs] add a release note and new installation instructions for SAM 2.1 (#338) 2024-09-30 09:55:58 -07:00
Ronghang Hu
429a2c7360 minor update README.md 2024-09-28 23:32:25 -07:00
Chay Ryali
3a7889d905 Merge pull request #335 from facebookresearch/sam2.1
SAM 2.1
2024-09-28 23:01:29 -07:00
Haitham Khedr
aa9b8722d0 SAM2.1
SAM2.1 checkpoints + training code + Demo
2024-09-29 05:49:56 +00:00
Sergii Dymchenko
0f6515ae85 Merge branch 'main' into patch-1 2024-08-26 15:49:40 -07:00
Ronghang Hu
7e1596c0b6 open README.md with unicode (to support Hugging Face emoji); fix various typos (#218)
(close #217, #66, #67, #69, #91, #126, #127, #145)
2024-08-14 09:06:25 -07:00
Haitham Khedr
0db838b117 Merge pull request #205 from facebookresearch/haitham/fix_hf_image_predictor
Fix HF image predictor
2024-08-12 17:04:04 -07:00
Haitham Khedr
fd5125b97a accept kwargs in auto_mask_generator 2024-08-13 00:02:36 +00:00
Haitham Khedr
1191677e1e Fix HF image predictor 2024-08-12 23:41:41 +00:00
Ronghang Hu
dce7b5446f improving warning message and adding further tips for installation (#204) 2024-08-12 11:37:41 -07:00
Ronghang Hu
1034ee2a1a better support for non-CUDA devices (CPU, MPS) (#192) 2024-08-12 10:46:50 -07:00
Chay Ryali
778e112740 Merge pull request #167 from arun477/patch-1
remove unused attributes from hieradet.py
2024-08-09 10:31:56 -07:00
Arun
8f607e2de1 Merge branch 'main' into patch-1 2024-08-09 11:14:43 +05:30
Arun
46945a2122 Update hieradet.py
ufmt formatting fixed.
2024-08-09 11:14:11 +05:30
Ronghang Hu
d421e0b040 add Colab support to the notebooks; pack config files in sam2_configs package during installation (#176) 2024-08-08 11:03:22 -07:00
Arun
102ddb8899 Merge branch 'main' into patch-1 2024-08-08 09:59:47 +05:30
Ronghang Hu
6186d1529a also catch errors during installation in case CUDAExtension cannot be loaded (#175)
Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
2024-08-07 12:26:11 -07:00
Ronghang Hu
6ecb5ff8d0 Add interface for box prompt in SAM 2 video predictor (#174)
This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained).

The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
2024-08-07 11:54:30 -07:00
Arun
086daf0641 Merge branch 'main' into patch-1 2024-08-07 21:50:26 +05:30
Haitham Khedr
6ba4c65cb2 Merge pull request #128 from NielsRogge/add_hf
Integrate with Hugging Face
2024-08-07 08:54:49 -07:00
Niels
9b58611e24 Address comment 2024-08-07 17:48:12 +02:00
Arun
6ec8560436 Update hieradet.py
Not used  
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5

F.scaled_dot_product_attention takes care of this automatically.
2024-08-07 11:35:46 +05:30
Niels
43c385c263 Update docstrings 2024-08-06 23:00:26 +02:00
Niels
322aa3e7e5 Revert code snippet 2024-08-06 22:57:07 +02:00
Diego Garcia
511199d7a9 Updated INSTALL.md with CUDA_HOME-related troubleshooting (#140)
This is referring to https://github.com/facebookresearch/segment-anything-2/issues/137 , which in itself refers to a common problem during installation, mentioned on https://github.com/facebookresearch/segment-anything-2/issues/19 .

Some users may encounter significant trouble installing the project, running into the error `OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.`. Simply adding the `--no-build-isolation` flag to the pip install, e.g. `pip install --no-build-isolation -e .`, usually solves this problem. However, this fix is not mentioned anywhere within the readmes or installation troubleshooting docs. This PR adds this recommendation into the INSTALL.md file under the "My installation failed with `CUDA_HOME environment variable is not set` " section, ensuring that more users are aware of this potential fix.

Examples of users experiencing related difficulties when installing:
https://github.com/facebookresearch/segment-anything-2/issues/19
https://github.com/facebookresearch/segment-anything-2/issues/41
https://github.com/facebookresearch/segment-anything-2/issues/99
https://github.com/facebookresearch/segment-anything-2/issues/133
2024-08-06 13:45:15 -07:00
Niels
8f15c6255a Format using ufmt 2024-08-06 22:43:35 +02:00
jhj0517
0bac418736 Update INSTALL.md (#156)
This PR suggests a way to resolve the error of `unsupported Microsoft Visual Studio version!` in INSTALL.md.
Adding `-allow-unsupported-compiler` argument for the `nvcc` worked. 

Editing [setup.py](https://github.com/facebookresearch/segment-anything-2/blob/main/setup.py) is required to add the `-allow-unsupported-compiler` argument for `nvcc`.

```python
def get_extensions():
    srcs = ["sam2/csrc/connected_components.cu"]
    compile_args = {
        "cxx": [],
        "nvcc": [
            "-DCUDA_HAS_FP16=1",
            "-D__CUDA_NO_HALF_OPERATORS__",
            "-D__CUDA_NO_HALF_CONVERSIONS__",
            "-D__CUDA_NO_HALF2_OPERATORS__",
            "-allow-unsupported-compiler"  # Add this argument
        ],
    }
    ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
    return ext_modules
```
2024-08-06 13:43:12 -07:00
Niels
27a167c004 Update README 2024-08-06 22:41:32 +02:00
Ronghang Hu
6f7e700c37 Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
2024-08-06 10:52:01 -07:00
Niels
a36edf1e01 Clean up 2024-08-06 08:34:42 +02:00
Niels
e815f70a38 Address comment 2024-08-06 08:32:36 +02:00
Niels
fbf7e3a664 Add link 2024-08-05 22:12:15 +02:00
Niels
e9503c96fe Move HF to separate section 2024-08-05 22:10:57 +02:00
Niels
c3393d8b5f Include original code snippet 2024-08-05 22:08:54 +02:00
Haitham Khedr
0230c5ff93 Merge pull request #152 from haithamkhedr/main
Configure a workflow for format checking
2024-08-05 10:02:32 -07:00
Haitham Khedr
5e3d6ca6b5 Merge pull request #1 from haithamkhedr/CI 2024-08-05 09:36:54 -07:00
Haitham Khedr
3b0fd9e4a9 Update workflow 2024-08-05 09:28:28 -07:00
Haitham Khedr
acd3939f88 Add workflow 2024-08-05 09:16:29 -07:00
Niels
841cc1f015 Update docstring 2024-08-05 09:44:03 +02:00
Niels
e93be7f6aa Update README 2024-08-05 09:43:04 +02:00
Niels
cb48213066 Update links 2024-08-05 09:41:40 +02:00
Niels
6aeee34775 Make huggingface_hub soft dependency 2024-08-05 09:37:53 +02:00
Niels
0c28c630c2 Do not load config from the hub 2024-08-03 14:45:20 +02:00
Niels
3af4e82263 Add model_id_to_filenames 2024-08-03 14:18:23 +02:00
Niels
17b74501fb Use classmethod 2024-08-03 14:14:12 +02:00
Niels
b72a8a97f0 First draft 2024-08-03 12:57:05 +02:00
Ronghang Hu
57bc94b739 Merge pull request #119 from facebookresearch/ronghanghu/installation_faq
[doc] add `INSTALL.md` as an installation FAQ page
2024-08-02 15:07:25 -07:00
Ronghang Hu
b744a3c084 [doc] add INSTALL.md as an installation FAQ page 2024-08-02 22:05:46 +00:00
Haitham Khedr
d1fc9a0686 Merge pull request #116 from facebookresearch/arXiv-paper
Update arXiv paper citation
2024-08-02 13:02:08 -07:00
Haitham Khedr
59550d4deb Update README.md 2024-08-02 12:58:23 -07:00
Haitham Khedr
de4db16676 Update README.md 2024-08-02 12:56:06 -07:00
Ronghang Hu
0e78a11899 Merge pull request #61 from DanBrown47/main
Change git repo url from SSH to HTTPS
2024-07-31 21:30:11 -07:00
Danwand
fa2796bb47 Change git repo url from SSH to HTTPS
The change is made for researchers to easly clone the project to check out from systems and platforms with SSH not in sync with github. Eg : Google Colab, Remote GPU Servers etc
2024-07-31 13:55:06 +05:30
Haitham Khedr
86827e2fba Merge pull request #32 from CharlesCNorton/patch-5
Fix: Hyphenate to "model-in-the-loop"
2024-07-30 07:54:23 -07:00
Ronghang Hu
cd270ed4f1 Merge pull request #29 from CharlesCNorton/patch-3
fix: correct spelling
2024-07-30 07:49:01 -07:00
Ronghang Hu
32750fa695 Merge pull request #30 from CharlesCNorton/patch-4
Fix typo in comment: "evalaution" to "evaluation"
2024-07-30 07:48:16 -07:00
CharlesCNorton
e62ec497b8 Fix: Hyphenate to "model-in-the-loop"
The phrase "model-in-the-loop" is now hyphenated to align with standard practices in technical literature, where hyphenation of compound adjectives clarifies their function as a single descriptor.
2024-07-30 08:35:59 -04:00
CharlesCNorton
c8127182c1 Fix typo in comment: "evalaution" to "evaluation"
Corrected the spelling of "evaluation" in the comment describing the use of the `--use_all_masks` flag in the `main` function.
2024-07-30 08:24:39 -04:00
CharlesCNorton
f882beb157 fix: correct spelling
Corrected two instances of "a even number" to "an even number" for correct article usage before a vowel sound.
2024-07-30 08:13:47 -04:00
Ronghang Hu
82b026cd55 Merge pull request #7 from CharlesCNorton/patch-2
Correct typo in sav_dataset README.md
2024-07-29 19:28:17 -07:00
CharlesCNorton
de05a2e0c5 Correct typo in sav_dataset README.md
Change "there licenses" to "their licenses" in the README.MD
2024-07-29 22:25:56 -04:00
Ronghang Hu
b3011f0ea6 Merge pull request #5 from CharlesCNorton/patch-1
Fix typo in README: "Aything" corrected to "Anything"
2024-07-29 19:24:31 -07:00
CharlesCNorton
662fd3d90e Fix typo in README: "Aything" corrected to "Anything"
Corrected a typo in the Segment Anything Video Dataset section of the README file. The word "Aything" has been updated to "Anything."
2024-07-29 22:15:41 -04:00
Sergii Dymchenko
658aaba327 Use weights_only for loading
sam2/build_sam.py:81:14: TOR102 [*] `torch.load` without `weights_only` parameter is unsafe. Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`.

Found with https://github.com/pytorch-labs/torchfix/
2024-07-29 16:54:54 -07:00
Haitham Khedr
0c5f8c5432 Initial commit 2024-07-29 21:54:20 +00:00
28 changed files with 1796 additions and 434 deletions

17
.github/workflows/check_fmt.yml vendored Normal file
View File

@@ -0,0 +1,17 @@
name: SAM2/fmt
on:
pull_request:
branches:
- main
jobs:
ufmt_check:
runs-on: ubuntu-latest
steps:
- name: Check formatting
uses: omnilib/ufmt@action-v1
with:
path: sam2 tools
version: "2.0.0b2"
python-version: "3.10"
black-version: "24.2.0"
usort-version: "1.0.2"

1
.gitignore vendored
View File

@@ -145,3 +145,4 @@ dmypy.json
outputs/ outputs/
.idea/ .idea/
demo/backend/checkpoints/*.pt

View File

@@ -2,7 +2,7 @@
### Requirements ### Requirements
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. - Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
@@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
</details> </details>
<details> <details>

View File

@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright 2023 - present, IDEA Research. Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

27
RELEASE_NOTES.md Normal file
View File

@@ -0,0 +1,27 @@
## SAM 2 release notes
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
### 07/29/2024 -- SAM 2 is released
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
* SAM 2 code: https://github.com/facebookresearch/sam2
* SAM 2 demo: https://sam2.metademolab.com/
* SAM 2 paper: https://arxiv.org/abs/2408.00714

View File

@@ -1,4 +1,4 @@
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
ARG MODEL_SIZE=base_plus ARG MODEL_SIZE=base_plus
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}

View File

@@ -105,7 +105,7 @@ cd demo/backend/server/
```bash ```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \ PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \ APP_ROOT="$(pwd)/../../../" \
APP_URL=http://localhost:7263 \ API_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \ MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \ DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \ DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \

View File

@@ -1,6 +1,6 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=61.0", "setuptools>=61.0",
"torch>=2.3.1", "torch>=2.5.1",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

92
sam2/benchmark.py Normal file
View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from sam2.build_sam import build_sam2_video_predictor
# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)
# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()
# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0
print("FPS: ", count * num_frames / total)

View File

@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
vos_optimized=False,
**kwargs, **kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
] ]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]
if apply_postprocessing: if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [ hydra_overrides_extra += [

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -97,7 +97,7 @@ trainer:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -108,7 +108,7 @@ trainer:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = ( windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp) return windows, (Hp, Wp)
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view( x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1 B, Hp // window_size, Wp // window_size, window_size, window_size, -1
) )
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :]
return x return x

View File

@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000, temperature: int = 10000,
normalize: bool = True, normalize: bool = True,
scale: Optional[float] = None, scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
): ):
super().__init__() super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width" assert num_pos_feats % 2 == 0, "Expecting even model width"
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale self.scale = scale
self.cache = {} self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y): def _encode_xy(self, x, y):
# The positions are expected to be normalized # The positions are expected to be normalized
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
return pos return pos
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor): def _pe(self, B, device, *cache_key):
cache_key = (x.shape[-2], x.shape[-1]) H, W = cache_key
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = ( y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1) .view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1]) .repeat(B, 1, W)
) )
x_embed = ( x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1) .view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1) .repeat(B, H, 1)
) )
if self.normalize: if self.normalize:
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0] self.cache[cache_key] = pos[0]
return pos return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module): class PositionEmbeddingRandom(nn.Module):
""" """

View File

@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords( point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size points, self.input_image_size
) )
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding = torch.where(
point_embedding[labels == 0] += self.point_embeddings[0].weight (labels == -1).unsqueeze(-1),
point_embedding[labels == 1] += self.point_embeddings[1].weight torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding[labels == 2] += self.point_embeddings[2].weight point_embedding,
point_embedding[labels == 3] += self.point_embeddings[3].weight )
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

View File

@@ -4,9 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import math import math
import warnings
from functools import partial from functools import partial
from typing import Tuple, Type from typing import Tuple, Type
@@ -16,29 +14,6 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):
@@ -265,20 +240,7 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try: out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
out = self.out_proj(out) out = self.out_proj(out)
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length # whether to repeat q rope to match k length
# this is needed for cross-attention to memories # this is needed for cross-attention to memories
rope_k_repeat=False, rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
) )
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis self.freqs_cis = (
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat self.rope_k_repeat = rope_k_repeat
def forward( def forward(
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try: out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
out = self.out_proj(out) out = self.out_proj(out)

View File

@@ -628,7 +628,9 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs: if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1 t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device) obj_pos = torch.tensor(pos_list).to(
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)

View File

@@ -8,6 +8,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
@@ -26,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False, clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False, add_all_frames_to_correct_as_cond=False,
@@ -37,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
self.fill_hole_area = fill_hole_area self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
@torch.inference_mode() @torch.inference_mode()
@@ -87,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = [] inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict" # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {} inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame # A temporary storage to hold new outputs when user interact with a frame
@@ -99,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["temp_output_dict_per_obj"] = {} inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs # Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking) # (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked) # metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False inference_state["frames_tracked_per_obj"] = {}
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0 # Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1) self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state return inference_state
@@ -133,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
if obj_idx is not None: if obj_idx is not None:
return obj_idx return obj_idx
# This is a new object id not sent to the server before. We only allow adding # We always allow adding new objects (including after tracking starts).
# new objects *before* the tracking starts. allow_new_object = True
allow_new_object = not inference_state["tracking_has_started"]
if allow_new_object: if allow_new_object:
# get the next object slot # get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"]) obj_idx = len(inference_state["obj_id_to_idx"])
@@ -153,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>} "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>} "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
} }
inference_state["frames_tracked_per_obj"][obj_idx] = {}
return obj_idx return obj_idx
else: else:
raise RuntimeError( raise RuntimeError(
@@ -213,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
"box prompt must be provided before any point prompt " "box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)" "(please use clear_old_points=True instead)"
) )
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor): if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device) box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2) box_coords = box.reshape(1, 2, 2)
@@ -251,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
# frame, meaning that the inputs points are to generate segments on this frame without # frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks. # the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order # whether to track in reverse time order
if is_init_cond_frame: if is_init_cond_frame:
reverse = False reverse = False
else: else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or # Add a frame to conditioning output if it's an initial conditioning frame or
@@ -305,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, inference_state,
frame_idx, frame_idx,
is_cond=is_cond, is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True, consolidate_at_video_res=True,
) )
_, video_res_masks = self._get_orig_video_res_output( _, video_res_masks = self._get_orig_video_res_output(
@@ -356,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
# frame, meaning that the inputs points are to generate segments on this frame without # frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked), # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks. # the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order # whether to track in reverse time order
if is_init_cond_frame: if is_init_cond_frame:
reverse = False reverse = False
else: else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or # Add a frame to conditioning output if it's an initial conditioning frame or
@@ -393,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, inference_state,
frame_idx, frame_idx,
is_cond=is_cond, is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True, consolidate_at_video_res=True,
) )
_, video_res_masks = self._get_orig_video_res_output( _, video_res_masks = self._get_orig_video_res_output(
@@ -428,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, inference_state,
frame_idx, frame_idx,
is_cond, is_cond,
run_mem_encoder,
consolidate_at_video_res=False, consolidate_at_video_res=False,
): ):
""" """
@@ -445,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
# Optionally, we allow consolidating the temporary outputs at the original # Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts). # video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res: if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"] consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"] consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res" consolidated_mask_key = "pred_masks_video_res"
@@ -458,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
# constraints to object scores. Its "pred_masks" are prefilled with a large # constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects. # negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = { consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full( consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W), size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE, fill_value=NO_OBJ_SCORE,
dtype=torch.float32, dtype=torch.float32,
device=inference_state["storage_device"], device=inference_state["storage_device"],
), ),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
"object_score_logits": torch.full(
size=(batch_size, 1),
# default to 10.0 for object_score_logits, i.e. assuming the object is
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
fill_value=10.0,
dtype=torch.float32,
device=inference_state["device"],
),
} }
empty_mask_ptr = None
for obj_idx in range(batch_size): for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
@@ -498,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer. # placeholder above) and set its object pointer to be a dummy pointer.
if out is None: if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue continue
# Add the temporary object output mask to consolidated output mask # Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"] obj_mask = out["pred_masks"]
@@ -523,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
align_corners=False, align_corners=False,
) )
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
"object_score_logits"
]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
object_score_logits=consolidated_out["object_score_logits"],
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out return consolidated_out
def _get_empty_mask_ptr(self, inference_state, frame_idx):
"""Get a dummy object pointer based on an empty mask on the current frame."""
# A dummy (empty) mask with a single object
batch_size = 1
mask_inputs = torch.zeros(
(batch_size, 1, self.image_size, self.image_size),
dtype=torch.float32,
device=inference_state["device"],
)
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=True,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=mask_inputs,
output_dict={},
num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
@torch.inference_mode() @torch.inference_mode()
def propagate_in_video_preflight(self, inference_state): def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking.""" """Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset. # Check and make sure that every object has received input points or masks.
inference_state["tracking_has_started"] = True
batch_size = self._get_obj_num(inference_state) batch_size = self._get_obj_num(inference_state)
if batch_size == 0:
raise RuntimeError(
"No input points or masks are provided for any object; please add inputs first."
)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict". # add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] for obj_idx in range(batch_size):
output_dict = inference_state["output_dict"] obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
# "consolidated_frame_inds" contains indices of those frames where consolidated obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# temporary outputs have been added (either in this call or any previous calls for is_cond in [False, True]:
# to `propagate_in_video_preflight`). # Separately consolidate conditioning and non-conditioning temp outputs
consolidated_frame_inds = inference_state["consolidated_frame_inds"] storage_key = (
for is_cond in [False, True]: "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
) )
# merge them into "output_dict" and also create per-object slices # Find all the frames that contain temporary outputs for any objects
output_dict[storage_key][frame_idx] = consolidated_out # (these should be the frames that have just received clicks for mask inputs
self._add_output_per_object( # via `add_new_points_or_box` or `add_new_mask`)
inference_state, frame_idx, consolidated_out, storage_key for frame_idx, out in obj_temp_output_dict[storage_key].items():
) # Run memory encoder on the temporary outputs (if the memory feature is missing)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( if out["maskmem_features"] is None:
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 high_res_masks = torch.nn.functional.interpolate(
) out["pred_masks"].to(inference_state["device"]),
if clear_non_cond_mem: size=(self.image_size, self.image_size),
# clear non-conditioning memory of the surrounding frames mode="bilinear",
self._clear_non_cond_mem_around_input(inference_state, frame_idx) align_corners=False,
)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
high_res_masks=high_res_masks,
object_score_logits=out["object_score_logits"],
# these frames are what the user interacted with
is_mask_from_pts=True,
)
out["maskmem_features"] = maskmem_features
out["maskmem_pos_enc"] = maskmem_pos_enc
# clear temporary outputs in `temp_output_dict_per_obj` obj_output_dict[storage_key][frame_idx] = out
for obj_temp_output_dict in temp_output_dict_per_obj.values(): if self.clear_non_cond_mem_around_input:
# clear non-conditioning memory of the surrounding frames
self._clear_obj_non_cond_mem_around_input(
inference_state, frame_idx, obj_idx
)
# clear temporary outputs in `temp_output_dict_per_obj`
obj_temp_output_dict[storage_key].clear() obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior # check and make sure that every object has received input points or masks
# output on the same frame in "non_cond_frame_outputs" obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
for frame_idx in output_dict["cond_frame_outputs"]: if len(obj_output_dict["cond_frame_outputs"]) == 0:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None) obj_id = self._obj_idx_to_id(inference_state, obj_idx)
for obj_output_dict in inference_state["output_dict_per_obj"].values(): raise RuntimeError(
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
)
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in obj_output_dict["cond_frame_outputs"]: for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode() @torch.inference_mode()
def propagate_in_video( def propagate_in_video(
@@ -670,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
"""Propagate the input points across frames to track in the entire video.""" """Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state) self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"] obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"] num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state) batch_size = self._get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order # set start index, end index, and processing order
if start_frame_idx is None: if start_frame_idx is None:
# default: start from the earliest frame with input points # default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"]) start_frame_idx = min(
t
for obj_output_dict in inference_state["output_dict_per_obj"].values()
for t in obj_output_dict["cond_frame_outputs"]
)
if max_frame_num_to_track is None: if max_frame_num_to_track is None:
# default: track all the frames in the video # default: track all the frames in the video
max_frame_num_to_track = num_frames max_frame_num_to_track = num_frames
@@ -701,78 +581,54 @@ class SAM2VideoPredictor(SAM2Base):
processing_order = range(start_frame_idx, end_frame_idx + 1) processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"): for frame_idx in tqdm(processing_order, desc="propagate in video"):
# We skip those frames already in consolidated outputs (these are frames pred_masks_per_obj = [None] * batch_size
# that received input clicks or mask). Note that we cannot directly run for obj_idx in range(batch_size):
# batched forward on them via `_run_single_frame_inference` because the obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
# number of clicks on each object might be different. # We skip those frames already in consolidated outputs (these are frames
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: # that received input clicks or mask). Note that we cannot directly run
storage_key = "cond_frame_outputs" # batched forward on them via `_run_single_frame_inference` because the
current_out = output_dict[storage_key][frame_idx] # number of clicks on each object might be different.
pred_masks = current_out["pred_masks"] if frame_idx in obj_output_dict["cond_frame_outputs"]:
if clear_non_cond_mem: storage_key = "cond_frame_outputs"
# clear non-conditioning memory of the surrounding frames current_out = obj_output_dict[storage_key][frame_idx]
self._clear_non_cond_mem_around_input(inference_state, frame_idx) device = inference_state["device"]
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
storage_key = "non_cond_frame_outputs" if self.clear_non_cond_mem_around_input:
current_out = output_dict[storage_key][frame_idx] # clear non-conditioning memory of the surrounding frames
pred_masks = current_out["pred_masks"] self._clear_obj_non_cond_mem_around_input(
else: inference_state, frame_idx, obj_idx
storage_key = "non_cond_frame_outputs" )
current_out, pred_masks = self._run_single_frame_inference( else:
inference_state=inference_state, storage_key = "non_cond_frame_outputs"
output_dict=output_dict, current_out, pred_masks = self._run_single_frame_inference(
frame_idx=frame_idx, inference_state=inference_state,
batch_size=batch_size, output_dict=obj_output_dict,
is_init_cond_frame=False, frame_idx=frame_idx,
point_inputs=None, batch_size=1, # run on the slice of a single object
mask_inputs=None, is_init_cond_frame=False,
reverse=reverse, point_inputs=None,
run_mem_encoder=True, mask_inputs=None,
) reverse=reverse,
output_dict[storage_key][frame_idx] = current_out run_mem_encoder=True,
# Create slices of per-object outputs for subsequent interaction with each )
# individual object after tracking. obj_output_dict[storage_key][frame_idx] = current_out
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
) "reverse": reverse
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} }
pred_masks_per_obj[obj_idx] = pred_masks
# Resize the output mask to the original video resolution (we directly use # Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between) # the mask scores on GPU for output to avoid any CPU conversion in between)
if len(pred_masks_per_obj) > 1:
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
else:
all_pred_masks = pred_masks_per_obj[0]
_, video_res_masks = self._get_orig_video_res_output( _, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks inference_state, all_pred_masks
) )
yield frame_idx, obj_ids, video_res_masks yield frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
"object_score_logits": current_out["object_score_logits"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode() @torch.inference_mode()
def clear_all_prompts_in_frame( def clear_all_prompts_in_frame(
self, inference_state, frame_idx, obj_id, need_output=True self, inference_state, frame_idx, obj_id, need_output=True
@@ -788,41 +644,14 @@ class SAM2VideoPredictor(SAM2Base):
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
# Check and see if there are still any inputs left on this frame # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
batch_size = self._get_obj_num(inference_state) obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
frame_has_input = False out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
for obj_idx2 in range(batch_size): if out is not None:
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: # The frame is not a conditioning frame anymore since it's not receiving inputs,
frame_has_input = True # so we "downgrade" its output (if exists) to a non-conditioning frame output.
break obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
frame_has_input = True
break
# If this frame has no remaining inputs for any objects, we further clear its
# conditioning frame status
if not frame_has_input:
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
if out is not None:
# The frame is not a conditioning frame anymore since it's not receiving inputs,
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_already_tracked"].pop(frame_idx, None)
# Similarly, do it for the sliced output on each object.
for obj_idx2 in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if obj_out is not None:
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
# If all the conditioning frames have been removed, we also clear the tracking outputs
if len(output_dict["cond_frame_outputs"]) == 0:
self._reset_tracking_results(inference_state)
if not need_output: if not need_output:
return return
@@ -836,7 +665,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, inference_state,
frame_idx, frame_idx,
is_cond=is_cond, is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True, consolidate_at_video_res=True,
) )
_, video_res_masks = self._get_orig_video_res_output( _, video_res_masks = self._get_orig_video_res_output(
@@ -856,6 +684,7 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["mask_inputs_per_obj"].clear() inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear() inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear() inference_state["temp_output_dict_per_obj"].clear()
inference_state["frames_tracked_per_obj"].clear()
def _reset_tracking_results(self, inference_state): def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos.""" """Reset all tracking inputs and results across the videos."""
@@ -869,12 +698,8 @@ class SAM2VideoPredictor(SAM2Base):
for v in inference_state["temp_output_dict_per_obj"].values(): for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear() v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear() v["non_cond_frame_outputs"].clear()
inference_state["output_dict"]["cond_frame_outputs"].clear() for v in inference_state["frames_tracked_per_obj"].values():
inference_state["output_dict"]["non_cond_frame_outputs"].clear() v.clear()
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"].clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size): def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame.""" """Compute the image features on a given frame."""
@@ -1092,8 +917,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["obj_ids"] = new_obj_ids inference_state["obj_ids"] = new_obj_ids
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
# it's already handled in Step 0)
def _map_keys(container): def _map_keys(container):
new_kvs = [] new_kvs = []
for k in old_obj_inds: for k in old_obj_inds:
@@ -1106,30 +929,9 @@ class SAM2VideoPredictor(SAM2Base):
_map_keys(inference_state["mask_inputs_per_obj"]) _map_keys(inference_state["mask_inputs_per_obj"])
_map_keys(inference_state["output_dict_per_obj"]) _map_keys(inference_state["output_dict_per_obj"])
_map_keys(inference_state["temp_output_dict_per_obj"]) _map_keys(inference_state["temp_output_dict_per_obj"])
_map_keys(inference_state["frames_tracked_per_obj"])
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
def _slice_state(output_dict, storage_key):
for frame_idx, out in output_dict[storage_key].items():
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
out["maskmem_pos_enc"] = [
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
]
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
out["object_score_logits"] = out["object_score_logits"][
remain_old_obj_inds
]
# also update the per-object slices
self._add_output_per_object(
inference_state, frame_idx, out, storage_key
)
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# could show an updated mask for objects previously occluded by the object being removed # could show an updated mask for objects previously occluded by the object being removed
if need_output: if need_output:
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
@@ -1142,7 +944,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, inference_state,
frame_idx, frame_idx,
is_cond=is_cond, is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True, consolidate_at_video_res=True,
) )
_, video_res_masks = self._get_orig_video_res_output( _, video_res_masks = self._get_orig_video_res_output(
@@ -1164,9 +965,259 @@ class SAM2VideoPredictor(SAM2Base):
r = self.memory_temporal_stride_for_eval r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"] batch_size = self._get_obj_num(inference_state)
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] for obj_idx in range(batch_size):
for t in range(frame_idx_begin, frame_idx_end + 1): obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
non_cond_frame_outputs.pop(t, None) non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
for obj_output_dict in inference_state["output_dict_per_obj"].values(): for t in range(frame_idx_begin, frame_idx_end + 1):
obj_output_dict["non_cond_frame_outputs"].pop(t, None) non_cond_frame_outputs.pop(t, None)
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
"""Optimized for the VOS setting"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compile_all_components()
def _compile_all_components(self):
print("Compiling all components for VOS setting. First time may be very slow.")
self.memory_encoder.forward = torch.compile(
self.memory_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
self.memory_attention.forward = torch.compile(
self.memory_attention.forward,
mode="max-autotune",
fullgraph=True,
dynamic=True, # Num. of memories varies
)
self.sam_prompt_encoder.forward = torch.compile(
self.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
self.sam_mask_decoder.forward = torch.compile(
self.sam_mask_decoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
def forward_image(self, img_batch: torch.Tensor):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the backbone features and pos encoding to enable compilation.
"""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0]
)
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1]
)
# Clone to help torch.compile
for i in range(len(backbone_out["backbone_fpn"])):
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
i
].clone()
return backbone_out
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
# Clone image_pe and the outputs of sam_prompt_encoder
# to enable compilation
sparse_embeddings = sparse_embeddings.clone()
dense_embeddings = dense_embeddings.clone()
image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
# Clone the output of sam_mask_decoder
# to enable compilation
low_res_multimasks = low_res_multimasks.clone()
ious = ious.clone()
sam_output_tokens = sam_output_tokens.clone()
object_score_logits = object_score_logits.clone()
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
object_score_logits,
is_mask_from_pts,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the memories and their pos enc to enable compilation.
"""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(
pred_masks_high_res
)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
)
# Clone the feats and pos_enc to enable compilation
maskmem_features = maskmem_out["vision_features"].clone()
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc

File diff suppressed because it is too large Load Diff

View File

@@ -22,8 +22,8 @@ with open("README.md", "r", encoding="utf-8") as f:
# Required dependencies # Required dependencies
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
"torch>=2.3.1", "torch>=2.5.1",
"torchvision>=0.18.1", "torchvision>=0.20.1",
"numpy>=1.24.4", "numpy>=1.24.4",
"tqdm>=4.66.1", "tqdm>=4.66.1",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
@@ -58,7 +58,7 @@ EXTRA_PACKAGES = {
"scikit-image>=0.24.0", "scikit-image>=0.24.0",
"tensorboard>=2.17.0", "tensorboard>=2.17.0",
"pycocotools>=2.0.8", "pycocotools>=2.0.8",
"tensordict>=0.5.0", "tensordict>=0.6.0",
"opencv-python>=4.7.0", "opencv-python>=4.7.0",
"submitit>=1.5.1", "submitit>=1.5.1",
], ],

View File

@@ -375,7 +375,7 @@ def main():
parser.add_argument( parser.add_argument(
"--sam2_checkpoint", "--sam2_checkpoint",
type=str, type=str,
default="./checkpoints/sam2.1_hiera_b+.pt", default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint", help="path to the SAM 2 model checkpoint",
) )
parser.add_argument( parser.add_argument(
@@ -434,6 +434,11 @@ def main():
help="whether to track objects that appear later in the video (i.e. not on the first frame; " help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
) )
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args() args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs # if we use per-object PNG files, they could possibly overlap in inputs and outputs
@@ -445,6 +450,7 @@ def main():
ckpt_path=args.sam2_checkpoint, ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing, apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra, hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
) )
if args.use_all_masks: if args.use_all_masks: