Files
Grounded-SAM-2/sam2/utils/transforms.py
Ronghang Hu 6f7e700c37 Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step.
1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`.
2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step.

We also fall back to the all available kernels if the Flash Attention kernel fails.
2024-08-06 10:52:01 -07:00

118 lines
4.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Normalize, Resize, ToTensor
class SAM2Transforms(nn.Module):
def __init__(
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
):
"""
Transforms for SAM2.
"""
super().__init__()
self.resolution = resolution
self.mask_threshold = mask_threshold
self.max_hole_area = max_hole_area
self.max_sprinkle_area = max_sprinkle_area
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.to_tensor = ToTensor()
self.transforms = torch.jit.script(
nn.Sequential(
Resize((self.resolution, self.resolution)),
Normalize(self.mean, self.std),
)
)
def __call__(self, x):
x = self.to_tensor(x)
return self.transforms(x)
def forward_batch(self, img_list):
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
img_batch = torch.stack(img_batch, dim=0)
return img_batch
def transform_coords(
self, coords: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
Returns
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
"""
if normalize:
assert orig_hw is not None
h, w = orig_hw
coords = coords.clone()
coords[..., 0] = coords[..., 0] / w
coords[..., 1] = coords[..., 1] / h
coords = coords * self.resolution # unnormalize coords
return coords
def transform_boxes(
self, boxes: torch.Tensor, normalize=False, orig_hw=None
) -> torch.Tensor:
"""
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
from sam2.utils.misc import get_connected_components
masks = masks.float()
input_masks = masks
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
try:
if self.max_hole_area > 0:
# Holes are those connected components in background with area <= self.fill_hole_area
# (background regions are those with mask scores <= self.mask_threshold)
labels, areas = get_connected_components(
mask_flat <= self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_hole_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with a small positive mask score (10.0) to change them to foreground.
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
if self.max_sprinkle_area > 0:
labels, areas = get_connected_components(
mask_flat > self.mask_threshold
)
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
is_hole = is_hole.reshape_as(masks)
# We fill holes with negative mask score (-10.0) to change them to background.
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
masks = input_masks
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks