SAM2.1
SAM2.1 checkpoints + training code + Demo
This commit is contained in:
427
demo/backend/server/inference/predictor.py
Normal file
427
demo/backend/server/inference/predictor.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Generator, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from app_conf import APP_ROOT, MODEL_SIZE
|
||||
from inference.data_types import (
|
||||
AddMaskRequest,
|
||||
AddPointsRequest,
|
||||
CancelPorpagateResponse,
|
||||
CancelPropagateInVideoRequest,
|
||||
ClearPointsInFrameRequest,
|
||||
ClearPointsInVideoRequest,
|
||||
ClearPointsInVideoResponse,
|
||||
CloseSessionRequest,
|
||||
CloseSessionResponse,
|
||||
Mask,
|
||||
PropagateDataResponse,
|
||||
PropagateDataValue,
|
||||
PropagateInVideoRequest,
|
||||
RemoveObjectRequest,
|
||||
RemoveObjectResponse,
|
||||
StartSessionRequest,
|
||||
StartSessionResponse,
|
||||
)
|
||||
from pycocotools.mask import decode as decode_masks, encode as encode_masks
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InferenceAPI:
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(InferenceAPI, self).__init__()
|
||||
|
||||
self.session_states: Dict[str, Any] = {}
|
||||
self.score_thresh = 0
|
||||
|
||||
if MODEL_SIZE == "tiny":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
||||
elif MODEL_SIZE == "small":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
|
||||
elif MODEL_SIZE == "large":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
else: # base_plus (default)
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||
|
||||
# select the device for computation
|
||||
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
|
||||
if force_cpu_device:
|
||||
logger.info("forcing CPU device for SAM 2 demo")
|
||||
if torch.cuda.is_available() and not force_cpu_device:
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available() and not force_cpu_device:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info(f"using device: {device}")
|
||||
|
||||
if device.type == "cuda":
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
elif device.type == "mps":
|
||||
logging.warning(
|
||||
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS. "
|
||||
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
|
||||
)
|
||||
|
||||
self.device = device
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg, checkpoint, device=device
|
||||
)
|
||||
self.inference_lock = Lock()
|
||||
|
||||
def autocast_context(self):
|
||||
if self.device.type == "cuda":
|
||||
return torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = str(uuid.uuid4())
|
||||
# for MPS devices, we offload the video frames to CPU by default to avoid
|
||||
# memory fragmentation in MPS (which sometimes crashes the entire process)
|
||||
offload_video_to_cpu = self.device.type == "mps"
|
||||
inference_state = self.predictor.init_state(
|
||||
request.path,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
)
|
||||
self.session_states[session_id] = {
|
||||
"canceled": False,
|
||||
"state": inference_state,
|
||||
}
|
||||
return StartSessionResponse(session_id=session_id)
|
||||
|
||||
def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
|
||||
is_successful = self.__clear_session_state(request.session_id)
|
||||
return CloseSessionResponse(success=is_successful)
|
||||
|
||||
def add_points(
|
||||
self, request: AddPointsRequest, test: str = ""
|
||||
) -> PropagateDataResponse:
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session = self.__get_session(request.session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
points = request.points
|
||||
labels = request.labels
|
||||
clear_old_points = request.clear_old_points
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
clear_old_points=clear_old_points,
|
||||
normalize_coords=False,
|
||||
)
|
||||
|
||||
masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=object_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
|
||||
"""
|
||||
Add new points on a specific video frame.
|
||||
- mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
|
||||
Note: providing an input mask would overwrite any previous input points on this frame.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
rle_mask = {
|
||||
"counts": request.mask.counts,
|
||||
"size": request.mask.size,
|
||||
}
|
||||
|
||||
mask = decode_masks(rle_mask)
|
||||
|
||||
logger.info(
|
||||
f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
|
||||
)
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
mask=torch.tensor(mask > 0),
|
||||
)
|
||||
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def clear_points_in_frame(
|
||||
self, request: ClearPointsInFrameRequest
|
||||
) -> PropagateDataResponse:
|
||||
"""
|
||||
Remove all input points in a specific frame.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
|
||||
logger.info(
|
||||
f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
|
||||
)
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
frame_idx, obj_ids, video_res_masks = (
|
||||
self.predictor.clear_all_prompts_in_frame(
|
||||
inference_state, frame_idx, obj_id
|
||||
)
|
||||
)
|
||||
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def clear_points_in_video(
|
||||
self, request: ClearPointsInVideoRequest
|
||||
) -> ClearPointsInVideoResponse:
|
||||
"""
|
||||
Remove all input points in all frames throughout the video.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
logger.info(f"clear all inputs across the video in session {session_id}")
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
self.predictor.reset_state(inference_state)
|
||||
return ClearPointsInVideoResponse(success=True)
|
||||
|
||||
def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
|
||||
"""
|
||||
Remove an object id from the tracking state.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
obj_id = request.object_id
|
||||
logger.info(f"remove object in session {session_id}: {obj_id=}")
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
new_obj_ids, updated_frames = self.predictor.remove_object(
|
||||
inference_state, obj_id
|
||||
)
|
||||
|
||||
results = []
|
||||
for frame_index, video_res_masks in updated_frames:
|
||||
masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=new_obj_ids, masks=masks
|
||||
)
|
||||
results.append(
|
||||
PropagateDataResponse(
|
||||
frame_index=frame_index,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
)
|
||||
|
||||
return RemoveObjectResponse(results=results)
|
||||
|
||||
def propagate_in_video(
|
||||
self, request: PropagateInVideoRequest
|
||||
) -> Generator[PropagateDataResponse, None, None]:
|
||||
session_id = request.session_id
|
||||
start_frame_idx = request.start_frame_index
|
||||
propagation_direction = "both"
|
||||
max_frame_num_to_track = None
|
||||
|
||||
"""
|
||||
Propagate existing input points in all frames to track the object across video.
|
||||
"""
|
||||
|
||||
# Note that as this method is a generator, we also need to use autocast_context
|
||||
# in caller to this method to ensure that it's called under the correct context
|
||||
# (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
logger.info(
|
||||
f"propagate in video in session {session_id}: "
|
||||
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
|
||||
)
|
||||
|
||||
try:
|
||||
session = self.__get_session(session_id)
|
||||
session["canceled"] = False
|
||||
|
||||
inference_state = session["state"]
|
||||
if propagation_direction not in ["both", "forward", "backward"]:
|
||||
raise ValueError(
|
||||
f"invalid propagation direction: {propagation_direction}"
|
||||
)
|
||||
|
||||
# First doing the forward propagation
|
||||
if propagation_direction in ["both", "forward"]:
|
||||
for outputs in self.predictor.propagate_in_video(
|
||||
inference_state=inference_state,
|
||||
start_frame_idx=start_frame_idx,
|
||||
max_frame_num_to_track=max_frame_num_to_track,
|
||||
reverse=False,
|
||||
):
|
||||
if session["canceled"]:
|
||||
return None
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = outputs
|
||||
masks_binary = (
|
||||
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
)
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
yield PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
# Then doing the backward propagation (reverse in time)
|
||||
if propagation_direction in ["both", "backward"]:
|
||||
for outputs in self.predictor.propagate_in_video(
|
||||
inference_state=inference_state,
|
||||
start_frame_idx=start_frame_idx,
|
||||
max_frame_num_to_track=max_frame_num_to_track,
|
||||
reverse=True,
|
||||
):
|
||||
if session["canceled"]:
|
||||
return None
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = outputs
|
||||
masks_binary = (
|
||||
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
)
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
yield PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
finally:
|
||||
# Log upon completion (so that e.g. we can see if two propagations happen in parallel).
|
||||
# Using `finally` here to log even when the tracking is aborted with GeneratorExit.
|
||||
logger.info(
|
||||
f"propagation ended in session {session_id}; {self.__get_session_stats()}"
|
||||
)
|
||||
|
||||
def cancel_propagate_in_video(
|
||||
self, request: CancelPropagateInVideoRequest
|
||||
) -> CancelPorpagateResponse:
|
||||
session = self.__get_session(request.session_id)
|
||||
session["canceled"] = True
|
||||
return CancelPorpagateResponse(success=True)
|
||||
|
||||
def __get_rle_mask_list(
|
||||
self, object_ids: List[int], masks: np.ndarray
|
||||
) -> List[PropagateDataValue]:
|
||||
"""
|
||||
Return a list of data values, i.e. list of object/mask combos.
|
||||
"""
|
||||
return [
|
||||
self.__get_mask_for_object(object_id=object_id, mask=mask)
|
||||
for object_id, mask in zip(object_ids, masks)
|
||||
]
|
||||
|
||||
def __get_mask_for_object(
|
||||
self, object_id: int, mask: np.ndarray
|
||||
) -> PropagateDataValue:
|
||||
"""
|
||||
Create a data value for an object/mask combo.
|
||||
"""
|
||||
mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
|
||||
mask_rle["counts"] = mask_rle["counts"].decode()
|
||||
return PropagateDataValue(
|
||||
object_id=object_id,
|
||||
mask=Mask(
|
||||
size=mask_rle["size"],
|
||||
counts=mask_rle["counts"],
|
||||
),
|
||||
)
|
||||
|
||||
def __get_session(self, session_id: str):
|
||||
session = self.session_states.get(session_id, None)
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot find session {session_id}; it might have expired"
|
||||
)
|
||||
return session
|
||||
|
||||
def __get_session_stats(self):
|
||||
"""Get a statistics string for live sessions and their GPU usage."""
|
||||
# print both the session ids and their video frame numbers
|
||||
live_session_strs = [
|
||||
f"'{session_id}' ({session['state']['num_frames']} frames, "
|
||||
f"{len(session['state']['obj_ids'])} objects)"
|
||||
for session_id, session in self.session_states.items()
|
||||
]
|
||||
session_stats_str = (
|
||||
"Test String Here - -"
|
||||
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
|
||||
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
|
||||
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
|
||||
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
|
||||
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
|
||||
)
|
||||
return session_stats_str
|
||||
|
||||
def __clear_session_state(self, session_id: str) -> bool:
|
||||
session = self.session_states.pop(session_id, None)
|
||||
if session is None:
|
||||
logger.warning(
|
||||
f"cannot close session {session_id} as it does not exist (it might have expired); "
|
||||
f"{self.__get_session_stats()}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
|
||||
return True
|
Reference in New Issue
Block a user