support gsam2 image predictor model
This commit is contained in:
5
sam2/modeling/backbones/__init__.py
Normal file
5
sam2/modeling/backbones/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
295
sam2/modeling/backbones/hieradet.py
Normal file
295
sam2/modeling/backbones/hieradet.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam2.modeling.backbones.utils import (
|
||||
PatchEmbed,
|
||||
window_partition,
|
||||
window_unpartition,
|
||||
)
|
||||
|
||||
from sam2.modeling.sam2_utils import DropPath, MLP
|
||||
|
||||
|
||||
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
||||
if pool is None:
|
||||
return x
|
||||
# (B, H, W, C) -> (B, C, H, W)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = pool(x)
|
||||
# (B, C, H', W') -> (B, H', W', C)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
if norm:
|
||||
x = norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
num_heads: int,
|
||||
q_pool: nn.Module = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (B, H * W, 3, nHead, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
||||
# q, k, v with shape (B, H * W, nheads, C)
|
||||
q, k, v = torch.unbind(qkv, 2)
|
||||
|
||||
# Q pooling (for downsample at stage changes)
|
||||
if self.q_pool:
|
||||
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
||||
H, W = q.shape[1:3] # downsampled shape
|
||||
q = q.reshape(B, H * W, self.num_heads, -1)
|
||||
|
||||
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
||||
x = F.scaled_dot_product_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
)
|
||||
# Transpose back
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(B, H, W, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||
q_stride: Tuple[int, int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
window_size: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
self.pool, self.q_stride = None, q_stride
|
||||
if self.q_stride:
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
||||
)
|
||||
|
||||
self.attn = MultiScaleAttention(
|
||||
dim,
|
||||
dim_out,
|
||||
num_heads=num_heads,
|
||||
q_pool=self.pool,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = MLP(
|
||||
dim_out,
|
||||
int(dim_out * mlp_ratio),
|
||||
dim_out,
|
||||
num_layers=2,
|
||||
activation=act_layer,
|
||||
)
|
||||
|
||||
if dim != dim_out:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x # B, H, W, C
|
||||
x = self.norm1(x)
|
||||
|
||||
# Skip connection
|
||||
if self.dim != self.dim_out:
|
||||
shortcut = do_pool(self.proj(x), self.pool)
|
||||
|
||||
# Window partition
|
||||
window_size = self.window_size
|
||||
if window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, window_size)
|
||||
|
||||
# Window Attention + Q Pooling (if stage change)
|
||||
x = self.attn(x)
|
||||
if self.q_stride:
|
||||
# Shapes have changed due to Q pooling
|
||||
window_size = self.window_size // self.q_stride[0]
|
||||
H, W = shortcut.shape[1:3]
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
pad_hw = (H + pad_h, W + pad_w)
|
||||
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
# MLP
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
"""
|
||||
Reference: https://arxiv.org/abs/2306.00989
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
drop_path_rate: float = 0.0, # stochastic depth
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
||||
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
||||
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
||||
head_mul: float = 2.0, # head_mul factor at stage shift
|
||||
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
||||
# window size per stage, when not using global att.
|
||||
window_spec: Tuple[int, ...] = (
|
||||
8,
|
||||
4,
|
||||
14,
|
||||
7,
|
||||
),
|
||||
# global attn in these blocks
|
||||
global_att_blocks: Tuple[int, ...] = (
|
||||
12,
|
||||
16,
|
||||
20,
|
||||
),
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert len(stages) == len(window_spec)
|
||||
self.window_spec = window_spec
|
||||
|
||||
depth = sum(stages)
|
||||
self.q_stride = q_stride
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
||||
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
# Which blocks have global att?
|
||||
self.global_att_blocks = global_att_blocks
|
||||
|
||||
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
||||
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
||||
)
|
||||
self.pos_embed_window = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
||||
)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
cur_stage = 1
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# lags by a block, so first block of
|
||||
# next stage uses an initial window size
|
||||
# of previous stage and final window size of current stage
|
||||
window_size = self.window_spec[cur_stage - 1]
|
||||
|
||||
if self.global_att_blocks is not None:
|
||||
window_size = 0 if i in self.global_att_blocks else window_size
|
||||
|
||||
if i - 1 in self.stage_ends:
|
||||
dim_out = int(embed_dim * dim_mul)
|
||||
num_heads = int(num_heads * head_mul)
|
||||
cur_stage += 1
|
||||
|
||||
block = MultiScaleBlock(
|
||||
dim=embed_dim,
|
||||
dim_out=dim_out,
|
||||
num_heads=num_heads,
|
||||
drop_path=dpr[i],
|
||||
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
embed_dim = dim_out
|
||||
self.blocks.append(block)
|
||||
|
||||
self.channel_list = (
|
||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||
if return_interm_layers
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||
pos_embed = pos_embed + window_embed.tile(
|
||||
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
||||
)
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.patch_embed(x)
|
||||
# x: (B, H, W, C)
|
||||
|
||||
# Add pos embed
|
||||
x = x + self._get_pos_embed(x.shape[1:3])
|
||||
|
||||
outputs = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if (i == self.stage_ends[-1]) or (
|
||||
i in self.stage_ends and self.return_interm_layers
|
||||
):
|
||||
feats = x.permute(0, 3, 1, 2)
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
133
sam2/modeling/backbones/image_encoder.py
Normal file
133
sam2/modeling/backbones/image_encoder.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
trunk: nn.Module,
|
||||
neck: nn.Module,
|
||||
scalp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.trunk = trunk
|
||||
self.neck = neck
|
||||
self.scalp = scalp
|
||||
assert (
|
||||
self.trunk.channel_list == self.neck.backbone_channel_list
|
||||
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
||||
|
||||
def forward(self, sample: torch.Tensor):
|
||||
# Forward through backbone
|
||||
features, pos = self.neck(self.trunk(sample))
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
||||
|
||||
src = features[-1]
|
||||
output = {
|
||||
"vision_features": src,
|
||||
"vision_pos_enc": pos,
|
||||
"backbone_fpn": features,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
"""
|
||||
A modified variant of Feature Pyramid Network (FPN) neck
|
||||
(we remove output conv and also do bicubic interpolation similar to ViT
|
||||
pos embed interpolation)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position_encoding: nn.Module,
|
||||
d_model: int,
|
||||
backbone_channel_list: List[int],
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
fpn_interp_model: str = "bilinear",
|
||||
fuse_type: str = "sum",
|
||||
fpn_top_down_levels: Optional[List[int]] = None,
|
||||
):
|
||||
"""Initialize the neck
|
||||
:param trunk: the backbone
|
||||
:param position_encoding: the positional encoding to use
|
||||
:param d_model: the dimension of the model
|
||||
:param neck_norm: the normalization to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.position_encoding = position_encoding
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=dim,
|
||||
out_channels=d_model,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
),
|
||||
)
|
||||
|
||||
self.convs.append(current)
|
||||
self.fpn_interp_model = fpn_interp_model
|
||||
assert fuse_type in ["sum", "avg"]
|
||||
self.fuse_type = fuse_type
|
||||
|
||||
# levels to have top-down features in its outputs
|
||||
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
||||
# have top-down propagation, while outputs of level 0 and level 1 have only
|
||||
# lateral features from the same backbone level.
|
||||
if fpn_top_down_levels is None:
|
||||
# default is to have top-down features on all levels
|
||||
fpn_top_down_levels = range(len(self.convs))
|
||||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: List[torch.Tensor]):
|
||||
|
||||
out = [None] * len(self.convs)
|
||||
pos = [None] * len(self.convs)
|
||||
assert len(xs) == len(self.convs)
|
||||
# fpn forward pass
|
||||
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
||||
prev_features = None
|
||||
# forward in top-down order (from low to high resolution)
|
||||
n = len(self.convs) - 1
|
||||
for i in range(n, -1, -1):
|
||||
x = xs[i]
|
||||
lateral_features = self.convs[n - i](x)
|
||||
if i in self.fpn_top_down_levels and prev_features is not None:
|
||||
top_down_features = F.interpolate(
|
||||
prev_features.to(dtype=torch.float32),
|
||||
scale_factor=2.0,
|
||||
mode=self.fpn_interp_model,
|
||||
align_corners=(
|
||||
None if self.fpn_interp_model == "nearest" else False
|
||||
),
|
||||
antialias=False,
|
||||
)
|
||||
prev_features = lateral_features + top_down_features
|
||||
if self.fuse_type == "avg":
|
||||
prev_features /= 2
|
||||
else:
|
||||
prev_features = lateral_features
|
||||
x_out = prev_features
|
||||
out[i] = x_out
|
||||
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
||||
|
||||
return out, pos
|
95
sam2/modeling/backbones/utils.py
Normal file
95
sam2/modeling/backbones/utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Some utilities for backbones, in particular for windowing"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = (
|
||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows, window_size, pad_hw, hw):
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(
|
||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, ...] = (7, 7),
|
||||
stride: Tuple[int, ...] = (4, 4),
|
||||
padding: Tuple[int, ...] = (3, 3),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
Reference in New Issue
Block a user