Make it optional to build CUDA extension for SAM 2; also fallback to all available kernels if Flash Attention fails (#155)
In this PR, we make it optional to build the SAM 2 CUDA extension, in observation that many users encounter difficulties with the CUDA compilation step. 1. During installation, we catch build errors and print a warning message. We also allow explicitly turning off the CUDA extension building with `SAM2_BUILD_CUDA=0`. 2. At runtime, we catch CUDA kernel errors from connected components and print a warning on skipping the post processing step. We also fall back to the all available kernels if the Flash Attention kernel fails.
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -78,22 +80,38 @@ class SAM2Transforms(nn.Module):
|
||||
from sam2.utils.misc import get_connected_components
|
||||
|
||||
masks = masks.float()
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
input_masks = masks
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
try:
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat <= self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat > self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
masks = input_masks
|
||||
|
||||
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
Reference in New Issue
Block a user