update to latest SAM 2

This commit is contained in:
rentainhe
2024-08-21 18:11:44 +08:00
parent 35efb4a5cb
commit 6e0ddadf7c
12 changed files with 140 additions and 87 deletions

View File

@@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
output_mode: str = "binary_mask",
use_m2m: bool = False,
multimask_output: bool = True,
**kwargs,
) -> None:
"""
Using a SAM 2 model, generates masks for the entire image.
@@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
self.use_m2m = use_m2m
self.multimask_output = multimask_output
@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
"""
Load a pretrained model from the Hugging Face hub.
Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.
Returns:
(SAM2AutomaticMaskGenerator): The loaded model.
"""
from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model, **kwargs)
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
@@ -284,7 +302,9 @@ class SAM2AutomaticMaskGenerator:
orig_h, orig_w = orig_size
# Run model on this batch
points = torch.as_tensor(points, device=self.predictor.device)
points = torch.as_tensor(
points, dtype=torch.float32, device=self.predictor.device
)
in_points = self.predictor._transforms.transform_coords(
points, normalize=normalize, orig_hw=im_size
)

View File

@@ -19,6 +19,7 @@ def build_sam2(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
if apply_postprocessing:
@@ -47,6 +48,7 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",

View File

@@ -46,11 +46,7 @@ class MultiScaleAttention(nn.Module):
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)

View File

@@ -16,7 +16,7 @@ from torch import nn
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
used by the Attention Is All You Need paper, generalized to work on images.
"""
def __init__(
@@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

View File

@@ -567,10 +567,10 @@ class SAM2Base(torch.nn.Module):
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = (
@@ -642,7 +642,7 @@ class SAM2Base(torch.nn.Module):
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]

View File

@@ -24,6 +24,7 @@ class SAM2ImagePredictor:
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to calculate the image embedding for an image, and then
@@ -33,8 +34,10 @@ class SAM2ImagePredictor:
sam_model (Sam-2): The model to use for mask prediction.
mask_threshold (float): The threshold to use when converting mask logits
to binary masks. Masks are thresholded at 0 by default.
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
the maximum area of fill_hole_area in low_res_masks.
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__()
self.model = sam_model
@@ -77,7 +80,7 @@ class SAM2ImagePredictor:
from sam2.build_sam import build_sam2_hf
sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)
return cls(sam_model, **kwargs)
@torch.no_grad()
def set_image(
@@ -180,7 +183,7 @@ class SAM2ImagePredictor:
normalize_coords=True,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
"""
assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set:

View File

@@ -44,12 +44,14 @@ class SAM2VideoPredictor(SAM2Base):
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize a inference state."""
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
@@ -65,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
inference_state["storage_device"] = compute_device
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
@@ -119,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
from sam2.build_sam import build_sam2_video_predictor_hf
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)
return sam_model
def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
@@ -270,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
device = inference_state["device"]
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
@@ -586,7 +589,7 @@ class SAM2VideoPredictor(SAM2Base):
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outptus
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
@@ -595,7 +598,7 @@ class SAM2VideoPredictor(SAM2Base):
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temprary output across all objects on this frame
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
@@ -793,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).

View File

