update to latest SAM 2
This commit is contained in:
@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
|
Reference in New Issue
Block a user