feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes … (#97)

* feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes #74)

* update README
This commit is contained in:
Embodied Learner
2025-05-08 11:02:33 +08:00
committed by GitHub
parent 7fec804683
commit c5780dabeb
4 changed files with 766 additions and 12 deletions

View File

@@ -12,7 +12,7 @@ import torch
from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame
class SAM2VideoPredictor(SAM2Base):
@@ -43,23 +43,33 @@ class SAM2VideoPredictor(SAM2Base):
@torch.inference_mode()
def init_state(
self,
video_path,
video_path=None,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
):
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
if video_path is not None:
# Preload video frames from file
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
compute_device=compute_device,
)
inference_state["images"] = images
inference_state["num_frames"] = len(images)
else:
# Real-time streaming mode
print("Real-time streaming mode: waiting for first image input...")
images = None
video_height, video_width = None, None
inference_state["images"] = None
inference_state["num_frames"] = 0
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
@@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
if video_path is not None:
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@classmethod
@@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, pred_masks
)
yield frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def add_new_frame(self, inference_state, new_image):
"""
Add a new frame to the inference state and cache its image features.
Args:
inference_state (dict): The current inference state containing cached frames, features, and tracking information.
new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing).
Returns:
frame_idx (int): The index of the newly added frame within the inference state.
"""
device = inference_state["device"]
# Preprocess the input frame and convert it to a normalized tensor
img_tensor, orig_h, orig_w = process_stream_frame(
img_array=new_image,
image_size=self.image_size,
offload_to_cpu=False,
compute_device=device,
)
# Handle initialization of the image sequence if this is the first frame
images = inference_state.get("images", None)
if images is None or (isinstance(images, list) and len(images) == 0):
# First frame: initialize image tensor batch
inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
else:
# Append to existing tensor batch
if isinstance(images, list):
raise ValueError(
"inference_state['images'] should be a Tensor, not a list after initialization."
)
img_tensor = img_tensor.to(images.device)
inference_state["images"] = torch.cat(
[images, img_tensor.unsqueeze(0)], dim=0
)
# Update frame count and compute new frame index
inference_state["num_frames"] = inference_state["images"].shape[0]
frame_idx = inference_state["num_frames"] - 1
# Cache visual features for the newly added frame
image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W]
backbone_out = self.forward_image(image_batch)
inference_state["cached_features"][frame_idx] = (image_batch, backbone_out)
return frame_idx
@torch.inference_mode()
def infer_single_frame(self, inference_state, frame_idx):
"""
Run inference on a single frame using existing points/masks in the inference state.
Args:
inference_state (dict): The current state of the tracking process.
frame_idx (int): Index of the frame to run inference on.
Returns:
frame_idx (int): Same as input; the index of the processed frame.
obj_ids (list): List of currently tracked object IDs.
video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame.
"""
if frame_idx >= inference_state["num_frames"]:
raise ValueError(
f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})."
)
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
batch_size = self._get_obj_num(inference_state)
# Ensure that initial conditioning points exist
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError(
"No conditioning points provided. Please add points before inference."
)
# Decide whether to clear nearby memory based on number of objects
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
obj_ids = inference_state["obj_ids"]
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
# If output is already consolidated with conditioning inputs
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
# If output was inferred without conditioning
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
# Run model inference for this frame
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Organize per-object outputs and mark frame as tracked
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": False}
# Convert output to original video resolution
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
return frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key