diff --git a/README.md b/README.md index d4a799f..8e00ae6 100644 --- a/README.md +++ b/README.md @@ -335,6 +335,16 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py ``` +### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream) + +This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream. + +```bash +python grounded_sam2_tracking_camera_with_continuous_id.py +``` + + + ## Grounded SAM 2 Florence-2 Demos ### Grounded SAM 2 Florence-2 Image Demo diff --git a/grounded_sam2_tracking_camera_with_continuous_id.py b/grounded_sam2_tracking_camera_with_continuous_id.py new file mode 100644 index 0000000..bd8d31d --- /dev/null +++ b/grounded_sam2_tracking_camera_with_continuous_id.py @@ -0,0 +1,536 @@ +import copy +import os + +import cv2 +import numpy as np +import supervision as sv +import torch +from PIL import Image +from sam2.build_sam import build_sam2, build_sam2_video_predictor +from sam2.sam2_image_predictor import SAM2ImagePredictor +from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor +from utils.common_utils import CommonUtils +from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo +from utils.track_utils import sample_points_from_masks +from utils.video_utils import create_video_from_images + +# Setup environment +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() +if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +class GroundingDinoPredictor: + """ + Wrapper for using a GroundingDINO model for zero-shot object detection. + """ + + def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"): + """ + Initialize the GroundingDINO predictor. + Args: + model_id (str): HuggingFace model ID to load. + device (str): Device to run the model on ('cuda' or 'cpu'). + """ + from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor + + self.device = device + self.processor = AutoProcessor.from_pretrained(model_id) + self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( + device + ) + + def predict( + self, + image: "PIL.Image.Image", + text_prompts: str, + box_threshold=0.25, + text_threshold=0.25, + ): + """ + Perform object detection using text prompts. + Args: + image (PIL.Image.Image): Input RGB image. + text_prompts (str): Text prompt describing target objects. + box_threshold (float): Confidence threshold for box selection. + text_threshold (float): Confidence threshold for text match. + Returns: + Tuple[Tensor, List[str]]: Bounding boxes and matched class labels. + """ + inputs = self.processor( + images=image, text=text_prompts, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + outputs = self.model(**inputs) + + results = self.processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=box_threshold, + text_threshold=text_threshold, + target_sizes=[image.size[::-1]], + ) + + return results[0]["boxes"], results[0]["labels"] + + +class SAM2ImageSegmentor: + """ + Wrapper class for SAM2-based segmentation given bounding boxes. + """ + + def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"): + """ + Initialize the SAM2 image segmentor. + Args: + sam_model_cfg (str): Path to the SAM2 config file. + sam_model_ckpt (str): Path to the SAM2 checkpoint file. + device (str): Device to load the model on ('cuda' or 'cpu'). + """ + from sam2.build_sam import build_sam2 + from sam2.sam2_image_predictor import SAM2ImagePredictor + + self.device = device + sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device) + self.predictor = SAM2ImagePredictor(sam_model) + + def set_image(self, image: np.ndarray): + """ + Set the input image for segmentation. + Args: + image (np.ndarray): RGB image array with shape (H, W, 3). + """ + self.predictor.set_image(image) + + def predict_masks_from_boxes(self, boxes: torch.Tensor): + """ + Predict segmentation masks from given bounding boxes. + Args: + boxes (torch.Tensor): Bounding boxes as (N, 4) tensor. + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: + - masks: Binary masks per box, shape (N, H, W) + - scores: Confidence scores for each mask + - logits: Raw logits from the model + """ + masks, scores, logits = self.predictor.predict( + point_coords=None, + point_labels=None, + box=boxes, + multimask_output=False, + ) + + # Normalize shape to (N, H, W) + if masks.ndim == 2: + masks = masks[None] + scores = scores[None] + logits = logits[None] + elif masks.ndim == 4: + masks = masks.squeeze(1) + + return masks, scores, logits + + +class IncrementalObjectTracker: + def __init__( + self, + grounding_model_id="IDEA-Research/grounding-dino-tiny", + sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml", + sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt", + device="cuda", + prompt_text="car.", + detection_interval=20, + ): + """ + Initialize an incremental object tracker using GroundingDINO and SAM2. + Args: + grounding_model_id (str): HuggingFace model ID for GroundingDINO. + sam2_model_cfg (str): Path to SAM2 model config file. + sam2_ckpt_path (str): Path to SAM2 model checkpoint. + device (str): Device to run the models on ('cuda' or 'cpu'). + prompt_text (str): Initial text prompt for detection. + detection_interval (int): Frame interval between full detections. + """ + self.device = device + self.detection_interval = detection_interval + self.prompt_text = prompt_text + + # Load models + self.grounding_predictor = GroundingDinoPredictor( + model_id=grounding_model_id, device=device + ) + self.sam2_segmentor = SAM2ImageSegmentor( + sam_model_cfg=sam2_model_cfg, + sam_model_ckpt=sam2_ckpt_path, + device=device, + ) + self.video_predictor = build_sam2_video_predictor( + sam2_model_cfg, sam2_ckpt_path + ) + + # Initialize inference state + self.inference_state = self.video_predictor.init_state() + self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device) + self.total_frames = 0 + self.objects_count = 0 + self.frame_cache_limit = detection_interval - 1 # or higher depending on memory + + # Store tracking results + self.last_mask_dict = MaskDictionaryModel() + self.track_dict = MaskDictionaryModel() + + def add_image(self, image_np: np.ndarray): + """ + Add a new image frame to the tracker and perform detection or tracking update. + Args: + image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8. + Returns: + np.ndarray: Annotated image with object masks and labels. + """ + import numpy as np + from PIL import Image + + img_pil = Image.fromarray(image_np) + + # Step 1: Perform detection every detection_interval frames + if self.total_frames % self.detection_interval == 0: + if ( + self.inference_state["video_height"] is None + or self.inference_state["video_width"] is None + ): + ( + self.inference_state["video_height"], + self.inference_state["video_width"], + ) = image_np.shape[:2] + + if self.inference_state["images"].shape[0] > self.frame_cache_limit: + print( + f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory." + ) + self.inference_state = self.video_predictor.init_state() + self.inference_state["images"] = torch.empty( + (0, 3, 1024, 1024), device=self.device + ) + ( + self.inference_state["video_height"], + self.inference_state["video_width"], + ) = image_np.shape[:2] + + # 1.1 GroundingDINO object detection + boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text) + if boxes.shape[0] == 0: + return + + # 1.2 SAM2 segmentation from detection boxes + self.sam2_segmentor.set_image(image_np) + masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes) + + # 1.3 Build MaskDictionaryModel + mask_dict = MaskDictionaryModel( + promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy" + ) + mask_dict.add_new_frame_annotation( + mask_list=torch.tensor(masks).to(self.device), + box_list=torch.tensor(boxes), + label_list=labels, + ) + + # 1.4 Object ID tracking and IOU-based update + self.objects_count = mask_dict.update_masks( + tracking_annotation_dict=self.last_mask_dict, + iou_threshold=0.3, + objects_count=self.objects_count, + ) + + # 1.5 Reset video tracker state + frame_idx = self.video_predictor.add_new_frame( + self.inference_state, image_np + ) + self.video_predictor.reset_state(self.inference_state) + + for object_id, object_info in mask_dict.labels.items(): + frame_idx, _, _ = self.video_predictor.add_new_mask( + self.inference_state, + frame_idx, + object_id, + object_info.mask, + ) + + self.track_dict = copy.deepcopy(mask_dict) + self.last_mask_dict = mask_dict + + else: + # Step 2: Use incremental tracking for intermediate frames + frame_idx = self.video_predictor.add_new_frame( + self.inference_state, image_np + ) + + # Step 3: Tracking propagation using the video predictor + frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame( + inference_state=self.inference_state, + frame_idx=frame_idx, + ) + + # Step 4: Update the mask dictionary based on tracked masks + frame_masks = MaskDictionaryModel() + for i, obj_id in enumerate(obj_ids): + out_mask = video_res_masks[i] > 0.0 + object_info = ObjectInfo( + instance_id=obj_id, + mask=out_mask[0], + class_name=self.track_dict.get_target_class_name(obj_id), + logit=self.track_dict.get_target_logit(obj_id), + ) + object_info.update_box() + frame_masks.labels[obj_id] = object_info + frame_masks.mask_name = f"mask_{frame_idx:05d}.npy" + frame_masks.mask_height = out_mask.shape[-2] + frame_masks.mask_width = out_mask.shape[-1] + + self.last_mask_dict = copy.deepcopy(frame_masks) + + # Step 5: Build mask array + H, W = image_np.shape[:2] + mask_img = torch.zeros((H, W), dtype=torch.int32) + for obj_id, obj_info in self.last_mask_dict.labels.items(): + mask_img[obj_info.mask == True] = obj_id + + mask_array = mask_img.cpu().numpy() + + # Step 6: Visualization + annotated_frame = self.visualize_frame_with_mask_and_metadata( + image_np=image_np, + mask_array=mask_array, + json_metadata=self.last_mask_dict.to_dict(), + ) + + print(f"[Tracker] Total processed frames: {self.total_frames}") + self.total_frames += 1 + torch.cuda.empty_cache() + return annotated_frame + + def set_prompt(self, new_prompt: str): + """ + Dynamically update the GroundingDINO prompt and reset tracking state + to force a new object detection. + """ + self.prompt_text = new_prompt + self.total_frames = 0 # Trigger immediate re-detection + self.inference_state = self.video_predictor.init_state() + self.inference_state["images"] = torch.empty( + (0, 3, 1024, 1024), device=self.device + ) + self.inference_state["video_height"] = None + self.inference_state["video_width"] = None + + print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.") + + def save_current_state(self, output_dir, raw_image: np.ndarray = None): + """ + Save the current mask, metadata, raw image, and annotated result. + Args: + output_dir (str): The root output directory. + raw_image (np.ndarray, optional): The original input image (RGB). + """ + mask_data_dir = os.path.join(output_dir, "mask_data") + json_data_dir = os.path.join(output_dir, "json_data") + image_data_dir = os.path.join(output_dir, "images") + vis_data_dir = os.path.join(output_dir, "result") + + os.makedirs(mask_data_dir, exist_ok=True) + os.makedirs(json_data_dir, exist_ok=True) + os.makedirs(image_data_dir, exist_ok=True) + os.makedirs(vis_data_dir, exist_ok=True) + + frame_masks = self.last_mask_dict + + # Ensure mask_name is valid + if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"): + frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy" + + base_name = f"image_{self.total_frames:05d}" + + # Save segmentation mask + mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width) + for obj_id, obj_info in frame_masks.labels.items(): + mask_img[obj_info.mask == True] = obj_id + np.save( + os.path.join(mask_data_dir, frame_masks.mask_name), + mask_img.numpy().astype(np.uint16), + ) + + # Save metadata as JSON + json_path = os.path.join(json_data_dir, base_name + ".json") + frame_masks.to_json(json_path) + + # Save raw input image + if raw_image is not None: + image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr) + + # Save annotated image with mask, bounding boxes, and labels + annotated_image = self.visualize_frame_with_mask_and_metadata( + image_np=raw_image, + mask_array=mask_img.numpy().astype(np.uint16), + json_metadata=frame_masks.to_dict(), + ) + annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR) + cv2.imwrite( + os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr + ) + print( + f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully." + ) + + def visualize_frame_with_mask_and_metadata( + self, + image_np: np.ndarray, + mask_array: np.ndarray, + json_metadata: dict, + ): + image = image_np.copy() + H, W = image.shape[:2] + + # Step 1: Parse metadata and build object entries + metadata_lookup = json_metadata.get("labels", {}) + + all_object_ids = [] + all_object_boxes = [] + all_object_classes = [] + all_object_masks = [] + + for obj_id_str, obj_info in metadata_lookup.items(): + instance_id = obj_info.get("instance_id") + if instance_id is None or instance_id == 0: + continue + if instance_id not in np.unique(mask_array): + continue + + object_mask = mask_array == instance_id + all_object_ids.append(instance_id) + x1 = obj_info.get("x1", 0) + y1 = obj_info.get("y1", 0) + x2 = obj_info.get("x2", 0) + y2 = obj_info.get("y2", 0) + all_object_boxes.append([x1, y1, x2, y2]) + all_object_classes.append(obj_info.get("class_name", "unknown")) + all_object_masks.append(object_mask[None]) # Shape (1, H, W) + + # Step 2: Check if valid objects exist + if len(all_object_ids) == 0: + print("No valid object instances found in metadata.") + return image + + # Step 3: Sort by instance ID + paired = list( + zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes) + ) + paired.sort(key=lambda x: x[0]) + + all_object_ids = [p[0] for p in paired] + all_object_boxes = [p[1] for p in paired] + all_object_masks = [p[2] for p in paired] + all_object_classes = [p[3] for p in paired] + + # Step 4: Build detections + all_object_masks = np.concatenate(all_object_masks, axis=0) + detections = sv.Detections( + xyxy=np.array(all_object_boxes), + mask=all_object_masks, + class_id=np.array(all_object_ids, dtype=np.int32), + ) + labels = [ + f"{instance_id}: {class_name}" + for instance_id, class_name in zip(all_object_ids, all_object_classes) + ] + + # Step 5: Annotate image + annotated_frame = image.copy() + mask_annotator = sv.MaskAnnotator() + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator() + + annotated_frame = mask_annotator.annotate(annotated_frame, detections) + annotated_frame = box_annotator.annotate(annotated_frame, detections) + annotated_frame = label_annotator.annotate(annotated_frame, detections, labels) + + return annotated_frame + + +import os + +import cv2 +import torch +from utils.common_utils import CommonUtils + + +def main(): + # Parameter settings + output_dir = "./outputs" + prompt_text = "hand." + detection_interval = 20 + max_frames = 300 # Maximum number of frames to process (prevents infinite loop) + + os.makedirs(output_dir, exist_ok=True) + + # Initialize the object tracker + tracker = IncrementalObjectTracker( + grounding_model_id="IDEA-Research/grounding-dino-tiny", + sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml", + sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt", + device="cuda", + prompt_text=prompt_text, + detection_interval=detection_interval, + ) + tracker.set_prompt("person.") + + # Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4")) + cap = cv2.VideoCapture(0) + if not cap.isOpened(): + print("[Error] Cannot open camera.") + return + + print("[Info] Camera opened. Press 'q' to quit.") + frame_idx = 0 + + try: + while True: + ret, frame = cap.read() + if not ret: + print("[Warning] Failed to capture frame.") + break + + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + print(f"[Frame {frame_idx}] Processing live frame...") + process_image = tracker.add_image(frame_rgb) + + if process_image is None or not isinstance(process_image, np.ndarray): + print(f"[Warning] Skipped frame {frame_idx} due to empty result.") + frame_idx += 1 + continue + + # process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR) + # cv2.imshow("Live Inference", process_image_bgr) + + + # if cv2.waitKey(1) & 0xFF == ord('q'): + # print("[Info] Quit signal received.") + # break + + tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb) + frame_idx += 1 + + if frame_idx >= max_frames: + print(f"[Info] Reached max_frames {max_frames}. Stopping.") + break + except KeyboardInterrupt: + print("[Info] Interrupted by user (Ctrl+C).") + finally: + cap.release() + cv2.destroyAllWindows() + print("[Done] Live inference complete.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c7e01cc..4f6e080 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -12,7 +12,7 @@ import torch from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame class SAM2VideoPredictor(SAM2Base): @@ -43,23 +43,33 @@ class SAM2VideoPredictor(SAM2Base): @torch.inference_mode() def init_state( self, - video_path, + video_path=None, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, ): """Initialize an inference state.""" compute_device = self.device # device of the model - images, video_height, video_width = load_video_frames( - video_path=video_path, - image_size=self.image_size, - offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, - compute_device=compute_device, - ) inference_state = {} - inference_state["images"] = images - inference_state["num_frames"] = len(images) + if video_path is not None: + # Preload video frames from file + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state["images"] = images + inference_state["num_frames"] = len(images) + else: + # Real-time streaming mode + print("Real-time streaming mode: waiting for first image input...") + images = None + video_height, video_width = None, None + inference_state["images"] = None + inference_state["num_frames"] = 0 + # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu @@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base): inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"] = {} # Warm up the visual backbone and cache the image feature on frame 0 - self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + if video_path is not None: + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state @classmethod @@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base): inference_state, pred_masks ) yield frame_idx, obj_ids, video_res_masks + @torch.inference_mode() + def add_new_frame(self, inference_state, new_image): + """ + Add a new frame to the inference state and cache its image features. + Args: + inference_state (dict): The current inference state containing cached frames, features, and tracking information. + new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing). + Returns: + frame_idx (int): The index of the newly added frame within the inference state. + """ + device = inference_state["device"] + + # Preprocess the input frame and convert it to a normalized tensor + img_tensor, orig_h, orig_w = process_stream_frame( + img_array=new_image, + image_size=self.image_size, + offload_to_cpu=False, + compute_device=device, + ) + + # Handle initialization of the image sequence if this is the first frame + images = inference_state.get("images", None) + if images is None or (isinstance(images, list) and len(images) == 0): + # First frame: initialize image tensor batch + inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W] + else: + # Append to existing tensor batch + if isinstance(images, list): + raise ValueError( + "inference_state['images'] should be a Tensor, not a list after initialization." + ) + + img_tensor = img_tensor.to(images.device) + inference_state["images"] = torch.cat( + [images, img_tensor.unsqueeze(0)], dim=0 + ) + + # Update frame count and compute new frame index + inference_state["num_frames"] = inference_state["images"].shape[0] + frame_idx = inference_state["num_frames"] - 1 + + # Cache visual features for the newly added frame + image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W] + backbone_out = self.forward_image(image_batch) + inference_state["cached_features"][frame_idx] = (image_batch, backbone_out) + + return frame_idx + + @torch.inference_mode() + def infer_single_frame(self, inference_state, frame_idx): + """ + Run inference on a single frame using existing points/masks in the inference state. + Args: + inference_state (dict): The current state of the tracking process. + frame_idx (int): Index of the frame to run inference on. + Returns: + frame_idx (int): Same as input; the index of the processed frame. + obj_ids (list): List of currently tracked object IDs. + video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame. + """ + if frame_idx >= inference_state["num_frames"]: + raise ValueError( + f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})." + ) + + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + batch_size = self._get_obj_num(inference_state) + + # Ensure that initial conditioning points exist + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError( + "No conditioning points provided. Please add points before inference." + ) + + # Decide whether to clear nearby memory based on number of objects + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + obj_ids = inference_state["obj_ids"] + + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + # If output is already consolidated with conditioning inputs + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + + if clear_non_cond_mem: + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + # If output was inferred without conditioning + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + + else: + # Run model inference for this frame + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + + # Organize per-object outputs and mark frame as tracked + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": False} + + # Convert output to original video resolution + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + + return frame_idx, obj_ids, video_res_masks def _add_output_per_object( self, inference_state, frame_idx, current_out, storage_key diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index b65ee82..07ba2c2 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -8,6 +8,7 @@ import os import warnings from threading import Thread +from typing import Tuple import numpy as np import torch from PIL import Image @@ -209,6 +210,74 @@ def load_video_frames( "Only MP4 video and JPEG folder are supported at this moment" ) +def process_stream_frame( + img_array: np.ndarray, + image_size: int, + img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), + img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225), + offload_to_cpu: bool = False, + compute_device: torch.device = torch.device("cuda"), +): + """ + Convert a raw image array (H,W,3 or 3,H,W) into a model‑ready tensor. + Steps + ----- + 1. Resize the shorter side to `image_size`, keeping aspect ratio, + then center‑crop/pad to `image_size` × `image_size`. + 2. Change layout to [3, H, W] and cast to float32 in [0,1]. + 3. Normalise with ImageNet statistics. + 4. Optionally move to `compute_device`. + Returns + ------- + img_tensor : torch.FloatTensor # shape [3, image_size, image_size] + orig_h : int + orig_w : int + """ + + # ↪ uses your existing helper so behaviour matches the batch loader + img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size) + + # Normalisation (done *after* potential device move for efficiency) + img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if not offload_to_cpu: + img_tensor = img_tensor.to(compute_device) + img_mean_t = img_mean_t.to(compute_device) + img_std_t = img_std_t.to(compute_device) + + img_tensor.sub_(img_mean_t).div_(img_std_t) + + return img_tensor, orig_h, orig_w + + +def _resize_and_convert_to_tensor(img_array, image_size): + """ + Resize the input image array and convert it into a tensor. + Also return original image height and width. + """ + # Convert numpy array to PIL image and ensure RGB + img_pil = Image.fromarray(img_array).convert("RGB") + + # Save original size (PIL: size = (width, height)) + video_width, video_height = img_pil.size + + # Resize with high-quality LANCZOS filter + img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS) + + # Convert resized image back to numpy and then to float tensor + img_resized_array = np.array(img_resized) + + if img_resized_array.dtype == np.uint8: + img_resized_array = img_resized_array / 255.0 + else: + raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}") + + # Convert to PyTorch tensor and permute to [C, H, W] + img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1) + + return img_tensor, video_height, video_width + def load_video_frames_from_jpg_images( video_path,