update to latest SAM 2
This commit is contained in:
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
|
Reference in New Issue
Block a user