support gsam2 image predictor model
This commit is contained in:
149
sam2/modeling/sam2_utils.py
Normal file
149
sam2/modeling/sam2_utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
||||
"""
|
||||
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
|
||||
that are temporally closest to the current frame at `frame_idx`. Here, we take
|
||||
- a) the closest conditioning frame before `frame_idx` (if any);
|
||||
- b) the closest conditioning frame after `frame_idx` (if any);
|
||||
- c) any other temporally closest conditioning frames until reaching a total
|
||||
of `max_cond_frame_num` conditioning frames.
|
||||
|
||||
Outputs:
|
||||
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
|
||||
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
|
||||
"""
|
||||
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
||||
selected_outputs = cond_frame_outputs
|
||||
unselected_outputs = {}
|
||||
else:
|
||||
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
||||
selected_outputs = {}
|
||||
|
||||
# the closest conditioning frame before `frame_idx` (if any)
|
||||
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
||||
if idx_before is not None:
|
||||
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
||||
|
||||
# the closest conditioning frame after `frame_idx` (if any)
|
||||
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
||||
if idx_after is not None:
|
||||
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
||||
|
||||
# add other temporally closest conditioning frames until reaching a total
|
||||
# of `max_cond_frame_num` conditioning frames.
|
||||
num_remain = max_cond_frame_num - len(selected_outputs)
|
||||
inds_remain = sorted(
|
||||
(t for t in cond_frame_outputs if t not in selected_outputs),
|
||||
key=lambda x: abs(x - frame_idx),
|
||||
)[:num_remain]
|
||||
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
||||
unselected_outputs = {
|
||||
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
|
||||
}
|
||||
|
||||
return selected_outputs, unselected_outputs
|
||||
|
||||
|
||||
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
||||
"""
|
||||
Get 1D sine positional embedding as in the original Transformer paper.
|
||||
"""
|
||||
pe_dim = dim // 2
|
||||
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
||||
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
||||
|
||||
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
||||
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
def get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
# adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
||||
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
activation: nn.Module = nn.ReLU,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||
)
|
||||
self.sigmoid_output = sigmoid_output
|
||||
self.act = activation()
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = F.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
Reference in New Issue
Block a user