feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes … (#97)

* feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes #74)

* update README
This commit is contained in:
Embodied Learner
2025-05-08 11:02:33 +08:00
committed by GitHub
parent 7fec804683
commit c5780dabeb
4 changed files with 766 additions and 12 deletions

View File

@@ -335,6 +335,16 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
``` ```
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
```bash
python grounded_sam2_tracking_camera_with_continuous_id.py
```
## Grounded SAM 2 Florence-2 Demos ## Grounded SAM 2 Florence-2 Demos
### Grounded SAM 2 Florence-2 Image Demo ### Grounded SAM 2 Florence-2 Image Demo

View File

@@ -0,0 +1,536 @@
import copy
import os
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
# Setup environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GroundingDinoPredictor:
"""
Wrapper for using a GroundingDINO model for zero-shot object detection.
"""
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
"""
Initialize the GroundingDINO predictor.
Args:
model_id (str): HuggingFace model ID to load.
device (str): Device to run the model on ('cuda' or 'cpu').
"""
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
self.device = device
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
device
)
def predict(
self,
image: "PIL.Image.Image",
text_prompts: str,
box_threshold=0.25,
text_threshold=0.25,
):
"""
Perform object detection using text prompts.
Args:
image (PIL.Image.Image): Input RGB image.
text_prompts (str): Text prompt describing target objects.
box_threshold (float): Confidence threshold for box selection.
text_threshold (float): Confidence threshold for text match.
Returns:
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
"""
inputs = self.processor(
images=image, text=text_prompts, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]],
)
return results[0]["boxes"], results[0]["labels"]
class SAM2ImageSegmentor:
"""
Wrapper class for SAM2-based segmentation given bounding boxes.
"""
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
"""
Initialize the SAM2 image segmentor.
Args:
sam_model_cfg (str): Path to the SAM2 config file.
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
device (str): Device to load the model on ('cuda' or 'cpu').
"""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.device = device
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
self.predictor = SAM2ImagePredictor(sam_model)
def set_image(self, image: np.ndarray):
"""
Set the input image for segmentation.
Args:
image (np.ndarray): RGB image array with shape (H, W, 3).
"""
self.predictor.set_image(image)
def predict_masks_from_boxes(self, boxes: torch.Tensor):
"""
Predict segmentation masks from given bounding boxes.
Args:
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- masks: Binary masks per box, shape (N, H, W)
- scores: Confidence scores for each mask
- logits: Raw logits from the model
"""
masks, scores, logits = self.predictor.predict(
point_coords=None,
point_labels=None,
box=boxes,
multimask_output=False,
)
# Normalize shape to (N, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
return masks, scores, logits
class IncrementalObjectTracker:
def __init__(
self,
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text="car.",
detection_interval=20,
):
"""
Initialize an incremental object tracker using GroundingDINO and SAM2.
Args:
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
sam2_model_cfg (str): Path to SAM2 model config file.
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
device (str): Device to run the models on ('cuda' or 'cpu').
prompt_text (str): Initial text prompt for detection.
detection_interval (int): Frame interval between full detections.
"""
self.device = device
self.detection_interval = detection_interval
self.prompt_text = prompt_text
# Load models
self.grounding_predictor = GroundingDinoPredictor(
model_id=grounding_model_id, device=device
)
self.sam2_segmentor = SAM2ImageSegmentor(
sam_model_cfg=sam2_model_cfg,
sam_model_ckpt=sam2_ckpt_path,
device=device,
)
self.video_predictor = build_sam2_video_predictor(
sam2_model_cfg, sam2_ckpt_path
)
# Initialize inference state
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
self.total_frames = 0
self.objects_count = 0
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
# Store tracking results
self.last_mask_dict = MaskDictionaryModel()
self.track_dict = MaskDictionaryModel()
def add_image(self, image_np: np.ndarray):
"""
Add a new image frame to the tracker and perform detection or tracking update.
Args:
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
Returns:
np.ndarray: Annotated image with object masks and labels.
"""
import numpy as np
from PIL import Image
img_pil = Image.fromarray(image_np)
# Step 1: Perform detection every detection_interval frames
if self.total_frames % self.detection_interval == 0:
if (
self.inference_state["video_height"] is None
or self.inference_state["video_width"] is None
):
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
print(
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
)
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
# 1.1 GroundingDINO object detection
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
if boxes.shape[0] == 0:
return
# 1.2 SAM2 segmentation from detection boxes
self.sam2_segmentor.set_image(image_np)
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
# 1.3 Build MaskDictionaryModel
mask_dict = MaskDictionaryModel(
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
)
mask_dict.add_new_frame_annotation(
mask_list=torch.tensor(masks).to(self.device),
box_list=torch.tensor(boxes),
label_list=labels,
)
# 1.4 Object ID tracking and IOU-based update
self.objects_count = mask_dict.update_masks(
tracking_annotation_dict=self.last_mask_dict,
iou_threshold=0.3,
objects_count=self.objects_count,
)
# 1.5 Reset video tracker state
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
self.video_predictor.reset_state(self.inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, _, _ = self.video_predictor.add_new_mask(
self.inference_state,
frame_idx,
object_id,
object_info.mask,
)
self.track_dict = copy.deepcopy(mask_dict)
self.last_mask_dict = mask_dict
else:
# Step 2: Use incremental tracking for intermediate frames
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
# Step 3: Tracking propagation using the video predictor
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
inference_state=self.inference_state,
frame_idx=frame_idx,
)
# Step 4: Update the mask dictionary based on tracked masks
frame_masks = MaskDictionaryModel()
for i, obj_id in enumerate(obj_ids):
out_mask = video_res_masks[i] > 0.0
object_info = ObjectInfo(
instance_id=obj_id,
mask=out_mask[0],
class_name=self.track_dict.get_target_class_name(obj_id),
logit=self.track_dict.get_target_logit(obj_id),
)
object_info.update_box()
frame_masks.labels[obj_id] = object_info
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
self.last_mask_dict = copy.deepcopy(frame_masks)
# Step 5: Build mask array
H, W = image_np.shape[:2]
mask_img = torch.zeros((H, W), dtype=torch.int32)
for obj_id, obj_info in self.last_mask_dict.labels.items():
mask_img[obj_info.mask == True] = obj_id
mask_array = mask_img.cpu().numpy()
# Step 6: Visualization
annotated_frame = self.visualize_frame_with_mask_and_metadata(
image_np=image_np,
mask_array=mask_array,
json_metadata=self.last_mask_dict.to_dict(),
)
print(f"[Tracker] Total processed frames: {self.total_frames}")
self.total_frames += 1
torch.cuda.empty_cache()
return annotated_frame
def set_prompt(self, new_prompt: str):
"""
Dynamically update the GroundingDINO prompt and reset tracking state
to force a new object detection.
"""
self.prompt_text = new_prompt
self.total_frames = 0 # Trigger immediate re-detection
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
self.inference_state["video_height"] = None
self.inference_state["video_width"] = None
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
"""
Save the current mask, metadata, raw image, and annotated result.
Args:
output_dir (str): The root output directory.
raw_image (np.ndarray, optional): The original input image (RGB).
"""
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
image_data_dir = os.path.join(output_dir, "images")
vis_data_dir = os.path.join(output_dir, "result")
os.makedirs(mask_data_dir, exist_ok=True)
os.makedirs(json_data_dir, exist_ok=True)
os.makedirs(image_data_dir, exist_ok=True)
os.makedirs(vis_data_dir, exist_ok=True)
frame_masks = self.last_mask_dict
# Ensure mask_name is valid
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
base_name = f"image_{self.total_frames:05d}"
# Save segmentation mask
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
for obj_id, obj_info in frame_masks.labels.items():
mask_img[obj_info.mask == True] = obj_id
np.save(
os.path.join(mask_data_dir, frame_masks.mask_name),
mask_img.numpy().astype(np.uint16),
)
# Save metadata as JSON
json_path = os.path.join(json_data_dir, base_name + ".json")
frame_masks.to_json(json_path)
# Save raw input image
if raw_image is not None:
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
# Save annotated image with mask, bounding boxes, and labels
annotated_image = self.visualize_frame_with_mask_and_metadata(
image_np=raw_image,
mask_array=mask_img.numpy().astype(np.uint16),
json_metadata=frame_masks.to_dict(),
)
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
)
print(
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
)
def visualize_frame_with_mask_and_metadata(
self,
image_np: np.ndarray,
mask_array: np.ndarray,
json_metadata: dict,
):
image = image_np.copy()
H, W = image.shape[:2]
# Step 1: Parse metadata and build object entries
metadata_lookup = json_metadata.get("labels", {})
all_object_ids = []
all_object_boxes = []
all_object_classes = []
all_object_masks = []
for obj_id_str, obj_info in metadata_lookup.items():
instance_id = obj_info.get("instance_id")
if instance_id is None or instance_id == 0:
continue
if instance_id not in np.unique(mask_array):
continue
object_mask = mask_array == instance_id
all_object_ids.append(instance_id)
x1 = obj_info.get("x1", 0)
y1 = obj_info.get("y1", 0)
x2 = obj_info.get("x2", 0)
y2 = obj_info.get("y2", 0)
all_object_boxes.append([x1, y1, x2, y2])
all_object_classes.append(obj_info.get("class_name", "unknown"))
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
# Step 2: Check if valid objects exist
if len(all_object_ids) == 0:
print("No valid object instances found in metadata.")
return image
# Step 3: Sort by instance ID
paired = list(
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
)
paired.sort(key=lambda x: x[0])
all_object_ids = [p[0] for p in paired]
all_object_boxes = [p[1] for p in paired]
all_object_masks = [p[2] for p in paired]
all_object_classes = [p[3] for p in paired]
# Step 4: Build detections
all_object_masks = np.concatenate(all_object_masks, axis=0)
detections = sv.Detections(
xyxy=np.array(all_object_boxes),
mask=all_object_masks,
class_id=np.array(all_object_ids, dtype=np.int32),
)
labels = [
f"{instance_id}: {class_name}"
for instance_id, class_name in zip(all_object_ids, all_object_classes)
]
# Step 5: Annotate image
annotated_frame = image.copy()
mask_annotator = sv.MaskAnnotator()
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
annotated_frame = box_annotator.annotate(annotated_frame, detections)
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
return annotated_frame
import os
import cv2
import torch
from utils.common_utils import CommonUtils
def main():
# Parameter settings
output_dir = "./outputs"
prompt_text = "hand."
detection_interval = 20
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
os.makedirs(output_dir, exist_ok=True)
# Initialize the object tracker
tracker = IncrementalObjectTracker(
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text=prompt_text,
detection_interval=detection_interval,
)
tracker.set_prompt("person.")
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("[Error] Cannot open camera.")
return
print("[Info] Camera opened. Press 'q' to quit.")
frame_idx = 0
try:
while True:
ret, frame = cap.read()
if not ret:
print("[Warning] Failed to capture frame.")
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
print(f"[Frame {frame_idx}] Processing live frame...")
process_image = tracker.add_image(frame_rgb)
if process_image is None or not isinstance(process_image, np.ndarray):
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
frame_idx += 1
continue
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
# cv2.imshow("Live Inference", process_image_bgr)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[Info] Quit signal received.")
# break
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
frame_idx += 1
if frame_idx >= max_frames:
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
break
except KeyboardInterrupt:
print("[Info] Interrupted by user (Ctrl+C).")
finally:
cap.release()
cv2.destroyAllWindows()
print("[Done] Live inference complete.")
if __name__ == "__main__":
main()

