Merge branch 'main' into patch-1
This commit is contained in:
17
.github/workflows/check_fmt.yml
vendored
Normal file
17
.github/workflows/check_fmt.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
name: SAM2/fmt
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
ufmt_check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check formatting
|
||||
uses: omnilib/ufmt@action-v1
|
||||
with:
|
||||
path: sam2 tools
|
||||
version: "2.0.0b2"
|
||||
python-version: "3.10"
|
||||
black-version: "24.2.0"
|
||||
usort-version: "1.0.2"
|
167
INSTALL.md
Normal file
167
INSTALL.md
Normal file
@@ -0,0 +1,167 @@
|
||||
## Installation
|
||||
|
||||
### Requirements
|
||||
|
||||
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
|
||||
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
|
||||
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
|
||||
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
||||
|
||||
Then, install SAM 2 from the root of this repository via
|
||||
```bash
|
||||
pip install -e ".[demo]"
|
||||
```
|
||||
|
||||
Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
|
||||
```bash
|
||||
# skip the SAM 2 CUDA extension
|
||||
SAM2_BUILD_CUDA=0 pip install -e ".[demo]"
|
||||
```
|
||||
This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
|
||||
|
||||
### Building the SAM 2 CUDA extension
|
||||
|
||||
By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
|
||||
|
||||
If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.
|
||||
|
||||
If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
|
||||
```bash
|
||||
pip uninstall -y SAM-2 && \
|
||||
rm -f ./sam2/*.so && \
|
||||
SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[demo]"
|
||||
```
|
||||
|
||||
Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
|
||||
|
||||
Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
|
||||
|
||||
### Common Installation Issues
|
||||
|
||||
Click each issue for its solutions:
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `ImportError: cannot import name '_C' from 'sam2'`
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
This is usually because you haven't run the `pip install -e ".[demo]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
|
||||
|
||||
In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/segment-anything-2/issues/77.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `MissingConfigException: Cannot find primary config 'sam2_hiera_l.yaml'`
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
This is usually because you haven't run the `pip install -e .` step above, so `sam2_configs` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
|
||||
```bash
|
||||
export SAM2_REPO_ROOT=/path/to/segment-anything-2 # path to this repo
|
||||
export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
|
||||
```
|
||||
to manually add `sam2_configs` into your Python's `sys.path`.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
My installation failed with `CUDA_HOME environment variable is not set`
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
|
||||
```
|
||||
export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
|
||||
```
|
||||
and rerun the installation.
|
||||
|
||||
Also, you should make sure
|
||||
```
|
||||
python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
|
||||
```
|
||||
print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
|
||||
|
||||
If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command:
|
||||
```
|
||||
pip install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
|
||||
|
||||
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
|
||||
|
||||
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/segment-anything-2/issues/22, https://github.com/facebookresearch/segment-anything-2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `CUDA error: no kernel image is available for execution on the device`
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).
|
||||
|
||||
You can try pulling the latest code from the SAM 2 repo and running the following
|
||||
```
|
||||
export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
|
||||
```
|
||||
to manually specify the CUDA capability in the compilation target that matches your GPU.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
|
||||
```python
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
```
|
||||
in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
|
||||
```python
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
|
||||
```
|
||||
to relax the attention kernel setting and use other kernels than Flash Attention.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
I got `Error compiling objects for extension`
|
||||
</summary>
|
||||
<br/>
|
||||
|
||||
You may see error log of:
|
||||
> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
|
||||
|
||||
This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).<br>
|
||||
You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/segment-anything-2/blob/main/setup.py). <br>
|
||||
After adding the argument, `get_extension()` will look like this:
|
||||
```python
|
||||
def get_extensions():
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
"nvcc": [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-allow-unsupported-compiler" # Add this argument
|
||||
],
|
||||
}
|
||||
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
||||
return ext_modules
|
||||
```
|
||||
</details>
|
70
README.md
70
README.md
@@ -8,19 +8,20 @@
|
||||
|
||||

|
||||
|
||||
**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model in the loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
|
||||
**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
Please install SAM 2 on a GPU machine using:
|
||||
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:facebookresearch/segment-anything-2.git
|
||||
git clone https://github.com/facebookresearch/segment-anything-2.git
|
||||
|
||||
cd segment-anything-2; pip install -e .
|
||||
cd segment-anything-2 & pip install -e .
|
||||
```
|
||||
If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
||||
|
||||
To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
|
||||
|
||||
@@ -28,6 +29,13 @@ To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplot
|
||||
pip install -e ".[demo]"
|
||||
```
|
||||
|
||||
Note:
|
||||
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
|
||||
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
|
||||
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
|
||||
|
||||
Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Download Checkpoints
|
||||
@@ -35,8 +43,9 @@ pip install -e ".[demo]"
|
||||
First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
|
||||
|
||||
```bash
|
||||
cd checkpoints
|
||||
./download_ckpts.sh
|
||||
cd checkpoints && \
|
||||
./download_ckpts.sh && \
|
||||
cd ..
|
||||
```
|
||||
|
||||
or individually from:
|
||||
@@ -66,9 +75,9 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
masks, _, _ = predictor.predict(<input_prompts>)
|
||||
```
|
||||
|
||||
Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases.
|
||||
Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.
|
||||
|
||||
SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images.
|
||||
SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.
|
||||
|
||||
### Video prediction
|
||||
|
||||
@@ -86,14 +95,50 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points(state, <your prompts>):
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
...
|
||||
```
|
||||
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
|
||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
|
||||
|
||||
## Load from 🤗 Hugging Face
|
||||
|
||||
Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
|
||||
|
||||
For image prediction:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
predictor.set_image(<your_image>)
|
||||
masks, _, _ = predictor.predict(<input_prompts>)
|
||||
```
|
||||
|
||||
For video prediction:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
||||
|
||||
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
state = predictor.init_state(<your_video>)
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
||||
|
||||
# propagate the prompts to get masklets throughout the video
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
...
|
||||
```
|
||||
|
||||
## Model Description
|
||||
|
||||
@@ -106,7 +151,7 @@ Please refer to the examples in [video_predictor_example.ipynb](./notebooks/vide
|
||||
|
||||
\* Compile the model by setting `compile_image_encoder: True` in the config.
|
||||
|
||||
## Segment Aything Video Dataset
|
||||
## Segment Anything Video Dataset
|
||||
|
||||
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
||||
|
||||
@@ -134,7 +179,8 @@ If you use SAM 2 or the SA-V dataset in your research, please use the following
|
||||
@article{ravi2024sam2,
|
||||
title={SAM 2: Segment Anything in Images and Videos},
|
||||
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
|
||||
journal={arXiv preprint},
|
||||
journal={arXiv preprint arXiv:2408.00714},
|
||||
url={https://arxiv.org/abs/2408.00714},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
||||
output_mode: str = "binary_mask",
|
||||
use_m2m: bool = False,
|
||||
multimask_output: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM 2 model, generates masks for the entire image.
|
||||
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
||||
self.use_m2m = use_m2m
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2AutomaticMaskGenerator): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
points = torch.as_tensor(points, device=self.predictor.device)
|
||||
points = torch.as_tensor(
|
||||
points, dtype=torch.float32, device=self.predictor.device
|
||||
)
|
||||
in_points = self.predictor._transforms.transform_coords(
|
||||
points, normalize=normalize, orig_hw=im_size
|
||||
)
|
||||
|
@@ -19,6 +19,7 @@ def build_sam2(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if apply_postprocessing:
|
||||
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
hydra_overrides = [
|
||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||
@@ -76,6 +78,44 @@ def build_sam2_video_predictor(
|
||||
return model
|
||||
|
||||
|
||||
def build_sam2_hf(model_id, **kwargs):
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_id_to_filenames = {
|
||||
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||
"facebook/sam2-hiera-base-plus": (
|
||||
"sam2_hiera_b+.yaml",
|
||||
"sam2_hiera_base_plus.pt",
|
||||
),
|
||||
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||
}
|
||||
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
||||
|
||||
|
||||
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_id_to_filenames = {
|
||||
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||
"facebook/sam2-hiera-base-plus": (
|
||||
"sam2_hiera_b+.yaml",
|
||||
"sam2_hiera_base_plus.pt",
|
||||
),
|
||||
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||
}
|
||||
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||
return build_sam2_video_predictor(
|
||||
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _load_checkpoint(model, ckpt_path):
|
||||
if ckpt_path is not None:
|
||||
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
||||
|
@@ -223,8 +223,8 @@ std::vector<torch::Tensor> get_connected_componnets(
|
||||
const uint32_t W = inputs.size(3);
|
||||
|
||||
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
|
||||
AT_ASSERTM((H % 2) == 0, "height must be a even number");
|
||||
AT_ASSERTM((W % 2) == 0, "width must be a even number");
|
||||
AT_ASSERTM((H % 2) == 0, "height must be an even number");
|
||||
AT_ASSERTM((W % 2) == 0, "width must be an even number");
|
||||
|
||||
// label must be uint32_t
|
||||
auto label_options =
|
||||
|
@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
@@ -14,12 +15,30 @@ import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
# Check whether Flash Attention is available (and use it by default)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
||||
ALLOW_ALL_KERNELS = False
|
||||
|
||||
|
||||
def sdp_kernel_context(dropout_p):
|
||||
"""
|
||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
||||
by default, but fall back to all available kernels if Flash Attention fails.
|
||||
"""
|
||||
if ALLOW_ALL_KERNELS:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
@@ -246,12 +265,19 @@ class Attention(nn.Module):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
@@ -313,12 +339,19 @@ class RoPEAttention(Attention):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
|
@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to calculate the image embedding for an image, and then
|
||||
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
|
||||
sam_model (Sam-2): The model to use for mask prediction.
|
||||
mask_threshold (float): The threshold to use when converting mask logits
|
||||
to binary masks. Masks are thresholded at 0 by default.
|
||||
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of fill_hole_area in low_res_masks.
|
||||
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
@@ -62,6 +65,23 @@ class SAM2ImagePredictor:
|
||||
(64, 64),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2ImagePredictor): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(
|
||||
self,
|
||||
@@ -163,7 +183,7 @@ class SAM2ImagePredictor:
|
||||
normalize_coords=True,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
||||
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
|
||||
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
||||
"""
|
||||
assert self._is_batch, "This function should only be used when in batched mode"
|
||||
if not self._is_image_set:
|
||||
|
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
@@ -43,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
offload_state_to_cpu=False,
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize a inference state."""
|
||||
"""Initialize an inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
@@ -64,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# the original video height and width, used for resizing final output scores
|
||||
inference_state["video_height"] = video_height
|
||||
inference_state["video_width"] = video_width
|
||||
inference_state["device"] = torch.device("cuda")
|
||||
inference_state["device"] = compute_device
|
||||
if offload_state_to_cpu:
|
||||
inference_state["storage_device"] = torch.device("cpu")
|
||||
else:
|
||||
inference_state["storage_device"] = torch.device("cuda")
|
||||
inference_state["storage_device"] = compute_device
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
@@ -103,6 +106,23 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
return inference_state
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2VideoPredictor): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_video_predictor_hf
|
||||
|
||||
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
||||
return sam_model
|
||||
|
||||
def _obj_id_to_idx(self, inference_state, obj_id):
|
||||
"""Map client-side object id to model-side object index."""
|
||||
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
@@ -146,29 +166,66 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
return len(inference_state["obj_idx_to_id"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_points(
|
||||
def add_new_points_or_box(
|
||||
self,
|
||||
inference_state,
|
||||
frame_idx,
|
||||
obj_id,
|
||||
points,
|
||||
labels,
|
||||
points=None,
|
||||
labels=None,
|
||||
clear_old_points=True,
|
||||
normalize_coords=True,
|
||||
box=None,
|
||||
):
|
||||
"""Add new points to a frame."""
|
||||
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
||||
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
||||
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
||||
|
||||
if not isinstance(points, torch.Tensor):
|
||||
if (points is not None) != (labels is not None):
|
||||
raise ValueError("points and labels must be provided together")
|
||||
if points is None and box is None:
|
||||
raise ValueError("at least one of points or box must be provided as input")
|
||||
|
||||
if points is None:
|
||||
points = torch.zeros(0, 2, dtype=torch.float32)
|
||||
elif not isinstance(points, torch.Tensor):
|
||||
points = torch.tensor(points, dtype=torch.float32)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
if labels is None:
|
||||
labels = torch.zeros(0, dtype=torch.int32)
|
||||
elif not isinstance(labels, torch.Tensor):
|
||||
labels = torch.tensor(labels, dtype=torch.int32)
|
||||
if points.dim() == 2:
|
||||
points = points.unsqueeze(0) # add batch dimension
|
||||
if labels.dim() == 1:
|
||||
labels = labels.unsqueeze(0) # add batch dimension
|
||||
|
||||
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
||||
# along with the user-provided points (consistent with how SAM 2 is trained).
|
||||
if box is not None:
|
||||
if not clear_old_points:
|
||||
raise ValueError(
|
||||
"cannot add box without clearing old points, since "
|
||||
"box prompt must be provided before any point prompt "
|
||||
"(please use clear_old_points=True instead)"
|
||||
)
|
||||
if inference_state["tracking_has_started"]:
|
||||
warnings.warn(
|
||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||
"use box prompt as an *initial* input before tracking, please call "
|
||||
"'reset_state' on the inference state to restart from scratch.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not isinstance(box, torch.Tensor):
|
||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||
box_coords = box.reshape(1, 2, 2)
|
||||
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
||||
box_labels = box_labels.reshape(1, 2)
|
||||
points = torch.cat([box_coords, points], dim=1)
|
||||
labels = torch.cat([box_labels, labels], dim=1)
|
||||
|
||||
if normalize_coords:
|
||||
video_H = inference_state["video_height"]
|
||||
video_W = inference_state["video_width"]
|
||||
@@ -215,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
|
||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
||||
device = inference_state["device"]
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||
current_out, _ = self._run_single_frame_inference(
|
||||
@@ -251,6 +309,10 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
return frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def add_new_points(self, *args, **kwargs):
|
||||
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
||||
return self.add_new_points_or_box(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_mask(
|
||||
self,
|
||||
@@ -527,16 +589,16 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points` or `add_new_mask`)
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temprary output across all objects on this frame
|
||||
# consolidate the temporary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
@@ -734,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
if backbone_out is None:
|
||||
# Cache miss -- we will run inference on a single image
|
||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
||||
device = inference_state["device"]
|
||||
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||
backbone_out = self.forward_image(image)
|
||||
# Cache the most recent frame's feature (for repeated interactions with
|
||||
# a frame; we can use an LRU cache for more frames in the future).
|
||||
|
@@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
|
||||
compute bounding box given an input mask
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
|
||||
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
|
||||
@@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
|
||||
A list of video frames to be load asynchronously without blocking session start.
|
||||
"""
|
||||
|
||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
||||
def __init__(
|
||||
self,
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
):
|
||||
self.img_paths = img_paths
|
||||
self.image_size = image_size
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
self.img_mean = img_mean
|
||||
self.img_std = img_std
|
||||
# items in `self._images` will be loaded asynchronously
|
||||
# items in `self.images` will be loaded asynchronously
|
||||
self.images = [None] * len(img_paths)
|
||||
# catch and raise any exceptions in the async loading thread
|
||||
self.exception = None
|
||||
# video_height and video_width be filled when loading the first image
|
||||
self.video_height = None
|
||||
self.video_width = None
|
||||
self.compute_device = compute_device
|
||||
|
||||
# load the first frame to fill video_height and video_width and also
|
||||
# to cache it (since it's most likely where the user will click)
|
||||
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
||||
img -= self.img_mean
|
||||
img /= self.img_std
|
||||
if not self.offload_video_to_cpu:
|
||||
img = img.cuda(non_blocking=True)
|
||||
img = img.to(self.compute_device, non_blocking=True)
|
||||
self.images[index] = img
|
||||
return img
|
||||
|
||||
@@ -167,6 +176,7 @@ def load_video_frames(
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
@@ -179,7 +189,15 @@ def load_video_frames(
|
||||
if isinstance(video_path, str) and os.path.isdir(video_path):
|
||||
jpg_folder = video_path
|
||||
else:
|
||||
raise NotImplementedError("Only JPEG frames are supported at this moment")
|
||||
raise NotImplementedError(
|
||||
"Only JPEG frames are supported at this moment. For video files, you may use "
|
||||
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
|
||||
"```\n"
|
||||
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
|
||||
"```\n"
|
||||
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
|
||||
"ffmpeg to start the JPEG file from 00000.jpg."
|
||||
)
|
||||
|
||||
frame_names = [
|
||||
p
|
||||
@@ -196,7 +214,12 @@ def load_video_frames(
|
||||
|
||||
if async_loading_frames:
|
||||
lazy_images = AsyncVideoFrameLoader(
|
||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
)
|
||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||
|
||||
@@ -204,9 +227,9 @@ def load_video_frames(
|
||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||
if not offload_video_to_cpu:
|
||||
images = images.cuda()
|
||||
img_mean = img_mean.cuda()
|
||||
img_std = img_std.cuda()
|
||||
images = images.to(compute_device)
|
||||
img_mean = img_mean.to(compute_device)
|
||||
img_std = img_std.to(compute_device)
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
@@ -220,10 +243,25 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
# Holes are those connected components in background with area <= self.max_area
|
||||
# (background regions are those with mask scores <= 0)
|
||||
assert max_area > 0, "max_area must be positive"
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
|
||||
input_mask = mask
|
||||
try:
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
mask = input_mask
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
|
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -78,22 +80,39 @@ class SAM2Transforms(nn.Module):
|
||||
from sam2.utils.misc import get_connected_components
|
||||
|
||||
masks = masks.float()
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
input_masks = masks
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
try:
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat <= self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat > self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
masks = input_masks
|
||||
|
||||
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
@@ -161,4 +161,4 @@ In the paper for the experiments on SA-V val and test, we run inference on the 2
|
||||
|
||||
The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0.
|
||||
|
||||
Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with there licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)).
|
||||
Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)).
|
||||
|
@@ -72,7 +72,7 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--do_not_skip_first_and_last_frame",
|
||||
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
|
||||
"Set this to true for evaluation on settings that doen't skip first and last frames",
|
||||
"Set this to true for evaluation on settings that doesn't skip first and last frames",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
|
@@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None):
|
||||
|
||||
assert not (
|
||||
width > w | height > h | abs(ar1 - ar2) > 0.01
|
||||
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
||||
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
|
||||
|
||||
e = np.zeros_like(seg)
|
||||
s = np.zeros_like(seg)
|
||||
|
103
setup.py
103
setup.py
@@ -3,9 +3,9 @@
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
# Package metadata
|
||||
NAME = "SAM 2"
|
||||
@@ -17,7 +17,7 @@ AUTHOR_EMAIL = "segment-anything@meta.com"
|
||||
LICENSE = "Apache 2.0"
|
||||
|
||||
# Read the contents of README file
|
||||
with open("README.md", "r") as f:
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
LONG_DESCRIPTION = f.read()
|
||||
|
||||
# Required dependencies
|
||||
@@ -36,22 +36,95 @@ EXTRA_PACKAGES = {
|
||||
"dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
|
||||
}
|
||||
|
||||
# By default, we also build the SAM 2 CUDA extension.
|
||||
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
|
||||
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
|
||||
# By default, we allow SAM 2 installation to proceed even with build errors.
|
||||
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
|
||||
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
|
||||
|
||||
# Catch and skip errors during extension building and print a warning message
|
||||
# (note that this message only shows up under verbose build mode
|
||||
# "pip install -v -e ." or "python setup.py build_ext -v")
|
||||
CUDA_ERROR_MSG = (
|
||||
"{}\n\n"
|
||||
"Failed to build the SAM 2 CUDA extension due to the error above. "
|
||||
"You can still use SAM 2 and it's OK to ignore the error above, although some "
|
||||
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
|
||||
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
|
||||
)
|
||||
|
||||
|
||||
def get_extensions():
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
"nvcc": [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
}
|
||||
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
||||
if not BUILD_CUDA:
|
||||
return []
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
||||
srcs = ["sam2/csrc/connected_components.cu"]
|
||||
compile_args = {
|
||||
"cxx": [],
|
||||
"nvcc": [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
],
|
||||
}
|
||||
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
||||
except Exception as e:
|
||||
if BUILD_ALLOW_ERRORS:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
ext_modules = []
|
||||
else:
|
||||
raise e
|
||||
|
||||
return ext_modules
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
|
||||
class BuildExtensionIgnoreErrors(BuildExtension):
|
||||
|
||||
def finalize_options(self):
|
||||
try:
|
||||
super().finalize_options()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def build_extensions(self):
|
||||
try:
|
||||
super().build_extensions()
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
|
||||
def get_ext_filename(self, ext_name):
|
||||
try:
|
||||
return super().get_ext_filename(ext_name)
|
||||
except Exception as e:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
self.extensions = []
|
||||
return "_C.so"
|
||||
|
||||
cmdclass = {
|
||||
"build_ext": (
|
||||
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
|
||||
if BUILD_ALLOW_ERRORS
|
||||
else BuildExtension.with_options(no_python_abi_suffix=True)
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
cmdclass = {}
|
||||
if BUILD_ALLOW_ERRORS:
|
||||
print(CUDA_ERROR_MSG.format(e))
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
# Setup configuration
|
||||
setup(
|
||||
name=NAME,
|
||||
@@ -64,9 +137,11 @@ setup(
|
||||
author_email=AUTHOR_EMAIL,
|
||||
license=LICENSE,
|
||||
packages=find_packages(exclude="notebooks"),
|
||||
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
extras_require=EXTRA_PACKAGES,
|
||||
python_requires=">=3.10.0",
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
|
||||
cmdclass=cmdclass,
|
||||
)
|
||||
|
@@ -250,7 +250,7 @@ def main():
|
||||
action="store_true",
|
||||
help="whether to use all available PNG files in input_mask_dir "
|
||||
"(default without this flag: just the first PNG file as input to the SAM 2 model; "
|
||||
"usually we don't need this flag, since semi-supervised VOS evalaution usually takes input from the first frame only)",
|
||||
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_obj_png_file",
|
||||
|
Reference in New Issue
Block a user