SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
This commit is contained in:
@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
|
||||
temperature: int = 10000,
|
||||
normalize: bool = True,
|
||||
scale: Optional[float] = None,
|
||||
# Following settings only relevant
|
||||
# for warmping up cache for compilation
|
||||
warmup_cache: bool = True,
|
||||
image_size: int = 1024,
|
||||
strides: Tuple[int] = (4, 8, 16, 32),
|
||||
):
|
||||
super().__init__()
|
||||
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
||||
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
|
||||
self.scale = scale
|
||||
|
||||
self.cache = {}
|
||||
if warmup_cache and torch.cuda.is_available():
|
||||
# Warmup cache for cuda, to help with compilation
|
||||
device = torch.device("cuda")
|
||||
for stride in strides:
|
||||
cache_key = (image_size // stride, image_size // stride)
|
||||
self._pe(1, device, *cache_key)
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
# The positions are expected to be normalized
|
||||
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
def _pe(self, B, device, *cache_key):
|
||||
H, W = cache_key
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
||||
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||
|
||||
y_embed = (
|
||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
||||
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||
.view(1, -1, 1)
|
||||
.repeat(x.shape[0], 1, x.shape[-1])
|
||||
.repeat(B, 1, W)
|
||||
)
|
||||
x_embed = (
|
||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
||||
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
||||
.view(1, 1, -1)
|
||||
.repeat(x.shape[0], x.shape[-2], 1)
|
||||
.repeat(B, H, 1)
|
||||
)
|
||||
|
||||
if self.normalize:
|
||||
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
|
||||
self.cache[cache_key] = pos[0]
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
B = x.shape[0]
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
return self._pe(B, x.device, *cache_key)
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user