428 lines
17 KiB
Python
428 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any, Dict, Generator, List
|
|
|
|
import numpy as np
|
|
import torch
|
|
from app_conf import APP_ROOT, MODEL_SIZE
|
|
from inference.data_types import (
|
|
AddMaskRequest,
|
|
AddPointsRequest,
|
|
CancelPorpagateResponse,
|
|
CancelPropagateInVideoRequest,
|
|
ClearPointsInFrameRequest,
|
|
ClearPointsInVideoRequest,
|
|
ClearPointsInVideoResponse,
|
|
CloseSessionRequest,
|
|
CloseSessionResponse,
|
|
Mask,
|
|
PropagateDataResponse,
|
|
PropagateDataValue,
|
|
PropagateInVideoRequest,
|
|
RemoveObjectRequest,
|
|
RemoveObjectResponse,
|
|
StartSessionRequest,
|
|
StartSessionResponse,
|
|
)
|
|
from pycocotools.mask import decode as decode_masks, encode as encode_masks
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InferenceAPI:
|
|
|
|
def __init__(self) -> None:
|
|
super(InferenceAPI, self).__init__()
|
|
|
|
self.session_states: Dict[str, Any] = {}
|
|
self.score_thresh = 0
|
|
|
|
if MODEL_SIZE == "tiny":
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
|
elif MODEL_SIZE == "small":
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
|
|
elif MODEL_SIZE == "large":
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
else: # base_plus (default)
|
|
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
|
|
|
# select the device for computation
|
|
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
|
|
if force_cpu_device:
|
|
logger.info("forcing CPU device for SAM 2 demo")
|
|
if torch.cuda.is_available() and not force_cpu_device:
|
|
device = torch.device("cuda")
|
|
elif torch.backends.mps.is_available() and not force_cpu_device:
|
|
device = torch.device("mps")
|
|
else:
|
|
device = torch.device("cpu")
|
|
logger.info(f"using device: {device}")
|
|
|
|
if device.type == "cuda":
|
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
elif device.type == "mps":
|
|
logging.warning(
|
|
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
|
"give numerically different outputs and sometimes degraded performance on MPS. "
|
|
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
|
|
)
|
|
|
|
self.device = device
|
|
self.predictor = build_sam2_video_predictor(
|
|
model_cfg, checkpoint, device=device
|
|
)
|
|
self.inference_lock = Lock()
|
|
|
|
def autocast_context(self):
|
|
if self.device.type == "cuda":
|
|
return torch.autocast("cuda", dtype=torch.bfloat16)
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
|
|
with self.autocast_context(), self.inference_lock:
|
|
session_id = str(uuid.uuid4())
|
|
# for MPS devices, we offload the video frames to CPU by default to avoid
|
|
# memory fragmentation in MPS (which sometimes crashes the entire process)
|
|
offload_video_to_cpu = self.device.type == "mps"
|
|
inference_state = self.predictor.init_state(
|
|
request.path,
|
|
offload_video_to_cpu=offload_video_to_cpu,
|
|
)
|
|
self.session_states[session_id] = {
|
|
"canceled": False,
|
|
"state": inference_state,
|
|
}
|
|
return StartSessionResponse(session_id=session_id)
|
|
|
|
def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
|
|
is_successful = self.__clear_session_state(request.session_id)
|
|
return CloseSessionResponse(success=is_successful)
|
|
|
|
def add_points(
|
|
self, request: AddPointsRequest, test: str = ""
|
|
) -> PropagateDataResponse:
|
|
with self.autocast_context(), self.inference_lock:
|
|
session = self.__get_session(request.session_id)
|
|
inference_state = session["state"]
|
|
|
|
frame_idx = request.frame_index
|
|
obj_id = request.object_id
|
|
points = request.points
|
|
labels = request.labels
|
|
clear_old_points = request.clear_old_points
|
|
|
|
# add new prompts and instantly get the output on the same frame
|
|
frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
|
|
inference_state=inference_state,
|
|
frame_idx=frame_idx,
|
|
obj_id=obj_id,
|
|
points=points,
|
|
labels=labels,
|
|
clear_old_points=clear_old_points,
|
|
normalize_coords=False,
|
|
)
|
|
|
|
masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=object_ids, masks=masks_binary
|
|
)
|
|
|
|
return PropagateDataResponse(
|
|
frame_index=frame_idx,
|
|
results=rle_mask_list,
|
|
)
|
|
|
|
def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
|
|
"""
|
|
Add new points on a specific video frame.
|
|
- mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
|
|
Note: providing an input mask would overwrite any previous input points on this frame.
|
|
"""
|
|
with self.autocast_context(), self.inference_lock:
|
|
session_id = request.session_id
|
|
frame_idx = request.frame_index
|
|
obj_id = request.object_id
|
|
rle_mask = {
|
|
"counts": request.mask.counts,
|
|
"size": request.mask.size,
|
|
}
|
|
|
|
mask = decode_masks(rle_mask)
|
|
|
|
logger.info(
|
|
f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
|
|
)
|
|
session = self.__get_session(session_id)
|
|
inference_state = session["state"]
|
|
|
|
frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
|
|
inference_state=inference_state,
|
|
frame_idx=frame_idx,
|
|
obj_id=obj_id,
|
|
mask=torch.tensor(mask > 0),
|
|
)
|
|
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=obj_ids, masks=masks_binary
|
|
)
|
|
|
|
return PropagateDataResponse(
|
|
frame_index=frame_idx,
|
|
results=rle_mask_list,
|
|
)
|
|
|
|
def clear_points_in_frame(
|
|
self, request: ClearPointsInFrameRequest
|
|
) -> PropagateDataResponse:
|
|
"""
|
|
Remove all input points in a specific frame.
|
|
"""
|
|
with self.autocast_context(), self.inference_lock:
|
|
session_id = request.session_id
|
|
frame_idx = request.frame_index
|
|
obj_id = request.object_id
|
|
|
|
logger.info(
|
|
f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
|
|
)
|
|
session = self.__get_session(session_id)
|
|
inference_state = session["state"]
|
|
frame_idx, obj_ids, video_res_masks = (
|
|
self.predictor.clear_all_prompts_in_frame(
|
|
inference_state, frame_idx, obj_id
|
|
)
|
|
)
|
|
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=obj_ids, masks=masks_binary
|
|
)
|
|
|
|
return PropagateDataResponse(
|
|
frame_index=frame_idx,
|
|
results=rle_mask_list,
|
|
)
|
|
|
|
def clear_points_in_video(
|
|
self, request: ClearPointsInVideoRequest
|
|
) -> ClearPointsInVideoResponse:
|
|
"""
|
|
Remove all input points in all frames throughout the video.
|
|
"""
|
|
with self.autocast_context(), self.inference_lock:
|
|
session_id = request.session_id
|
|
logger.info(f"clear all inputs across the video in session {session_id}")
|
|
session = self.__get_session(session_id)
|
|
inference_state = session["state"]
|
|
self.predictor.reset_state(inference_state)
|
|
return ClearPointsInVideoResponse(success=True)
|
|
|
|
def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
|
|
"""
|
|
Remove an object id from the tracking state.
|
|
"""
|
|
with self.autocast_context(), self.inference_lock:
|
|
session_id = request.session_id
|
|
obj_id = request.object_id
|
|
logger.info(f"remove object in session {session_id}: {obj_id=}")
|
|
session = self.__get_session(session_id)
|
|
inference_state = session["state"]
|
|
new_obj_ids, updated_frames = self.predictor.remove_object(
|
|
inference_state, obj_id
|
|
)
|
|
|
|
results = []
|
|
for frame_index, video_res_masks in updated_frames:
|
|
masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=new_obj_ids, masks=masks
|
|
)
|
|
results.append(
|
|
PropagateDataResponse(
|
|
frame_index=frame_index,
|
|
results=rle_mask_list,
|
|
)
|
|
)
|
|
|
|
return RemoveObjectResponse(results=results)
|
|
|
|
def propagate_in_video(
|
|
self, request: PropagateInVideoRequest
|
|
) -> Generator[PropagateDataResponse, None, None]:
|
|
session_id = request.session_id
|
|
start_frame_idx = request.start_frame_index
|
|
propagation_direction = "both"
|
|
max_frame_num_to_track = None
|
|
|
|
"""
|
|
Propagate existing input points in all frames to track the object across video.
|
|
"""
|
|
|
|
# Note that as this method is a generator, we also need to use autocast_context
|
|
# in caller to this method to ensure that it's called under the correct context
|
|
# (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
|
|
with self.autocast_context(), self.inference_lock:
|
|
logger.info(
|
|
f"propagate in video in session {session_id}: "
|
|
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
|
|
)
|
|
|
|
try:
|
|
session = self.__get_session(session_id)
|
|
session["canceled"] = False
|
|
|
|
inference_state = session["state"]
|
|
if propagation_direction not in ["both", "forward", "backward"]:
|
|
raise ValueError(
|
|
f"invalid propagation direction: {propagation_direction}"
|
|
)
|
|
|
|
# First doing the forward propagation
|
|
if propagation_direction in ["both", "forward"]:
|
|
for outputs in self.predictor.propagate_in_video(
|
|
inference_state=inference_state,
|
|
start_frame_idx=start_frame_idx,
|
|
max_frame_num_to_track=max_frame_num_to_track,
|
|
reverse=False,
|
|
):
|
|
if session["canceled"]:
|
|
return None
|
|
|
|
frame_idx, obj_ids, video_res_masks = outputs
|
|
masks_binary = (
|
|
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
)
|
|
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=obj_ids, masks=masks_binary
|
|
)
|
|
|
|
yield PropagateDataResponse(
|
|
frame_index=frame_idx,
|
|
results=rle_mask_list,
|
|
)
|
|
|
|
# Then doing the backward propagation (reverse in time)
|
|
if propagation_direction in ["both", "backward"]:
|
|
for outputs in self.predictor.propagate_in_video(
|
|
inference_state=inference_state,
|
|
start_frame_idx=start_frame_idx,
|
|
max_frame_num_to_track=max_frame_num_to_track,
|
|
reverse=True,
|
|
):
|
|
if session["canceled"]:
|
|
return None
|
|
|
|
frame_idx, obj_ids, video_res_masks = outputs
|
|
masks_binary = (
|
|
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
|
)
|
|
|
|
rle_mask_list = self.__get_rle_mask_list(
|
|
object_ids=obj_ids, masks=masks_binary
|
|
)
|
|
|
|
yield PropagateDataResponse(
|
|
frame_index=frame_idx,
|
|
results=rle_mask_list,
|
|
)
|
|
finally:
|
|
# Log upon completion (so that e.g. we can see if two propagations happen in parallel).
|
|
# Using `finally` here to log even when the tracking is aborted with GeneratorExit.
|
|
logger.info(
|
|
f"propagation ended in session {session_id}; {self.__get_session_stats()}"
|
|
)
|
|
|
|
def cancel_propagate_in_video(
|
|
self, request: CancelPropagateInVideoRequest
|
|
) -> CancelPorpagateResponse:
|
|
session = self.__get_session(request.session_id)
|
|
session["canceled"] = True
|
|
return CancelPorpagateResponse(success=True)
|
|
|
|
def __get_rle_mask_list(
|
|
self, object_ids: List[int], masks: np.ndarray
|
|
) -> List[PropagateDataValue]:
|
|
"""
|
|
Return a list of data values, i.e. list of object/mask combos.
|
|
"""
|
|
return [
|
|
self.__get_mask_for_object(object_id=object_id, mask=mask)
|
|
for object_id, mask in zip(object_ids, masks)
|
|
]
|
|
|
|
def __get_mask_for_object(
|
|
self, object_id: int, mask: np.ndarray
|
|
) -> PropagateDataValue:
|
|
"""
|
|
Create a data value for an object/mask combo.
|
|
"""
|
|
mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
|
|
mask_rle["counts"] = mask_rle["counts"].decode()
|
|
return PropagateDataValue(
|
|
object_id=object_id,
|
|
mask=Mask(
|
|
size=mask_rle["size"],
|
|
counts=mask_rle["counts"],
|
|
),
|
|
)
|
|
|
|
def __get_session(self, session_id: str):
|
|
session = self.session_states.get(session_id, None)
|
|
if session is None:
|
|
raise RuntimeError(
|
|
f"Cannot find session {session_id}; it might have expired"
|
|
)
|
|
return session
|
|
|
|
def __get_session_stats(self):
|
|
"""Get a statistics string for live sessions and their GPU usage."""
|
|
# print both the session ids and their video frame numbers
|
|
live_session_strs = [
|
|
f"'{session_id}' ({session['state']['num_frames']} frames, "
|
|
f"{len(session['state']['obj_ids'])} objects)"
|
|
for session_id, session in self.session_states.items()
|
|
]
|
|
session_stats_str = (
|
|
"Test String Here - -"
|
|
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
|
|
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
|
|
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
|
|
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
|
|
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
|
|
)
|
|
return session_stats_str
|
|
|
|
def __clear_session_state(self, session_id: str) -> bool:
|
|
session = self.session_states.pop(session_id, None)
|
|
if session is None:
|
|
logger.warning(
|
|
f"cannot close session {session_id} as it does not exist (it might have expired); "
|
|
f"{self.__get_session_stats()}"
|
|
)
|
|
return False
|
|
else:
|
|
logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
|
|
return True
|