[New Feature] Support SAM 2.1 (#59)

* support sam 2.1

* refine config path and ckpt path

* update README
This commit is contained in:
Ren Tianhe
2024-10-10 14:55:50 +08:00
committed by GitHub
parent e899ad99e8
commit 82e503604f
340 changed files with 39100 additions and 608 deletions

View File

@@ -59,9 +59,6 @@ class SAM2Base(torch.nn.Module):
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
memory_temporal_stride_for_eval=1,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
non_overlap_masks_for_mem_enc=False,
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
@@ -73,6 +70,9 @@ class SAM2Base(torch.nn.Module):
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
proj_tpos_enc_in_obj_ptrs=False,
# whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
# (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
use_signed_tpos_enc_to_obj_ptrs=False,
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
only_obj_ptrs_in_the_past_for_eval=False,
@@ -88,6 +88,8 @@ class SAM2Base(torch.nn.Module):
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False,
# add no obj embedding to spatial frames
no_obj_embed_spatial: bool = False,
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
@@ -110,12 +112,13 @@ class SAM2Base(torch.nn.Module):
if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features
# with memories (and obj ptrs) from past frames
self.memory_attention = memory_attention
self.hidden_dim = memory_attention.d_model
self.hidden_dim = image_encoder.neck.d_model
# Part 3: memory encoder for the previous frame's outputs
self.memory_encoder = memory_encoder
@@ -170,9 +173,12 @@ class SAM2Base(torch.nn.Module):
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self.no_obj_embed_spatial = None
if no_obj_embed_spatial:
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
self._build_sam_heads()
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation
@@ -194,8 +200,8 @@ class SAM2Base(torch.nn.Module):
def forward(self, *args, **kwargs):
raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference."
"See notebooks/video_predictor_example.ipynb for an example."
"Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
"See notebooks/video_predictor_example.ipynb for an inference example."
)
def _build_sam_heads(self):
@@ -388,8 +394,6 @@ class SAM2Base(torch.nn.Module):
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
# Only hard possible with gt
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
@@ -513,6 +517,7 @@ class SAM2Base(torch.nn.Module):
return pix_feat
num_obj_ptr_tokens = 0
tpos_sign_mul = -1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone
@@ -528,9 +533,9 @@ class SAM2Base(torch.nn.Module):
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
r = self.memory_temporal_stride_for_eval
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
stride = 1 if self.training else self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1:
@@ -546,15 +551,15 @@ class SAM2Base(torch.nn.Module):
if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // r) * r
prev_frame_idx = ((frame_idx - 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // r) * r
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
@@ -593,7 +598,14 @@ class SAM2Base(torch.nn.Module):
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
(abs(frame_idx - t), out["obj_ptr"])
(
(
(frame_idx - t) * tpos_sign_mul
if self.use_signed_tpos_enc_to_obj_ptrs
else abs(frame_idx - t)
),
out["obj_ptr"],
)
for t, out in ptr_cond_outputs.items()
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
@@ -666,6 +678,7 @@ class SAM2Base(torch.nn.Module):
current_vision_feats,
feat_sizes,
pred_masks_high_res,
object_score_logits,
is_mask_from_pts,
):
"""Encode the current image and its prediction into a memory feature."""
@@ -698,9 +711,104 @@ class SAM2Base(torch.nn.Module):
)
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc
def _track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
):
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return current_out, sam_outputs, high_res_features, pix_feat
def _encode_memory_in_output(
self,
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
):
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
object_score_logits=object_score_logits,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
def track_step(
self,
frame_idx,
@@ -722,50 +830,20 @@ class SAM2Base(torch.nn.Module):
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None,
):
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
current_out, sam_outputs, _, _ = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
)
(
_,
_,
@@ -773,28 +851,28 @@ class SAM2Base(torch.nn.Module):
low_res_masks,
high_res_masks,
obj_ptr,
_,
object_score_logits,
) = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
if not self.training:
# Only add this in inference (to avoid unused param in activation checkpointing;
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
)
return current_out