update to latest SAM 2

This commit is contained in:
rentainhe
2024-08-21 18:11:44 +08:00
parent 35efb4a5cb
commit 6e0ddadf7c
12 changed files with 140 additions and 87 deletions

View File

@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)

View File

@@ -16,7 +16,7 @@ from torch import nn
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
used by the Attention Is All You Need paper, generalized to work on images.
"""
def __init__(
@@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

View File

@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]