Merge branch 'main' into patch-1

This commit is contained in:
Arun
2024-08-08 09:59:47 +05:30
committed by GitHub
4 changed files with 483 additions and 112 deletions

View File

@@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>) state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame # add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
# propagate the prompts to get masklets throughout the video # propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state): for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
... ...
``` ```
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
## Load from 🤗 Hugging Face ## Load from 🤗 Hugging Face
@@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>) state = predictor.init_state(<your_video>)
# add new prompts and instantly get the output on the same frame # add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
# propagate the prompts to get masklets throughout the video # propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state): for frame_idx, object_ids, masks in predictor.propagate_in_video(state):

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import warnings
from collections import OrderedDict from collections import OrderedDict
import torch import torch
@@ -163,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base):
return len(inference_state["obj_idx_to_id"]) return len(inference_state["obj_idx_to_id"])
@torch.inference_mode() @torch.inference_mode()
def add_new_points( def add_new_points_or_box(
self, self,
inference_state, inference_state,
frame_idx, frame_idx,
obj_id, obj_id,
points, points=None,
labels, labels=None,
clear_old_points=True, clear_old_points=True,
normalize_coords=True, normalize_coords=True,
box=None,
): ):
"""Add new points to a frame.""" """Add new points to a frame."""
obj_idx = self._obj_id_to_idx(inference_state, obj_id) obj_idx = self._obj_id_to_idx(inference_state, obj_id)
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
if not isinstance(points, torch.Tensor): if (points is not None) != (labels is not None):
raise ValueError("points and labels must be provided together")
if points is None and box is None:
raise ValueError("at least one of points or box must be provided as input")
if points is None:
points = torch.zeros(0, 2, dtype=torch.float32)
elif not isinstance(points, torch.Tensor):
points = torch.tensor(points, dtype=torch.float32) points = torch.tensor(points, dtype=torch.float32)
if not isinstance(labels, torch.Tensor): if labels is None:
labels = torch.zeros(0, dtype=torch.int32)
elif not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.int32) labels = torch.tensor(labels, dtype=torch.int32)
if points.dim() == 2: if points.dim() == 2:
points = points.unsqueeze(0) # add batch dimension points = points.unsqueeze(0) # add batch dimension
if labels.dim() == 1: if labels.dim() == 1:
labels = labels.unsqueeze(0) # add batch dimension labels = labels.unsqueeze(0) # add batch dimension
# If `box` is provided, we add it as the first two points with labels 2 and 3
# along with the user-provided points (consistent with how SAM 2 is trained).
if box is not None:
if not clear_old_points:
raise ValueError(
"cannot add box without clearing old points, since "
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
box_labels = box_labels.reshape(1, 2)
points = torch.cat([box_coords, points], dim=1)
labels = torch.cat([box_labels, labels], dim=1)
if normalize_coords: if normalize_coords:
video_H = inference_state["video_height"] video_H = inference_state["video_height"]
video_W = inference_state["video_width"] video_W = inference_state["video_width"]
@@ -268,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base):
) )
return frame_idx, obj_ids, video_res_masks return frame_idx, obj_ids, video_res_masks
def add_new_points(self, *args, **kwargs):
"""Deprecated method. Please use `add_new_points_or_box` instead."""
return self.add_new_points_or_box(*args, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def add_new_mask( def add_new_mask(
self, self,
@@ -548,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base):
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects # Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs # (these should be the frames that have just received clicks for mask inputs
# via `add_new_points` or `add_new_mask`) # via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set() temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values(): for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())

View File

@@ -44,11 +44,22 @@ BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
def get_extensions(): def get_extensions():
if not BUILD_CUDA: if not BUILD_CUDA:
return [] return []
try:
srcs = ["sam2/csrc/connected_components.cu"] srcs = ["sam2/csrc/connected_components.cu"]
compile_args = { compile_args = {
"cxx": [], "cxx": [],
@@ -60,39 +71,37 @@ def get_extensions():
], ],
} }
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
except Exception as e:
if BUILD_ALLOW_ERRORS:
print(CUDA_ERROR_MSG.format(e))
ext_modules = []
else:
raise e
return ext_modules return ext_modules
class BuildExtensionIgnoreErrors(BuildExtension): class BuildExtensionIgnoreErrors(BuildExtension):
# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
ERROR_MSG = (
"{}\n\n"
"Failed to build the SAM 2 CUDA extension due to the error above. "
"You can still use SAM 2, but some post-processing functionality may be limited "
"(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)
def finalize_options(self): def finalize_options(self):
try: try:
super().finalize_options() super().finalize_options()
except Exception as e: except Exception as e:
print(self.ERROR_MSG.format(e)) print(CUDA_ERROR_MSG.format(e))
self.extensions = [] self.extensions = []
def build_extensions(self): def build_extensions(self):
try: try:
super().build_extensions() super().build_extensions()
except Exception as e: except Exception as e:
print(self.ERROR_MSG.format(e)) print(CUDA_ERROR_MSG.format(e))
self.extensions = [] self.extensions = []
def get_ext_filename(self, ext_name): def get_ext_filename(self, ext_name):
try: try:
return super().get_ext_filename(ext_name) return super().get_ext_filename(ext_name)
except Exception as e: except Exception as e:
print(self.ERROR_MSG.format(e)) print(CUDA_ERROR_MSG.format(e))
self.extensions = [] self.extensions = []
return "_C.so" return "_C.so"