Initial commit
This commit is contained in:
5
sam2/modeling/sam/__init__.py
Normal file
5
sam2/modeling/sam/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
295
sam2/modeling/sam/mask_decoder.py
Normal file
295
sam2/modeling/sam/mask_decoder.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sam2.modeling.sam2_utils import LayerNorm2d, MLP
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
use_high_res_features: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
dynamic_multimask_via_stability=False,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a
|
||||
transformer architecture.
|
||||
|
||||
Arguments:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
if self.pred_obj_scores:
|
||||
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
||||
),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(
|
||||
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
||||
),
|
||||
activation(),
|
||||
)
|
||||
self.use_high_res_features = use_high_res_features
|
||||
if use_high_res_features:
|
||||
self.conv_s0 = nn.Conv2d(
|
||||
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
|
||||
)
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
||||
)
|
||||
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||
for i in range(self.num_mask_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
self.num_mask_tokens,
|
||||
iou_head_depth,
|
||||
sigmoid_output=iou_prediction_use_sigmoid,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
||||
if pred_obj_scores_mlp:
|
||||
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
||||
|
||||
# When outputting a single mask, optionally we can dynamically fall back to the best
|
||||
# multimask output token if the single mask output token gives low stability scores.
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
torch.Tensor: batched SAM token for mask output
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
repeat_image=repeat_image,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
masks = masks[:, 1:, :, :]
|
||||
iou_pred = iou_pred[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
masks = masks[:, 0:1, :, :]
|
||||
iou_pred = iou_pred[:, 0:1]
|
||||
|
||||
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
||||
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
||||
else:
|
||||
# Take the mask output token. Here we *always* use the token for single mask output.
|
||||
# At test time, even if we track after 1-click (and using multimask_output=True),
|
||||
# we still take the single mask token here. The rationale is that we always track
|
||||
# after multiple clicks during training, so the past tokens seen during training
|
||||
# are always the single mask token (and we'll let it be the object-memory token).
|
||||
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred, sam_tokens_out, object_score_logits
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
# Concatenate output tokens
|
||||
s = 0
|
||||
if self.pred_obj_scores:
|
||||
output_tokens = torch.cat(
|
||||
[
|
||||
self.obj_score_token.weight,
|
||||
self.iou_token.weight,
|
||||
self.mask_tokens.weight,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
s = 1
|
||||
else:
|
||||
output_tokens = torch.cat(
|
||||
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
||||
)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(
|
||||
sparse_prompt_embeddings.size(0), -1, -1
|
||||
)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
if repeat_image:
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
else:
|
||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
assert (
|
||||
image_pe.size(0) == 1
|
||||
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, s, :]
|
||||
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
if not self.use_high_res_features:
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
else:
|
||||
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
||||
feat_s0, feat_s1 = high_res_features
|
||||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
||||
)
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
if self.pred_obj_scores:
|
||||
assert s == 1
|
||||
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
||||
else:
|
||||
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
||||
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
||||
|
||||
return masks, iou_pred, mask_tokens_out, object_score_logits
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""
|
||||
Compute stability scores of the mask logits based on the IoU between upper and
|
||||
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
|
||||
"""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
When outputting a single mask, if the stability score from the current single-mask
|
||||
output (based on output token 0) falls below a threshold, we instead select from
|
||||
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
||||
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(
|
||||
multimask_iou_scores.size(0), device=all_iou_scores.device
|
||||
)
|
||||
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
||||
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
182
sam2/modeling/sam/prompt_encoder.py
Normal file
182
sam2/modeling/sam/prompt_encoder.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
||||
|
||||
from sam2.modeling.sam2_utils import LayerNorm2d
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Arguments:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
point_embeddings = [
|
||||
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
||||
]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (
|
||||
4 * image_embedding_size[0],
|
||||
4 * image_embedding_size[1],
|
||||
)
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
||||
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(
|
||||
points, self.input_image_size
|
||||
)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(
|
||||
coords, self.input_image_size
|
||||
)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
mask_embedding = self.mask_downscaling(masks)
|
||||
return mask_embedding
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
|
||||
Arguments:
|
||||
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
||||
and labels to embed.
|
||||
boxes (torch.Tensor or none): boxes to embed
|
||||
masks (torch.Tensor or none): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty(
|
||||
(bs, 0, self.embed_dim), device=self._get_device()
|
||||
)
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
327
sam2/modeling/sam/transformer.py
Normal file
327
sam2/modeling/sam/transformer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape
|
||||
B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the processed point_embedding
|
||||
torch.Tensor: the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Arguments:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLP(
|
||||
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
||||
)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(
|
||||
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
dropout: float = 0.0,
|
||||
kv_in_dim: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert (
|
||||
self.internal_dim % num_heads == 0
|
||||
), "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
self.dropout_p = dropout
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RoPEAttention(Attention):
|
||||
"""Attention with rotary position encoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
rope_theta=10000.0,
|
||||
# whether to repeat q rope to match k length
|
||||
# this is needed for cross-attention to memories
|
||||
rope_k_repeat=False,
|
||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.compute_cis = partial(
|
||||
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
||||
)
|
||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||
self.freqs_cis = freqs_cis
|
||||
self.rope_k_repeat = rope_k_repeat
|
||||
|
||||
def forward(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
||||
) -> Tensor:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Apply rotary position encoding
|
||||
w = h = math.sqrt(q.shape[-2])
|
||||
self.freqs_cis = self.freqs_cis.to(q.device)
|
||||
if self.freqs_cis.shape[0] != q.shape[-2]:
|
||||
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
||||
if q.shape[-2] != k.shape[-2]:
|
||||
assert self.rope_k_repeat
|
||||
|
||||
num_k_rope = k.size(-2) - num_k_exclude_rope
|
||||
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
||||
q,
|
||||
k[:, :, :num_k_rope],
|
||||
freqs_cis=self.freqs_cis,
|
||||
repeat_freqs_k=self.rope_k_repeat,
|
||||
)
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
Reference in New Issue
Block a user