better support for non-CUDA devices (CPU, MPS) (#192)

This commit is contained in:
Ronghang Hu
2024-08-12 10:46:50 -07:00
committed by GitHub
parent 778e112740
commit 1034ee2a1a
8 changed files with 213 additions and 377 deletions

View File

@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (