SAM2.1
SAM2.1 checkpoints + training code + Demo
This commit is contained in:
@@ -28,6 +28,9 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
clear_non_cond_mem_around_input=False,
|
||||
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
||||
clear_non_cond_mem_for_multi_obj=False,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -35,6 +38,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
self.non_overlap_masks = non_overlap_masks
|
||||
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
||||
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def init_state(
|
||||
@@ -468,6 +472,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
"object_score_logits": torch.full(
|
||||
size=(batch_size, 1),
|
||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||
fill_value=10.0,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
}
|
||||
empty_mask_ptr = None
|
||||
for obj_idx in range(batch_size):
|
||||
@@ -512,6 +524,9 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
)
|
||||
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
||||
"object_score_logits"
|
||||
]
|
||||
|
||||
# Optionally, apply non-overlapping constraints on the consolidated scores
|
||||
# and rerun the memory encoder
|
||||
@@ -530,6 +545,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
high_res_masks=high_res_masks,
|
||||
object_score_logits=consolidated_out["object_score_logits"],
|
||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||
)
|
||||
consolidated_out["maskmem_features"] = maskmem_features
|
||||
@@ -749,6 +765,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
||||
}
|
||||
if maskmem_features is not None:
|
||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||
@@ -756,6 +773,77 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
@torch.inference_mode()
|
||||
def clear_all_prompts_in_frame(
|
||||
self, inference_state, frame_idx, obj_id, need_output=True
|
||||
):
|
||||
"""Remove all input points or mask in a specific frame for a given object."""
|
||||
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
||||
|
||||
# Clear the conditioning information on the given frame
|
||||
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
||||
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
||||
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
||||
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
|
||||
# Check and see if there are still any inputs left on this frame
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
frame_has_input = False
|
||||
for obj_idx2 in range(batch_size):
|
||||
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
|
||||
# If this frame has no remaining inputs for any objects, we further clear its
|
||||
# conditioning frame status
|
||||
if not frame_has_input:
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if out is not None:
|
||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
||||
# Similarly, do it for the sliced output on each object.
|
||||
for obj_idx2 in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
||||
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if obj_out is not None:
|
||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
||||
|
||||
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
self._reset_tracking_results(inference_state)
|
||||
|
||||
if not need_output:
|
||||
return
|
||||
# Finally, output updated masks per object (after removing the inputs above)
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
is_cond = any(
|
||||
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
||||
)
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
inference_state, consolidated_out["pred_masks_video_res"]
|
||||
)
|
||||
return frame_idx, obj_ids, video_res_masks
|
||||
|
||||
@torch.inference_mode()
|
||||
def reset_state(self, inference_state):
|
||||
"""Remove all input points or mask in all frames throughout the video."""
|
||||
@@ -878,17 +966,25 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
||||
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
||||
obj_ptr = current_out["obj_ptr"]
|
||||
object_score_logits = current_out["object_score_logits"]
|
||||
# make a compact version of this frame's output to reduce the state size
|
||||
compact_current_out = {
|
||||
"maskmem_features": maskmem_features,
|
||||
"maskmem_pos_enc": maskmem_pos_enc,
|
||||
"pred_masks": pred_masks,
|
||||
"obj_ptr": obj_ptr,
|
||||
"object_score_logits": object_score_logits,
|
||||
}
|
||||
return compact_current_out, pred_masks_gpu
|
||||
|
||||
def _run_memory_encoder(
|
||||
self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
|
||||
self,
|
||||
inference_state,
|
||||
frame_idx,
|
||||
batch_size,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
is_mask_from_pts,
|
||||
):
|
||||
"""
|
||||
Run the memory encoder on `high_res_masks`. This is usually after applying
|
||||
@@ -903,6 +999,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks,
|
||||
object_score_logits=object_score_logits,
|
||||
is_mask_from_pts=is_mask_from_pts,
|
||||
)
|
||||
|
||||
@@ -941,6 +1038,120 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
expanded_maskmem_pos_enc = None
|
||||
return expanded_maskmem_pos_enc
|
||||
|
||||
@torch.inference_mode()
|
||||
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
||||
"""
|
||||
Remove an object id from the tracking state. If strict is True, we check whether
|
||||
the object id actually exists and raise an error if it doesn't exist.
|
||||
"""
|
||||
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
updated_frames = []
|
||||
# Check whether this object_id to remove actually exists and possibly raise an error.
|
||||
if old_obj_idx_to_rm is None:
|
||||
if not strict:
|
||||
return inference_state["obj_ids"], updated_frames
|
||||
raise RuntimeError(
|
||||
f"Cannot remove object id {obj_id} as it doesn't exist. "
|
||||
f"All existing object ids: {inference_state['obj_ids']}."
|
||||
)
|
||||
|
||||
# If this is the only remaining object id, we simply reset the state.
|
||||
if len(inference_state["obj_id_to_idx"]) == 1:
|
||||
self.reset_state(inference_state)
|
||||
return inference_state["obj_ids"], updated_frames
|
||||
|
||||
# There are still remaining objects after removing this object id. In this case,
|
||||
# we need to delete the object storage from inference state tensors.
|
||||
# Step 0: clear the input on those frames where this object id has point or mask input
|
||||
# (note that this step is required as it might downgrade conditioning frames to
|
||||
# non-conditioning ones)
|
||||
obj_input_frames_inds = set()
|
||||
obj_input_frames_inds.update(
|
||||
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
||||
)
|
||||
obj_input_frames_inds.update(
|
||||
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
||||
)
|
||||
for frame_idx in obj_input_frames_inds:
|
||||
self.clear_all_prompts_in_frame(
|
||||
inference_state, frame_idx, obj_id, need_output=False
|
||||
)
|
||||
|
||||
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
||||
# since Step 0 still requires the old object id mappings in inference_state)
|
||||
old_obj_ids = inference_state["obj_ids"]
|
||||
old_obj_inds = list(range(len(old_obj_ids)))
|
||||
remain_old_obj_inds = old_obj_inds.copy()
|
||||
remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
||||
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
||||
new_obj_inds = list(range(len(new_obj_ids)))
|
||||
# build new mappings
|
||||
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
||||
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
||||
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
||||
inference_state["obj_ids"] = new_obj_ids
|
||||
|
||||
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
||||
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
||||
# it's already handled in Step 0)
|
||||
def _map_keys(container):
|
||||
new_kvs = []
|
||||
for k in old_obj_inds:
|
||||
v = container.pop(k)
|
||||
if k in old_idx_to_new_idx:
|
||||
new_kvs.append((old_idx_to_new_idx[k], v))
|
||||
container.update(new_kvs)
|
||||
|
||||
_map_keys(inference_state["point_inputs_per_obj"])
|
||||
_map_keys(inference_state["mask_inputs_per_obj"])
|
||||
_map_keys(inference_state["output_dict_per_obj"])
|
||||
_map_keys(inference_state["temp_output_dict_per_obj"])
|
||||
|
||||
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
||||
def _slice_state(output_dict, storage_key):
|
||||
for frame_idx, out in output_dict[storage_key].items():
|
||||
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
||||
out["maskmem_pos_enc"] = [
|
||||
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
||||
]
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
||||
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
||||
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
||||
out["object_score_logits"] = out["object_score_logits"][
|
||||
remain_old_obj_inds
|
||||
]
|
||||
# also update the per-object slices
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, out, storage_key
|
||||
)
|
||||
|
||||
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
||||
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
||||
|
||||
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||
# could show an updated mask for objects previously occluded by the object being removed
|
||||
if need_output:
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
for frame_idx in obj_input_frames_inds:
|
||||
is_cond = any(
|
||||
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
||||
)
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
inference_state, consolidated_out["pred_masks_video_res"]
|
||||
)
|
||||
updated_frames.append((frame_idx, video_res_masks))
|
||||
|
||||
return inference_state["obj_ids"], updated_frames
|
||||
|
||||
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
||||
"""
|
||||
Remove the non-conditioning memory around the input frame. When users provide
|
||||
|
Reference in New Issue
Block a user