View File

@@ -12,7 +12,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame
class SAM2VideoPredictor(SAM2Base): class SAM2VideoPredictor(SAM2Base):
@@ -43,13 +43,16 @@ class SAM2VideoPredictor(SAM2Base):
@torch.inference_mode() @torch.inference_mode()
def init_state( def init_state(
self, self,
video_path, video_path=None,
offload_video_to_cpu=False, offload_video_to_cpu=False,
offload_state_to_cpu=False, offload_state_to_cpu=False,
async_loading_frames=False, async_loading_frames=False,
): ):
"""Initialize an inference state.""" """Initialize an inference state."""
compute_device = self.device # device of the model compute_device = self.device # device of the model
inference_state = {}
if video_path is not None:
# Preload video frames from file
images, video_height, video_width = load_video_frames( images, video_height, video_width = load_video_frames(
video_path=video_path, video_path=video_path,
image_size=self.image_size, image_size=self.image_size,
@@ -57,9 +60,16 @@ class SAM2VideoPredictor(SAM2Base):
async_loading_frames=async_loading_frames, async_loading_frames=async_loading_frames,
compute_device=compute_device, compute_device=compute_device,
) )
inference_state = {}
inference_state["images"] = images inference_state["images"] = images
inference_state["num_frames"] = len(images) inference_state["num_frames"] = len(images)
else:
# Real-time streaming mode
print("Real-time streaming mode: waiting for first image input...")
images = None
video_height, video_width = None, None
inference_state["images"] = None
inference_state["num_frames"] = 0
# whether to offload the video frames to CPU memory # whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead # turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu inference_state["offload_video_to_cpu"] = offload_video_to_cpu
@@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["tracking_has_started"] = False inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {} inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0 # Warm up the visual backbone and cache the image feature on frame 0
if video_path is not None:
self._get_image_feature(inference_state, frame_idx=0, batch_size=1) self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state return inference_state
@classmethod @classmethod
@@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, pred_masks inference_state, pred_masks
) )
yield frame_idx, obj_ids, video_res_masks yield frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def add_new_frame(self, inference_state, new_image):
"""
Add a new frame to the inference state and cache its image features.
Args:
inference_state (dict): The current inference state containing cached frames, features, and tracking information.
new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing).
Returns:
frame_idx (int): The index of the newly added frame within the inference state.
"""
device = inference_state["device"]
# Preprocess the input frame and convert it to a normalized tensor
img_tensor, orig_h, orig_w = process_stream_frame(
img_array=new_image,
image_size=self.image_size,
offload_to_cpu=False,
compute_device=device,
)
# Handle initialization of the image sequence if this is the first frame
images = inference_state.get("images", None)
if images is None or (isinstance(images, list) and len(images) == 0):
# First frame: initialize image tensor batch
inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
else:
# Append to existing tensor batch
if isinstance(images, list):
raise ValueError(
"inference_state['images'] should be a Tensor, not a list after initialization."
)
img_tensor = img_tensor.to(images.device)
inference_state["images"] = torch.cat(
[images, img_tensor.unsqueeze(0)], dim=0
)
# Update frame count and compute new frame index
inference_state["num_frames"] = inference_state["images"].shape[0]
frame_idx = inference_state["num_frames"] - 1
# Cache visual features for the newly added frame
image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W]
backbone_out = self.forward_image(image_batch)
inference_state["cached_features"][frame_idx] = (image_batch, backbone_out)
return frame_idx
@torch.inference_mode()
def infer_single_frame(self, inference_state, frame_idx):
"""
Run inference on a single frame using existing points/masks in the inference state.
Args:
inference_state (dict): The current state of the tracking process.
frame_idx (int): Index of the frame to run inference on.
Returns:
frame_idx (int): Same as input; the index of the processed frame.
obj_ids (list): List of currently tracked object IDs.
video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame.
"""
if frame_idx >= inference_state["num_frames"]:
raise ValueError(
f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})."
)
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
batch_size = self._get_obj_num(inference_state)
# Ensure that initial conditioning points exist
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError(
"No conditioning points provided. Please add points before inference."
)
# Decide whether to clear nearby memory based on number of objects
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
obj_ids = inference_state["obj_ids"]
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
# If output is already consolidated with conditioning inputs
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
# If output was inferred without conditioning
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
# Run model inference for this frame
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Organize per-object outputs and mark frame as tracked
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": False}
# Convert output to original video resolution
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
return frame_idx, obj_ids, video_res_masks
def _add_output_per_object( def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key self, inference_state, frame_idx, current_out, storage_key

View File

@@ -8,6 +8,7 @@ import os
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@@ -209,6 +210,74 @@ def load_video_frames(
"Only MP4 video and JPEG folder are supported at this moment" "Only MP4 video and JPEG folder are supported at this moment"
) )
def process_stream_frame(
img_array: np.ndarray,
image_size: int,
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
offload_to_cpu: bool = False,
compute_device: torch.device = torch.device("cuda"),
):
"""
Convert a raw image array (H,W,3 or 3,H,W) into a modelready tensor.
Steps
-----
1. Resize the shorter side to `image_size`, keeping aspect ratio,
then centercrop/pad to `image_size` × `image_size`.
2. Change layout to [3, H, W] and cast to float32 in [0,1].
3. Normalise with ImageNet statistics.
4. Optionally move to `compute_device`.
Returns
-------
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
orig_h : int
orig_w : int
"""
# ↪ uses your existing helper so behaviour matches the batch loader
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
# Normalisation (done *after* potential device move for efficiency)
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if not offload_to_cpu:
img_tensor = img_tensor.to(compute_device)
img_mean_t = img_mean_t.to(compute_device)
img_std_t = img_std_t.to(compute_device)
img_tensor.sub_(img_mean_t).div_(img_std_t)
return img_tensor, orig_h, orig_w
def _resize_and_convert_to_tensor(img_array, image_size):
"""
Resize the input image array and convert it into a tensor.
Also return original image height and width.
"""
# Convert numpy array to PIL image and ensure RGB
img_pil = Image.fromarray(img_array).convert("RGB")
# Save original size (PIL: size = (width, height))
video_width, video_height = img_pil.size
# Resize with high-quality LANCZOS filter
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
# Convert resized image back to numpy and then to float tensor
img_resized_array = np.array(img_resized)
if img_resized_array.dtype == np.uint8:
img_resized_array = img_resized_array / 255.0
else:
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
# Convert to PyTorch tensor and permute to [C, H, W]
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
return img_tensor, video_height, video_width
def load_video_frames_from_jpg_images( def load_video_frames_from_jpg_images(
video_path, video_path,