feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes … (#97)
* feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes #74) * update README
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
||||
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
||||
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame
|
||||
|
||||
|
||||
class SAM2VideoPredictor(SAM2Base):
|
||||
@@ -43,23 +43,33 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
@torch.inference_mode()
|
||||
def init_state(
|
||||
self,
|
||||
video_path,
|
||||
video_path=None,
|
||||
offload_video_to_cpu=False,
|
||||
offload_state_to_cpu=False,
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""Initialize an inference state."""
|
||||
compute_device = self.device # device of the model
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state = {}
|
||||
inference_state["images"] = images
|
||||
inference_state["num_frames"] = len(images)
|
||||
if video_path is not None:
|
||||
# Preload video frames from file
|
||||
images, video_height, video_width = load_video_frames(
|
||||
video_path=video_path,
|
||||
image_size=self.image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
inference_state["images"] = images
|
||||
inference_state["num_frames"] = len(images)
|
||||
else:
|
||||
# Real-time streaming mode
|
||||
print("Real-time streaming mode: waiting for first image input...")
|
||||
images = None
|
||||
video_height, video_width = None, None
|
||||
inference_state["images"] = None
|
||||
inference_state["num_frames"] = 0
|
||||
|
||||
# whether to offload the video frames to CPU memory
|
||||
# turning on this option saves the GPU memory with only a very small overhead
|
||||
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
||||
@@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = {}
|
||||
# Warm up the visual backbone and cache the image feature on frame 0
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
if video_path is not None:
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
|
||||
return inference_state
|
||||
|
||||
@classmethod
|
||||
@@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state, pred_masks
|
||||
)
|
||||
yield frame_idx, obj_ids, video_res_masks
|
||||
@torch.inference_mode()
|
||||
def add_new_frame(self, inference_state, new_image):
|
||||
"""
|
||||
Add a new frame to the inference state and cache its image features.
|
||||
Args:
|
||||
inference_state (dict): The current inference state containing cached frames, features, and tracking information.
|
||||
new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing).
|
||||
Returns:
|
||||
frame_idx (int): The index of the newly added frame within the inference state.
|
||||
"""
|
||||
device = inference_state["device"]
|
||||
|
||||
# Preprocess the input frame and convert it to a normalized tensor
|
||||
img_tensor, orig_h, orig_w = process_stream_frame(
|
||||
img_array=new_image,
|
||||
image_size=self.image_size,
|
||||
offload_to_cpu=False,
|
||||
compute_device=device,
|
||||
)
|
||||
|
||||
# Handle initialization of the image sequence if this is the first frame
|
||||
images = inference_state.get("images", None)
|
||||
if images is None or (isinstance(images, list) and len(images) == 0):
|
||||
# First frame: initialize image tensor batch
|
||||
inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
|
||||
else:
|
||||
# Append to existing tensor batch
|
||||
if isinstance(images, list):
|
||||
raise ValueError(
|
||||
"inference_state['images'] should be a Tensor, not a list after initialization."
|
||||
)
|
||||
|
||||
img_tensor = img_tensor.to(images.device)
|
||||
inference_state["images"] = torch.cat(
|
||||
[images, img_tensor.unsqueeze(0)], dim=0
|
||||
)
|
||||
|
||||
# Update frame count and compute new frame index
|
||||
inference_state["num_frames"] = inference_state["images"].shape[0]
|
||||
frame_idx = inference_state["num_frames"] - 1
|
||||
|
||||
# Cache visual features for the newly added frame
|
||||
image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W]
|
||||
backbone_out = self.forward_image(image_batch)
|
||||
inference_state["cached_features"][frame_idx] = (image_batch, backbone_out)
|
||||
|
||||
return frame_idx
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer_single_frame(self, inference_state, frame_idx):
|
||||
"""
|
||||
Run inference on a single frame using existing points/masks in the inference state.
|
||||
Args:
|
||||
inference_state (dict): The current state of the tracking process.
|
||||
frame_idx (int): Index of the frame to run inference on.
|
||||
Returns:
|
||||
frame_idx (int): Same as input; the index of the processed frame.
|
||||
obj_ids (list): List of currently tracked object IDs.
|
||||
video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame.
|
||||
"""
|
||||
if frame_idx >= inference_state["num_frames"]:
|
||||
raise ValueError(
|
||||
f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})."
|
||||
)
|
||||
|
||||
self.propagate_in_video_preflight(inference_state)
|
||||
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
|
||||
# Ensure that initial conditioning points exist
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
raise RuntimeError(
|
||||
"No conditioning points provided. Please add points before inference."
|
||||
)
|
||||
|
||||
# Decide whether to clear nearby memory based on number of objects
|
||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||
)
|
||||
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
|
||||
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
# If output is already consolidated with conditioning inputs
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
|
||||
if clear_non_cond_mem:
|
||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||
|
||||
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||
# If output was inferred without conditioning
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
|
||||
else:
|
||||
# Run model inference for this frame
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out, pred_masks = self._run_single_frame_inference(
|
||||
inference_state=inference_state,
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=False,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
# Organize per-object outputs and mark frame as tracked
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, current_out, storage_key
|
||||
)
|
||||
inference_state["frames_already_tracked"][frame_idx] = {"reverse": False}
|
||||
|
||||
# Convert output to original video resolution
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
inference_state, pred_masks
|
||||
)
|
||||
|
||||
return frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def _add_output_per_object(
|
||||
self, inference_state, frame_idx, current_out, storage_key
|
||||
|
@@ -8,6 +8,7 @@ import os
|
||||
import warnings
|
||||
from threading import Thread
|
||||
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -209,6 +210,74 @@ def load_video_frames(
|
||||
"Only MP4 video and JPEG folder are supported at this moment"
|
||||
)
|
||||
|
||||
def process_stream_frame(
|
||||
img_array: np.ndarray,
|
||||
image_size: int,
|
||||
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
||||
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
||||
offload_to_cpu: bool = False,
|
||||
compute_device: torch.device = torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Convert a raw image array (H,W,3 or 3,H,W) into a model‑ready tensor.
|
||||
Steps
|
||||
-----
|
||||
1. Resize the shorter side to `image_size`, keeping aspect ratio,
|
||||
then center‑crop/pad to `image_size` × `image_size`.
|
||||
2. Change layout to [3, H, W] and cast to float32 in [0,1].
|
||||
3. Normalise with ImageNet statistics.
|
||||
4. Optionally move to `compute_device`.
|
||||
Returns
|
||||
-------
|
||||
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
|
||||
orig_h : int
|
||||
orig_w : int
|
||||
"""
|
||||
|
||||
# ↪ uses your existing helper so behaviour matches the batch loader
|
||||
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
|
||||
|
||||
# Normalisation (done *after* potential device move for efficiency)
|
||||
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
|
||||
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
|
||||
|
||||
if not offload_to_cpu:
|
||||
img_tensor = img_tensor.to(compute_device)
|
||||
img_mean_t = img_mean_t.to(compute_device)
|
||||
img_std_t = img_std_t.to(compute_device)
|
||||
|
||||
img_tensor.sub_(img_mean_t).div_(img_std_t)
|
||||
|
||||
return img_tensor, orig_h, orig_w
|
||||
|
||||
|
||||
def _resize_and_convert_to_tensor(img_array, image_size):
|
||||
"""
|
||||
Resize the input image array and convert it into a tensor.
|
||||
Also return original image height and width.
|
||||
"""
|
||||
# Convert numpy array to PIL image and ensure RGB
|
||||
img_pil = Image.fromarray(img_array).convert("RGB")
|
||||
|
||||
# Save original size (PIL: size = (width, height))
|
||||
video_width, video_height = img_pil.size
|
||||
|
||||
# Resize with high-quality LANCZOS filter
|
||||
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert resized image back to numpy and then to float tensor
|
||||
img_resized_array = np.array(img_resized)
|
||||
|
||||
if img_resized_array.dtype == np.uint8:
|
||||
img_resized_array = img_resized_array / 255.0
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
|
||||
|
||||
# Convert to PyTorch tensor and permute to [C, H, W]
|
||||
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
|
||||
|
||||
return img_tensor, video_height, video_width
|
||||
|
||||
|
||||
def load_video_frames_from_jpg_images(
|
||||
video_path,
|
||||
|
Reference in New Issue
Block a user