@@ -68,7 +68,7 @@ def mask_to_box(masks: torch.Tensor):
compute bounding box given an input mask
Inputs:
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
- masks: [B, 1, H, W] masks, dtype=torch.Tensor
Returns:
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
@@ -106,19 +106,28 @@ class AsyncVideoFrameLoader:
A list of video frames to be load asynchronously without blocking session start.
"""
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self._images` will be loaded asynchronously
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None
self.compute_device = compute_device
# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
img = img.to(self.compute_device, non_blocking=True)
self.images[index] = img
return img
@@ -167,6 +176,7 @@ def load_video_frames(
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
"""
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
@@ -179,12 +189,20 @@ def load_video_frames(
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")
raise NotImplementedError(
"Only JPEG frames are supported at this moment. For video files, you may use "
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
"```\n"
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
"```\n"
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
"ffmpeg to start the JPEG file from 00000.jpg."
)
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
@@ -196,7 +214,12 @@ def load_video_frames(
if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width
@@ -204,9 +227,9 @@ def load_video_frames(
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)
# normalize by mean and std
images -= img_mean
images /= img_std
@@ -230,8 +253,9 @@ def fill_holes_in_mask_scores(mask, max_area):
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,

View File

@@ -105,8 +105,9 @@ class SAM2Transforms(nn.Module):
except Exception as e:
# Skip the post-processing step if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
"functionality may be limited (which doesn't affect the results in most cases; see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,

View File

@@ -72,7 +72,7 @@ parser.add_argument(
parser.add_argument(
"--do_not_skip_first_and_last_frame",
help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. "
"Set this to true for evaluation on settings that doen't skip first and last frames",
"Set this to true for evaluation on settings that doesn't skip first and last frames",
action="store_true",
)

View File

@@ -183,7 +183,7 @@ def _seg2bmap(seg, width=None, height=None):
assert not (
width > w | height > h | abs(ar1 - ar2) > 0.01
), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
), "Cannot convert %dx%d seg to %dx%d bmap." % (w, h, width, height)
e = np.zeros_like(seg)
s = np.zeros_like(seg)

View File

@@ -6,7 +6,6 @@
import os
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Package metadata
NAME = "SAM 2"
@@ -18,35 +17,18 @@ AUTHOR_EMAIL = "segment-anything@meta.com"
LICENSE = "Apache 2.0"
# Read the contents of README file
with open("README.md", "r") as f:
with open("README.md", "r", encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
# Required dependencies
REQUIRED_PACKAGES = [
"torch>=2.3.1",
"torchvision>=0.18.1",
"transformers",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"huggingface_hub",
"diffusers[torch]==0.15.1",
"onnxruntime==1.14.1",
"onnx==1.13.1",
"ipykernel==6.16.2",
"scipy",
"gradio",
"openai",
"matplotlib>=3.9.1",
"opencv-python>=4.7.0",
"dds_cloudapi_sdk",
"addict",
"yapf",
"timm",
"supervision>=0.22.0",
"pycocotools",
]
EXTRA_PACKAGES = {
@@ -67,7 +49,8 @@ BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"You can still use SAM 2 and it's OK to ignore the error above, although some "
"post-processing functionality may be limited (which doesn't affect the results in most cases; "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
@@ -77,6 +60,8 @@ def get_extensions():
return []
try:
from torch.utils.cpp_extension import CUDAExtension
srcs = ["sam2/csrc/connected_components.cu"]
compile_args = {
"cxx": [],
@@ -98,7 +83,10 @@ def get_extensions():
return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension):
try:
from torch.utils.cpp_extension import BuildExtension
class BuildExtensionIgnoreErrors(BuildExtension):
def finalize_options(self):
try:
@@ -122,6 +110,20 @@ class BuildExtensionIgnoreErrors(BuildExtension):
self.extensions = []
return "_C.so"
cmdclass = {
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
)
}
except Exception as e:
cmdclass = {}
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
else:
raise e
# Setup configuration
setup(
@@ -135,15 +137,11 @@ setup(
author_email=AUTHOR_EMAIL,
license=LICENSE,
packages=find_packages(exclude="notebooks"),
package_data={"": ["*.yaml"]}, # SAM 2 configuration files
include_package_data=True,
install_requires=REQUIRED_PACKAGES,
extras_require=EXTRA_PACKAGES,
python_requires=">=3.10.0",
ext_modules=get_extensions(),
cmdclass={
"build_ext": (
BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
if BUILD_ALLOW_ERRORS
else BuildExtension.with_options(no_python_abi_suffix=True)
),
},
cmdclass=cmdclass,
)