Merge branch 'main' into patch-1
This commit is contained in:
@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
|
||||
output_mode: str = "binary_mask",
|
||||
use_m2m: bool = False,
|
||||
multimask_output: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM 2 model, generates masks for the entire image.
|
||||
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
|
||||
self.use_m2m = use_m2m
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2AutomaticMaskGenerator): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
points = torch.as_tensor(points, device=self.predictor.device)
|
||||
points = torch.as_tensor(
|
||||
points, dtype=torch.float32, device=self.predictor.device
|
||||
)
|
||||
in_points = self.predictor._transforms.transform_coords(
|
||||
points, normalize=normalize, orig_hw=im_size
|
||||
)
|
||||
|
@@ -19,6 +19,7 @@ def build_sam2(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if apply_postprocessing:
|
||||
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
|
||||
mode="eval",
|
||||
hydra_overrides_extra=[],
|
||||
apply_postprocessing=True,
|
||||
**kwargs,
|
||||
):
|
||||
hydra_overrides = [
|
||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||
@@ -76,6 +78,44 @@ def build_sam2_video_predictor(
|
||||
return model
|
||||
|
||||
|
||||
def build_sam2_hf(model_id, **kwargs):
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_id_to_filenames = {
|
||||
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||
"facebook/sam2-hiera-base-plus": (
|
||||
"sam2_hiera_b+.yaml",
|
||||
"sam2_hiera_base_plus.pt",
|
||||
),
|
||||
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||
}
|
||||
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
||||
|
||||
|
||||
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_id_to_filenames = {
|
||||
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
||||
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
||||
"facebook/sam2-hiera-base-plus": (
|
||||
"sam2_hiera_b+.yaml",
|
||||
"sam2_hiera_base_plus.pt",
|
||||
),
|
||||
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
||||
}
|
||||
config_name, checkpoint_name = model_id_to_filenames[model_id]
|
||||
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
||||
return build_sam2_video_predictor(
|
||||
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _load_checkpoint(model, ckpt_path):
|
||||
if ckpt_path is not None:
|
||||
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
||||
|
@@ -223,8 +223,8 @@ std::vector<torch::Tensor> get_connected_componnets(
|
||||
const uint32_t W = inputs.size(3);
|
||||
|
||||
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
|
||||
AT_ASSERTM((H % 2) == 0, "height must be a even number");
|
||||
AT_ASSERTM((W % 2) == 0, "width must be a even number");
|
||||
AT_ASSERTM((H % 2) == 0, "height must be an even number");
|
||||
AT_ASSERTM((W % 2) == 0, "width must be an even number");
|
||||
|
||||
// label must be uint32_t
|
||||
auto label_options =
|
||||
|
@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from functools import partial
|
||||
@@ -14,12 +15,30 @@ import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||
|
||||
from sam2.modeling.sam2_utils import MLP
|
||||
from sam2.utils.misc import get_sdpa_settings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
# Check whether Flash Attention is available (and use it by default)
|
||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
||||
ALLOW_ALL_KERNELS = False
|
||||
|
||||
|
||||
def sdp_kernel_context(dropout_p):
|
||||
"""
|
||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
||||
by default, but fall back to all available kernels if Flash Attention fails.
|
||||
"""
|
||||
if ALLOW_ALL_KERNELS:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
@@ -246,12 +265,19 @@ class Attention(nn.Module):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
@@ -313,12 +339,19 @@ class RoPEAttention(Attention):
|
||||
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
# Attention
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=USE_FLASH_ATTN,
|
||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
||||
enable_mem_efficient=OLD_GPU,
|
||||
):
|
||||
try:
|
||||
with sdp_kernel_context(dropout_p):
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
except Exception as e:
|
||||
# Fall back to all kernels if the Flash attention kernel fails
|
||||
warnings.warn(
|
||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
global ALLOW_ALL_KERNELS
|
||||
ALLOW_ALL_KERNELS = True
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
|
@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
|
@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to calculate the image embedding for an image, and then
|
||||
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
|
||||
sam_model (Sam-2): The model to use for mask prediction.
|
||||
mask_threshold (float): The threshold to use when converting mask logits
|
||||
to binary masks. Masks are thresholded at 0 by default.
|
||||
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of fill_hole_area in low_res_masks.
|
||||
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
@@ -62,6 +65,23 @@ class SAM2ImagePredictor:
|
||||
(64, 64),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2ImagePredictor): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_hf
|
||||
|
||||
sam_model = build_sam2_hf(model_id, **kwargs)
|
||||
return cls(sam_model, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(
|
||||
self,
|
||||
@@ -163,7 +183,7 @@ class SAM2ImagePredictor:
|
||||
normalize_coords=True,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
||||
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
|
||||
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
||||
"""
|
||||
assert self._is_batch, "This function should only be used when in batched mode"
|
||||
if not self._is_image_set:
|
||||
|
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
@@ -43,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
offload_state_to_cpu=False,
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize a inference state."""
|
||||
"""Initialize an inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
@@ -64,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# the original video height and width, used for resizing final output scores
|
||||
inference_state["video_height"] = video_height
|
||||
inference_state["video_width"] = video_width
|
||||
inference_state["device"] = torch.device("cuda")
|
||||
inference_state["device"] = compute_device
|
||||
if offload_state_to_cpu:
|
||||
inference_state["storage_device"] = torch.device("cpu")
|
||||
else:
|
||||
inference_state["storage_device"] = torch.device("cuda")
|
||||
inference_state["storage_device"] = compute_device
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
@@ -103,6 +106,23 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
return inference_state
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
||||
"""
|
||||
Load a pretrained model from the Hugging Face hub.
|
||||
|
||||
Arguments:
|
||||
model_id (str): The Hugging Face repository ID.
|
||||
**kwargs: Additional arguments to pass to the model constructor.
|
||||
|
||||
Returns:
|
||||
(SAM2VideoPredictor): The loaded model.
|
||||
"""
|
||||
from sam2.build_sam import build_sam2_video_predictor_hf
|
||||
|
||||
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
||||
return sam_model
|
||||
|
||||
def _obj_id_to_idx(self, inference_state, obj_id):
|
||||
"""Map client-side object id to model-side object index."""
|
||||
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
@@ -146,29 +166,66 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
return len(inference_state["obj_idx_to_id"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_points(
|
||||
def add_new_points_or_box(
|
||||
self,
|
||||
inference_state,
|
||||
frame_idx,
|
||||
obj_id,
|
||||
points,
|
||||
labels,
|
||||
points=None,
|
||||
labels=None,
|
||||
clear_old_points=True,
|
||||
normalize_coords=True,
|
||||
box=None,
|
||||
):
|
||||
"""Add new points to a frame."""
|
||||
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
||||
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
||||
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
||||
|
||||
if not isinstance(points, torch.Tensor):
|
||||
if (points is not None) != (labels is not None):
|
||||
raise ValueError("points and labels must be provided together")
|
||||
if points is None and box is None:
|
||||
raise ValueError("at least one of points or box must be provided as input")
|
||||
|
||||
if points is None:
|
||||
points = torch.zeros(0, 2, dtype=torch.float32)
|
||||
elif not isinstance(points, torch.Tensor):
|
||||
points = torch.tensor(points, dtype=torch.float32)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
if labels is None:
|
||||
labels = torch.zeros(0, dtype=torch.int32)
|
||||
elif not isinstance(labels, torch.Tensor):
|
||||
labels = torch.tensor(labels, dtype=torch.int32)
|
||||
if points.dim() == 2:
|
||||
points = points.unsqueeze(0) # add batch dimension
|
||||
if labels.dim() == 1:
|
||||
labels = labels.unsqueeze(0) # add batch dimension
|
||||
|
||||
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
||||
# along with the user-provided points (consistent with how SAM 2 is trained).
|
||||
if box is not None:
|
||||
if not clear_old_points:
|
||||
raise ValueError(
|
||||
"cannot add box without clearing old points, since "
|
||||
"box prompt must be provided before any point prompt "
|
||||
"(please use clear_old_points=True instead)"
|
||||
)
|
||||
if inference_state["tracking_has_started"]:
|
||||
warnings.warn(
|
||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||
"use box prompt as an *initial* input before tracking, please call "
|
||||
"'reset_state' on the inference state to restart from scratch.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not isinstance(box, torch.Tensor):
|
||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||
box_coords = box.reshape(1, 2, 2)
|
||||
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
||||
box_labels = box_labels.reshape(1, 2)
|
||||
points = torch.cat([box_coords, points], dim=1)
|
||||
labels = torch.cat([box_labels, labels], dim=1)
|
||||
|
||||
if normalize_coords:
|
||||
video_H = inference_state["video_height"]
|
||||
video_W = inference_state["video_width"]
|
||||
@@ -215,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
|
||||
if prev_out is not None and prev_out["pred_masks"] is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
||||
device = inference_state["device"]
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
||||
current_out, _ = self._run_single_frame_inference(
|
||||
@@ -251,6 +309,10 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
return frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def add_new_points(self, *args, **kwargs):
|
||||
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
||||
return self.add_new_points_or_box(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def add_new_mask(
|
||||
self,
|
||||
@@ -527,16 +589,16 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points` or `add_new_mask`)
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temprary output across all objects on this frame
|
||||
# consolidate the temporary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
@@ -734,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
if backbone_out is None:
|
||||
# Cache miss -- we will run inference on a single image
|
||||
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
||||
device = inference_state["device"]
|
||||
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
||||
backbone_out = self.forward_image(image)
|
||||
# Cache the most recent frame's feature (for repeated interactions with
|
||||
# a frame; we can use an LRU cache for more frames in the future).
|
||||
|
@@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
|
||||
compute bounding box given an input mask
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
|
||||
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
|
||||
@@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
|
||||
A list of video frames to be load asynchronously without blocking session start.
|
||||
"""
|
||||
|
||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
||||
def __init__(
|
||||
self,
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
):
|
||||
self.img_paths = img_paths
|
||||
self.image_size = image_size
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
self.img_mean = img_mean
|
||||
self.img_std = img_std
|
||||
# items in `self._images` will be loaded asynchronously
|
||||
# items in `self.images` will be loaded asynchronously
|
||||
self.images = [None] * len(img_paths)
|
||||
# catch and raise any exceptions in the async loading thread
|
||||
self.exception = None
|
||||
# video_height and video_width be filled when loading the first image
|
||||
self.video_height = None
|
||||
self.video_width = None
|
||||
self.compute_device = compute_device
|
||||
|
||||
# load the first frame to fill video_height and video_width and also
|
||||
# to cache it (since it's most likely where the user will click)
|
||||
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
||||
img -= self.img_mean
|
||||
img /= self.img_std
|
||||
if not self.offload_video_to_cpu:
|
||||
img = img.cuda(non_blocking=True)
|
||||
img = img.to(self.compute_device, non_blocking=True)
|
||||
self.images[index] = img
|
||||
return img
|
||||
|
||||
@@ -167,6 +176,7 @@ def load_video_frames(
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
@@ -179,7 +189,15 @@ def load_video_frames(
|
||||
if isinstance(video_path, str) and os.path.isdir(video_path):
|
||||
jpg_folder = video_path
|
||||
else:
|
||||
raise NotImplementedError("Only JPEG frames are supported at this moment")
|
||||
raise NotImplementedError(
|
||||
"Only JPEG frames are supported at this moment. For video files, you may use "
|
||||
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
|
||||
"```\n"
|
||||
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
|
||||
"```\n"
|
||||
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
|
||||
"ffmpeg to start the JPEG file from 00000.jpg."
|
||||
)
|
||||
|
||||
frame_names = [
|
||||
p
|
||||
@@ -196,7 +214,12 @@ def load_video_frames(
|
||||
|
||||
if async_loading_frames:
|
||||
lazy_images = AsyncVideoFrameLoader(
|
||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
||||
img_paths,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean,
|
||||
img_std,
|
||||
compute_device,
|
||||
)
|
||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||
|
||||
@@ -204,9 +227,9 @@ def load_video_frames(
|
||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||
if not offload_video_to_cpu:
|
||||
images = images.cuda()
|
||||
img_mean = img_mean.cuda()
|
||||
img_std = img_std.cuda()
|
||||
images = images.to(compute_device)
|
||||
img_mean = img_mean.to(compute_device)
|
||||
img_std = img_std.to(compute_device)
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
@@ -220,10 +243,25 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
# Holes are those connected components in background with area <= self.max_area
|
||||
# (background regions are those with mask scores <= 0)
|
||||
assert max_area > 0, "max_area must be positive"
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
|
||||
input_mask = mask
|
||||
try:
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
mask = input_mask
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
|
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -78,22 +80,39 @@ class SAM2Transforms(nn.Module):
|
||||
from sam2.utils.misc import get_connected_components
|
||||
|
||||
masks = masks.float()
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
input_masks = masks
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
try:
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat <= self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(
|
||||
mask_flat > self.mask_threshold
|
||||
)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
masks = input_masks
|
||||
|
||||
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
Reference in New Issue
Block a user