
This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
203 lines
7.4 KiB
Python
203 lines
7.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Optional, Tuple, Type
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
|
|
|
from sam2.modeling.sam2_utils import LayerNorm2d
|
|
|
|
|
|
class PromptEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
image_embedding_size: Tuple[int, int],
|
|
input_image_size: Tuple[int, int],
|
|
mask_in_chans: int,
|
|
activation: Type[nn.Module] = nn.GELU,
|
|
) -> None:
|
|
"""
|
|
Encodes prompts for input to SAM's mask decoder.
|
|
|
|
Arguments:
|
|
embed_dim (int): The prompts' embedding dimension
|
|
image_embedding_size (tuple(int, int)): The spatial size of the
|
|
image embedding, as (H, W).
|
|
input_image_size (int): The padded size of the image as input
|
|
to the image encoder, as (H, W).
|
|
mask_in_chans (int): The number of hidden channels used for
|
|
encoding input masks.
|
|
activation (nn.Module): The activation to use when encoding
|
|
input masks.
|
|
"""
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.input_image_size = input_image_size
|
|
self.image_embedding_size = image_embedding_size
|
|
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
|
|
|
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
|
point_embeddings = [
|
|
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
|
]
|
|
self.point_embeddings = nn.ModuleList(point_embeddings)
|
|
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
|
|
|
self.mask_input_size = (
|
|
4 * image_embedding_size[0],
|
|
4 * image_embedding_size[1],
|
|
)
|
|
self.mask_downscaling = nn.Sequential(
|
|
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
|
LayerNorm2d(mask_in_chans // 4),
|
|
activation(),
|
|
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
|
LayerNorm2d(mask_in_chans),
|
|
activation(),
|
|
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
|
)
|
|
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
|
|
|
def get_dense_pe(self) -> torch.Tensor:
|
|
"""
|
|
Returns the positional encoding used to encode point prompts,
|
|
applied to a dense set of points the shape of the image encoding.
|
|
|
|
Returns:
|
|
torch.Tensor: Positional encoding with shape
|
|
1x(embed_dim)x(embedding_h)x(embedding_w)
|
|
"""
|
|
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
|
|
|
def _embed_points(
|
|
self,
|
|
points: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
pad: bool,
|
|
) -> torch.Tensor:
|
|
"""Embeds point prompts."""
|
|
points = points + 0.5 # Shift to center of pixel
|
|
if pad:
|
|
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
|
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
|
points = torch.cat([points, padding_point], dim=1)
|
|
labels = torch.cat([labels, padding_label], dim=1)
|
|
point_embedding = self.pe_layer.forward_with_coords(
|
|
points, self.input_image_size
|
|
)
|
|
|
|
point_embedding = torch.where(
|
|
(labels == -1).unsqueeze(-1),
|
|
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
|
point_embedding,
|
|
)
|
|
point_embedding = torch.where(
|
|
(labels == 0).unsqueeze(-1),
|
|
point_embedding + self.point_embeddings[0].weight,
|
|
point_embedding,
|
|
)
|
|
point_embedding = torch.where(
|
|
(labels == 1).unsqueeze(-1),
|
|
point_embedding + self.point_embeddings[1].weight,
|
|
point_embedding,
|
|
)
|
|
point_embedding = torch.where(
|
|
(labels == 2).unsqueeze(-1),
|
|
point_embedding + self.point_embeddings[2].weight,
|
|
point_embedding,
|
|
)
|
|
point_embedding = torch.where(
|
|
(labels == 3).unsqueeze(-1),
|
|
point_embedding + self.point_embeddings[3].weight,
|
|
point_embedding,
|
|
)
|
|
return point_embedding
|
|
|
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
|
"""Embeds box prompts."""
|
|
boxes = boxes + 0.5 # Shift to center of pixel
|
|
coords = boxes.reshape(-1, 2, 2)
|
|
corner_embedding = self.pe_layer.forward_with_coords(
|
|
coords, self.input_image_size
|
|
)
|
|
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
|
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
|
return corner_embedding
|
|
|
|
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
|
"""Embeds mask inputs."""
|
|
mask_embedding = self.mask_downscaling(masks)
|
|
return mask_embedding
|
|
|
|
def _get_batch_size(
|
|
self,
|
|
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
boxes: Optional[torch.Tensor],
|
|
masks: Optional[torch.Tensor],
|
|
) -> int:
|
|
"""
|
|
Gets the batch size of the output given the batch size of the input prompts.
|
|
"""
|
|
if points is not None:
|
|
return points[0].shape[0]
|
|
elif boxes is not None:
|
|
return boxes.shape[0]
|
|
elif masks is not None:
|
|
return masks.shape[0]
|
|
else:
|
|
return 1
|
|
|
|
def _get_device(self) -> torch.device:
|
|
return self.point_embeddings[0].weight.device
|
|
|
|
def forward(
|
|
self,
|
|
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
boxes: Optional[torch.Tensor],
|
|
masks: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Embeds different types of prompts, returning both sparse and dense
|
|
embeddings.
|
|
|
|
Arguments:
|
|
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
|
and labels to embed.
|
|
boxes (torch.Tensor or none): boxes to embed
|
|
masks (torch.Tensor or none): masks to embed
|
|
|
|
Returns:
|
|
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
|
BxNx(embed_dim), where N is determined by the number of input points
|
|
and boxes.
|
|
torch.Tensor: dense embeddings for the masks, in the shape
|
|
Bx(embed_dim)x(embed_H)x(embed_W)
|
|
"""
|
|
bs = self._get_batch_size(points, boxes, masks)
|
|
sparse_embeddings = torch.empty(
|
|
(bs, 0, self.embed_dim), device=self._get_device()
|
|
)
|
|
if points is not None:
|
|
coords, labels = points
|
|
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
|
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
|
if boxes is not None:
|
|
box_embeddings = self._embed_boxes(boxes)
|
|
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
|
|
|
if masks is not None:
|
|
dense_embeddings = self._embed_masks(masks)
|
|
else:
|
|
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
|
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
|
)
|
|
|
|
return sparse_embeddings, dense_embeddings
|