update to latest SAM 2
This commit is contained in:
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
Reference in New Issue
Block a user