[New Feature] Support SAM 2.1 (#59)

* support sam 2.1

* refine config path and ckpt path

* update README
This commit is contained in:
Ren Tianhe
2024-10-10 14:55:50 +08:00
committed by GitHub
parent e899ad99e8
commit 82e503604f
340 changed files with 39100 additions and 608 deletions

140
demo/backend/server/app.py Normal file
View File

@@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Generator
from app_conf import (
GALLERY_PATH,
GALLERY_PREFIX,
POSTERS_PATH,
POSTERS_PREFIX,
UPLOADS_PATH,
UPLOADS_PREFIX,
)
from data.loader import preload_data
from data.schema import schema
from data.store import set_videos
from flask import Flask, make_response, Request, request, Response, send_from_directory
from flask_cors import CORS
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
from inference.multipart import MultipartResponseBuilder
from inference.predictor import InferenceAPI
from strawberry.flask.views import GraphQLView
logger = logging.getLogger(__name__)
app = Flask(__name__)
cors = CORS(app, supports_credentials=True)
videos = preload_data()
set_videos(videos)
inference_api = InferenceAPI()
@app.route("/healthy")
def healthy() -> Response:
return make_response("OK", 200)
@app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
def send_gallery_video(path: str) -> Response:
try:
return send_from_directory(
GALLERY_PATH,
path,
)
except:
raise ValueError("resource not found")
@app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
def send_poster_image(path: str) -> Response:
try:
return send_from_directory(
POSTERS_PATH,
path,
)
except:
raise ValueError("resource not found")
@app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
def send_uploaded_video(path: str):
try:
return send_from_directory(
UPLOADS_PATH,
path,
)
except:
raise ValueError("resource not found")
# TOOD: Protect route with ToS permission check
@app.route("/propagate_in_video", methods=["POST"])
def propagate_in_video() -> Response:
data = request.json
args = {
"session_id": data["session_id"],
"start_frame_index": data.get("start_frame_index", 0),
}
boundary = "frame"
frame = gen_track_with_mask_stream(boundary, **args)
return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
def gen_track_with_mask_stream(
boundary: str,
session_id: str,
start_frame_index: int,
) -> Generator[bytes, None, None]:
with inference_api.autocast_context():
request = PropagateInVideoRequest(
type="propagate_in_video",
session_id=session_id,
start_frame_index=start_frame_index,
)
for chunk in inference_api.propagate_in_video(request=request):
yield MultipartResponseBuilder.build(
boundary=boundary,
headers={
"Content-Type": "application/json; charset=utf-8",
"Frame-Current": "-1",
# Total frames minus the reference frame
"Frame-Total": "-1",
"Mask-Type": "RLE[]",
},
body=chunk.to_json().encode("UTF-8"),
).get_message()
class MyGraphQLView(GraphQLView):
def get_context(self, request: Request, response: Response) -> Any:
return {"inference_api": inference_api}
# Add GraphQL route to Flask app.
app.add_url_rule(
"/graphql",
view_func=MyGraphQLView.as_view(
"graphql_view",
schema=schema,
# Disable GET queries
# https://strawberry.rocks/docs/operations/deployment
# https://strawberry.rocks/docs/integrations/flask
allow_queries_via_get=False,
# Strawberry recently changed multipart request handling, which now
# requires enabling support explicitly for views.
# https://github.com/strawberry-graphql/strawberry/issues/3655
multipart_uploads_enabled=True,
),
)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)

View File

@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from pathlib import Path
logger = logging.getLogger(__name__)
APP_ROOT = os.getenv("APP_ROOT", "/opt/sam2")
API_URL = os.getenv("API_URL", "http://localhost:7263")
MODEL_SIZE = os.getenv("MODEL_SIZE", "base_plus")
logger.info(f"using model size {MODEL_SIZE}")
FFMPEG_NUM_THREADS = int(os.getenv("FFMPEG_NUM_THREADS", "1"))
# Path for all data used in API
DATA_PATH = Path(os.getenv("DATA_PATH", "/data"))
# Max duration an uploaded video can have in seconds. The default is 10
# seconds.
MAX_UPLOAD_VIDEO_DURATION = float(os.environ.get("MAX_UPLOAD_VIDEO_DURATION", "10"))
# If set, it will define which video is returned by the default video query for
# desktop
DEFAULT_VIDEO_PATH = os.getenv("DEFAULT_VIDEO_PATH")
# Prefix for gallery videos
GALLERY_PREFIX = "gallery"
# Path where all gallery videos are stored
GALLERY_PATH = DATA_PATH / GALLERY_PREFIX
# Prefix for uploaded videos
UPLOADS_PREFIX = "uploads"
# Path where all uploaded videos are stored
UPLOADS_PATH = DATA_PATH / UPLOADS_PREFIX
# Prefix for video posters (1st frame of video)
POSTERS_PREFIX = "posters"
# Path where all posters are stored
POSTERS_PATH = DATA_PATH / POSTERS_PREFIX
# Make sure any of those paths exist
os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(GALLERY_PATH, exist_ok=True)
os.makedirs(UPLOADS_PATH, exist_ok=True)
os.makedirs(POSTERS_PATH, exist_ok=True)

View File

@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Iterable, List, Optional
import strawberry
from app_conf import API_URL
from data.resolver import resolve_videos
from dataclasses_json import dataclass_json
from strawberry import relay
@strawberry.type
class Video(relay.Node):
"""Core type for video."""
code: relay.NodeID[str]
path: str
poster_path: Optional[str]
width: int
height: int
@strawberry.field
def url(self) -> str:
return f"{API_URL}/{self.path}"
@strawberry.field
def poster_url(self) -> str:
return f"{API_URL}/{self.poster_path}"
@classmethod
def resolve_nodes(
cls,
*,
info: relay.PageInfo,
node_ids: Iterable[str],
required: bool = False,
):
return resolve_videos(node_ids, required)
@strawberry.type
class RLEMask:
"""Core type for Onevision GraphQL RLE mask."""
size: List[int]
counts: str
order: str
@strawberry.type
class RLEMaskForObject:
"""Type for RLE mask associated with a specific object id."""
object_id: int
rle_mask: RLEMask
@strawberry.type
class RLEMaskListOnFrame:
"""Type for a list of object-associated RLE masks on a specific video frame."""
frame_index: int
rle_mask_list: List[RLEMaskForObject]
@strawberry.input
class StartSessionInput:
path: str
@strawberry.type
class StartSession:
session_id: str
@strawberry.input
class PingInput:
session_id: str
@strawberry.type
class Pong:
success: bool
@strawberry.input
class CloseSessionInput:
session_id: str
@strawberry.type
class CloseSession:
success: bool
@strawberry.input
class AddPointsInput:
session_id: str
frame_index: int
clear_old_points: bool
object_id: int
labels: List[int]
points: List[List[float]]
@strawberry.input
class ClearPointsInFrameInput:
session_id: str
frame_index: int
object_id: int
@strawberry.input
class ClearPointsInVideoInput:
session_id: str
@strawberry.type
class ClearPointsInVideo:
success: bool
@strawberry.input
class RemoveObjectInput:
session_id: str
object_id: int
@strawberry.input
class PropagateInVideoInput:
session_id: str
start_frame_index: int
@strawberry.input
class CancelPropagateInVideoInput:
session_id: str
@strawberry.type
class CancelPropagateInVideo:
success: bool
@strawberry.type
class SessionExpiration:
session_id: str
expiration_time: int
max_expiration_time: int
ttl: int

View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import subprocess
from glob import glob
from pathlib import Path
from typing import Dict, Optional
import imagesize
from app_conf import GALLERY_PATH, POSTERS_PATH, POSTERS_PREFIX
from data.data_types import Video
from tqdm import tqdm
def preload_data() -> Dict[str, Video]:
"""
Preload data including gallery videos and their posters.
"""
# Dictionaries for videos and datasets on the backend.
# Note that since Python 3.7, dictionaries preserve their insert order, so
# when looping over its `.values()`, elements inserted first also appear first.
# https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6
all_videos = {}
video_path_pattern = os.path.join(GALLERY_PATH, "**/*.mp4")
video_paths = glob(video_path_pattern, recursive=True)
for p in tqdm(video_paths):
video = get_video(p, GALLERY_PATH)
all_videos[video.code] = video
return all_videos
def get_video(
filepath: os.PathLike,
absolute_path: Path,
file_key: Optional[str] = None,
generate_poster: bool = True,
width: Optional[int] = None,
height: Optional[int] = None,
verbose: Optional[bool] = False,
) -> Video:
"""
Get video object given
"""
# Use absolute_path to include the parent directory in the video
video_path = os.path.relpath(filepath, absolute_path.parent)
poster_path = None
if generate_poster:
poster_id = os.path.splitext(os.path.basename(filepath))[0]
poster_filename = f"{str(poster_id)}.jpg"
poster_path = f"{POSTERS_PREFIX}/{poster_filename}"
# Extract the first frame from video
poster_output_path = os.path.join(POSTERS_PATH, poster_filename)
ffmpeg = shutil.which("ffmpeg")
subprocess.call(
[
ffmpeg,
"-y",
"-i",
str(filepath),
"-pix_fmt",
"yuv420p",
"-frames:v",
"1",
"-update",
"1",
"-strict",
"unofficial",
str(poster_output_path),
],
stdout=None if verbose else subprocess.DEVNULL,
stderr=None if verbose else subprocess.DEVNULL,
)
# Extract video width and height from poster. This is important to optimize
# rendering previews in the mosaic video preview.
width, height = imagesize.get(poster_output_path)
return Video(
code=video_path,
path=video_path if file_key is None else file_key,
poster_path=poster_path,
width=width,
height=height,
)

View File

@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable
def resolve_videos(node_ids: Iterable[str], required: bool = False):
"""
Resolve videos given node ids.
"""
from data.store import get_videos
all_videos = get_videos()
return [
all_videos[nid] if required else all_videos.get(nid, None) for nid in node_ids
]

View File

@@ -0,0 +1,357 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import os
import shutil
import tempfile
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
import av
import strawberry
from app_conf import (
DATA_PATH,
DEFAULT_VIDEO_PATH,
MAX_UPLOAD_VIDEO_DURATION,
UPLOADS_PATH,
UPLOADS_PREFIX,
)
from data.data_types import (
AddPointsInput,
CancelPropagateInVideo,
CancelPropagateInVideoInput,
ClearPointsInFrameInput,
ClearPointsInVideo,
ClearPointsInVideoInput,
CloseSession,
CloseSessionInput,
RemoveObjectInput,
RLEMask,
RLEMaskForObject,
RLEMaskListOnFrame,
StartSession,
StartSessionInput,
Video,
)
from data.loader import get_video
from data.store import get_videos
from data.transcoder import get_video_metadata, transcode, VideoMetadata
from inference.data_types import (
AddPointsRequest,
CancelPropagateInVideoRequest,
CancelPropagateInVideoRequest,
ClearPointsInFrameRequest,
ClearPointsInVideoRequest,
CloseSessionRequest,
RemoveObjectRequest,
StartSessionRequest,
)
from inference.predictor import InferenceAPI
from strawberry import relay
from strawberry.file_uploads import Upload
@strawberry.type
class Query:
@strawberry.field
def default_video(self) -> Video:
"""
Return the default video.
The default video can be set with the DEFAULT_VIDEO_PATH environment
variable. It will return the video that matches this path. If no video
is found, it will return the first video.
"""
all_videos = get_videos()
# Find the video that matches the default path and return that as
# default video.
for _, v in all_videos.items():
if v.path == DEFAULT_VIDEO_PATH:
return v
# Fallback is returning the first video
return next(iter(all_videos.values()))
@relay.connection(relay.ListConnection[Video])
def videos(
self,
) -> Iterable[Video]:
"""
Return all available videos.
"""
all_videos = get_videos()
return all_videos.values()
@strawberry.type
class Mutation:
@strawberry.mutation
def upload_video(
self,
file: Upload,
start_time_sec: Optional[float] = None,
duration_time_sec: Optional[float] = None,
) -> Video:
"""
Receive a video file and store it in the configured S3 bucket.
"""
max_time = MAX_UPLOAD_VIDEO_DURATION
filepath, file_key, vm = process_video(
file,
max_time=max_time,
start_time_sec=start_time_sec,
duration_time_sec=duration_time_sec,
)
video = get_video(
filepath,
UPLOADS_PATH,
file_key=file_key,
width=vm.width,
height=vm.height,
generate_poster=False,
)
return video
@strawberry.mutation
def start_session(
self, input: StartSessionInput, info: strawberry.Info
) -> StartSession:
inference_api: InferenceAPI = info.context["inference_api"]
request = StartSessionRequest(
type="start_session",
path=f"{DATA_PATH}/{input.path}",
)
response = inference_api.start_session(request=request)
return StartSession(session_id=response.session_id)
@strawberry.mutation
def close_session(
self, input: CloseSessionInput, info: strawberry.Info
) -> CloseSession:
inference_api: InferenceAPI = info.context["inference_api"]
request = CloseSessionRequest(
type="close_session",
session_id=input.session_id,
)
response = inference_api.close_session(request)
return CloseSession(success=response.success)
@strawberry.mutation
def add_points(
self, input: AddPointsInput, info: strawberry.Info
) -> RLEMaskListOnFrame:
inference_api: InferenceAPI = info.context["inference_api"]
request = AddPointsRequest(
type="add_points",
session_id=input.session_id,
frame_index=input.frame_index,
object_id=input.object_id,
points=input.points,
labels=input.labels,
clear_old_points=input.clear_old_points,
)
reponse = inference_api.add_points(request)
return RLEMaskListOnFrame(
frame_index=reponse.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
)
for r in reponse.results
],
)
@strawberry.mutation
def remove_object(
self, input: RemoveObjectInput, info: strawberry.Info
) -> List[RLEMaskListOnFrame]:
inference_api: InferenceAPI = info.context["inference_api"]
request = RemoveObjectRequest(
type="remove_object", session_id=input.session_id, object_id=input.object_id
)
response = inference_api.remove_object(request)
return [
RLEMaskListOnFrame(
frame_index=res.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(
counts=r.mask.counts, size=r.mask.size, order="F"
),
)
for r in res.results
],
)
for res in response.results
]
@strawberry.mutation
def clear_points_in_frame(
self, input: ClearPointsInFrameInput, info: strawberry.Info
) -> RLEMaskListOnFrame:
inference_api: InferenceAPI = info.context["inference_api"]
request = ClearPointsInFrameRequest(
type="clear_points_in_frame",
session_id=input.session_id,
frame_index=input.frame_index,
object_id=input.object_id,
)
response = inference_api.clear_points_in_frame(request)
return RLEMaskListOnFrame(
frame_index=response.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
)
for r in response.results
],
)
@strawberry.mutation
def clear_points_in_video(
self, input: ClearPointsInVideoInput, info: strawberry.Info
) -> ClearPointsInVideo:
inference_api: InferenceAPI = info.context["inference_api"]
request = ClearPointsInVideoRequest(
type="clear_points_in_video",
session_id=input.session_id,
)
response = inference_api.clear_points_in_video(request)
return ClearPointsInVideo(success=response.success)
@strawberry.mutation
def cancel_propagate_in_video(
self, input: CancelPropagateInVideoInput, info: strawberry.Info
) -> CancelPropagateInVideo:
inference_api: InferenceAPI = info.context["inference_api"]
request = CancelPropagateInVideoRequest(
type="cancel_propagate_in_video",
session_id=input.session_id,
)
response = inference_api.cancel_propagate_in_video(request)
return CancelPropagateInVideo(success=response.success)
def get_file_hash(video_path_or_file) -> str:
if isinstance(video_path_or_file, str):
with open(video_path_or_file, "rb") as in_f:
result = hashlib.sha256(in_f.read()).hexdigest()
else:
video_path_or_file.seek(0)
result = hashlib.sha256(video_path_or_file.read()).hexdigest()
return result
def _get_start_sec_duration_sec(
start_time_sec: Union[float, None],
duration_time_sec: Union[float, None],
max_time: float,
) -> Tuple[float, float]:
default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0"))
if start_time_sec is None:
start_time_sec = default_seek_t
if duration_time_sec is not None:
duration_time_sec = min(duration_time_sec, max_time)
else:
duration_time_sec = max_time
return start_time_sec, duration_time_sec
def process_video(
file: Upload,
max_time: float,
start_time_sec: Optional[float] = None,
duration_time_sec: Optional[float] = None,
) -> Tuple[Optional[str], str, str, VideoMetadata]:
"""
Process file upload including video trimming and content moderation checks.
Returns the filepath, s3_file_key, hash & video metaedata as a tuple.
"""
with tempfile.TemporaryDirectory() as tempdir:
in_path = f"{tempdir}/in.mp4"
out_path = f"{tempdir}/out.mp4"
with open(in_path, "wb") as in_f:
in_f.write(file.read())
try:
video_metadata = get_video_metadata(in_path)
except av.InvalidDataError:
raise Exception("not valid video file")
if video_metadata.num_video_streams == 0:
raise Exception("video container does not contain a video stream")
if video_metadata.width is None or video_metadata.height is None:
raise Exception("video container does not contain width or height metadata")
if video_metadata.duration_sec in (None, 0):
raise Exception("video container does time duration metadata")
start_time_sec, duration_time_sec = _get_start_sec_duration_sec(
max_time=max_time,
start_time_sec=start_time_sec,
duration_time_sec=duration_time_sec,
)
# Transcode video to make sure videos returned to the app are all in
# the same format, duration, resolution, fps.
transcode(
in_path,
out_path,
video_metadata,
seek_t=start_time_sec,
duration_time_sec=duration_time_sec,
)
os.remove(in_path) # don't need original video now
out_video_metadata = get_video_metadata(out_path)
if out_video_metadata.num_video_frames == 0:
raise Exception(
"transcode produced empty video; check seek time or your input video"
)
filepath = None
file_key = None
with open(out_path, "rb") as file_data:
file_hash = get_file_hash(file_data)
file_data.seek(0)
file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4"
filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4")
assert filepath is not None and file_key is not None
shutil.move(out_path, filepath)
return filepath, file_key, out_video_metadata
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
)

View File

@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
from data.data_types import Video
ALL_VIDEOS: Dict[str, Video] = []
def set_videos(videos: Dict[str, Video]) -> None:
"""
Set the videos available in the backend. The data is kept in-memory, but a future change could replace the
in-memory storage with a database backend. This would also be more efficient when querying videos given a
dataset name etc.
"""
global ALL_VIDEOS
ALL_VIDEOS = videos
def get_videos() -> Dict[str, Video]:
"""
Return the videos available in the backend.
"""
global ALL_VIDEOS
return ALL_VIDEOS

View File

@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import ast
import math
import os
import shutil
import subprocess
from dataclasses import dataclass
from typing import Optional
import av
from app_conf import FFMPEG_NUM_THREADS
from dataclasses_json import dataclass_json
TRANSCODE_VERSION = 1
@dataclass_json
@dataclass
class VideoMetadata:
duration_sec: Optional[float]
video_duration_sec: Optional[float]
container_duration_sec: Optional[float]
fps: Optional[float]
width: Optional[int]
height: Optional[int]
num_video_frames: int
num_video_streams: int
video_start_time: float
def transcode(
in_path: str,
out_path: str,
in_metadata: Optional[VideoMetadata],
seek_t: float,
duration_time_sec: float,
):
codec = os.environ.get("VIDEO_ENCODE_CODEC", "libx264")
crf = int(os.environ.get("VIDEO_ENCODE_CRF", "23"))
fps = int(os.environ.get("VIDEO_ENCODE_FPS", "24"))
max_w = int(os.environ.get("VIDEO_ENCODE_MAX_WIDTH", "1280"))
max_h = int(os.environ.get("VIDEO_ENCODE_MAX_HEIGHT", "720"))
verbose = ast.literal_eval(os.environ.get("VIDEO_ENCODE_VERBOSE", "False"))
normalize_video(
in_path=in_path,
out_path=out_path,
max_w=max_w,
max_h=max_h,
seek_t=seek_t,
max_time=duration_time_sec,
in_metadata=in_metadata,
codec=codec,
crf=crf,
fps=fps,
verbose=verbose,
)
def get_video_metadata(path: str) -> VideoMetadata:
with av.open(path) as cont:
num_video_streams = len(cont.streams.video)
width, height, fps = None, None, None
video_duration_sec = 0
container_duration_sec = float((cont.duration or 0) / av.time_base)
video_start_time = 0.0
rotation_deg = 0
num_video_frames = 0
if num_video_streams > 0:
video_stream = cont.streams.video[0]
assert video_stream.time_base is not None
# for rotation, see: https://github.com/PyAV-Org/PyAV/pull/1249
rotation_deg = video_stream.side_data.get("DISPLAYMATRIX", 0)
num_video_frames = video_stream.frames
video_start_time = float(video_stream.start_time * video_stream.time_base)
width, height = video_stream.width, video_stream.height
fps = float(video_stream.guessed_rate)
fps_avg = video_stream.average_rate
if video_stream.duration is not None:
video_duration_sec = float(
video_stream.duration * video_stream.time_base
)
if fps is None:
fps = float(fps_avg)
if not math.isnan(rotation_deg) and int(rotation_deg) in (
90,
-90,
270,
-270,
):
width, height = height, width
duration_sec = max(container_duration_sec, video_duration_sec)
return VideoMetadata(
duration_sec=duration_sec,
container_duration_sec=container_duration_sec,
video_duration_sec=video_duration_sec,
video_start_time=video_start_time,
fps=fps,
width=width,
height=height,
num_video_streams=num_video_streams,
num_video_frames=num_video_frames,
)
def normalize_video(
in_path: str,
out_path: str,
max_w: int,
max_h: int,
seek_t: float,
max_time: float,
in_metadata: Optional[VideoMetadata],
codec: str = "libx264",
crf: int = 23,
fps: int = 24,
verbose: bool = False,
):
if in_metadata is None:
in_metadata = get_video_metadata(in_path)
assert in_metadata.num_video_streams > 0, "no video stream present"
w, h = in_metadata.width, in_metadata.height
assert w is not None, "width not available"
assert h is not None, "height not available"
# rescale to max_w:max_h if needed & preserve aspect ratio
r = w / h
if r < 1:
h = min(720, h)
w = h * r
else:
w = min(1280, w)
h = w / r
# h264 cannot encode w/ odd dimensions
w = int(w)
h = int(h)
if w % 2 != 0:
w += 1
if h % 2 != 0:
h += 1
ffmpeg = shutil.which("ffmpeg")
cmd = [
ffmpeg,
"-threads",
f"{FFMPEG_NUM_THREADS}", # global threads
"-ss",
f"{seek_t:.2f}",
"-t",
f"{max_time:.2f}",
"-i",
in_path,
"-threads",
f"{FFMPEG_NUM_THREADS}", # decode (or filter..?) threads
"-vf",
f"fps={fps},scale={w}:{h},setsar=1:1",
"-c:v",
codec,
"-crf",
f"{crf}",
"-pix_fmt",
"yuv420p",
"-threads",
f"{FFMPEG_NUM_THREADS}", # encode threads
out_path,
"-y",
]
if verbose:
print(" ".join(cmd))
subprocess.call(
cmd,
stdout=None if verbose else subprocess.DEVNULL,
stderr=None if verbose else subprocess.DEVNULL,
)

View File

@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from dataclasses_json import dataclass_json
from torch import Tensor
@dataclass_json
@dataclass
class Mask:
size: List[int]
counts: str
@dataclass_json
@dataclass
class BaseRequest:
type: str
@dataclass_json
@dataclass
class StartSessionRequest(BaseRequest):
type: str
path: str
session_id: Optional[str] = None
@dataclass_json
@dataclass
class SaveSessionRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class LoadSessionRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class RenewSessionRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class CloseSessionRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class AddPointsRequest(BaseRequest):
type: str
session_id: str
frame_index: int
clear_old_points: bool
object_id: int
labels: List[int]
points: List[List[float]]
@dataclass_json
@dataclass
class AddMaskRequest(BaseRequest):
type: str
session_id: str
frame_index: int
object_id: int
mask: Mask
@dataclass_json
@dataclass
class ClearPointsInFrameRequest(BaseRequest):
type: str
session_id: str
frame_index: int
object_id: int
@dataclass_json
@dataclass
class ClearPointsInVideoRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class RemoveObjectRequest(BaseRequest):
type: str
session_id: str
object_id: int
@dataclass_json
@dataclass
class PropagateInVideoRequest(BaseRequest):
type: str
session_id: str
start_frame_index: int
@dataclass_json
@dataclass
class CancelPropagateInVideoRequest(BaseRequest):
type: str
session_id: str
@dataclass_json
@dataclass
class StartSessionResponse:
session_id: str
@dataclass_json
@dataclass
class SaveSessionResponse:
session_id: str
@dataclass_json
@dataclass
class LoadSessionResponse:
session_id: str
@dataclass_json
@dataclass
class RenewSessionResponse:
session_id: str
@dataclass_json
@dataclass
class CloseSessionResponse:
success: bool
@dataclass_json
@dataclass
class ClearPointsInVideoResponse:
success: bool
@dataclass_json
@dataclass
class PropagateDataValue:
object_id: int
mask: Mask
@dataclass_json
@dataclass
class PropagateDataResponse:
frame_index: int
results: List[PropagateDataValue]
@dataclass_json
@dataclass
class RemoveObjectResponse:
results: List[PropagateDataResponse]
@dataclass_json
@dataclass
class CancelPorpagateResponse:
success: bool
@dataclass_json
@dataclass
class InferenceSession:
start_time: float
last_use_time: float
session_id: str
state: Dict[str, Dict[str, Union[Tensor, Dict[int, Tensor]]]]

View File

@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Union
class MultipartResponseBuilder:
message: bytes
def __init__(self, boundary: str) -> None:
self.message = b"--" + boundary.encode("utf-8") + b"\r\n"
@classmethod
def build(
cls, boundary: str, headers: Dict[str, str], body: Union[str, bytes]
) -> "MultipartResponseBuilder":
builder = cls(boundary=boundary)
for k, v in headers.items():
builder.__append_header(key=k, value=v)
if isinstance(body, bytes):
builder.__append_body(body)
elif isinstance(body, str):
builder.__append_body(body.encode("utf-8"))
else:
raise ValueError(
f"body needs to be of type bytes or str but got {type(body)}"
)
return builder
def get_message(self) -> bytes:
return self.message
def __append_header(self, key: str, value: str) -> "MultipartResponseBuilder":
self.message += key.encode("utf-8") + b": " + value.encode("utf-8") + b"\r\n"
return self
def __close_header(self) -> "MultipartResponseBuilder":
self.message += b"\r\n"
return self
def __append_body(self, body: bytes) -> "MultipartResponseBuilder":
self.__append_header(key="Content-Length", value=str(len(body)))
self.__close_header()
self.message += body
return self

View File

@@ -0,0 +1,427 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import uuid
from pathlib import Path
from threading import Lock
from typing import Any, Dict, Generator, List
import numpy as np
import torch
from app_conf import APP_ROOT, MODEL_SIZE
from inference.data_types import (
AddMaskRequest,
AddPointsRequest,
CancelPorpagateResponse,
CancelPropagateInVideoRequest,
ClearPointsInFrameRequest,
ClearPointsInVideoRequest,
ClearPointsInVideoResponse,
CloseSessionRequest,
CloseSessionResponse,
Mask,
PropagateDataResponse,
PropagateDataValue,
PropagateInVideoRequest,
RemoveObjectRequest,
RemoveObjectResponse,
StartSessionRequest,
StartSessionResponse,
)
from pycocotools.mask import decode as decode_masks, encode as encode_masks
from sam2.build_sam import build_sam2_video_predictor
logger = logging.getLogger(__name__)
class InferenceAPI:
def __init__(self) -> None:
super(InferenceAPI, self).__init__()
self.session_states: Dict[str, Any] = {}
self.score_thresh = 0
if MODEL_SIZE == "tiny":
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
elif MODEL_SIZE == "small":
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
elif MODEL_SIZE == "large":
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
else: # base_plus (default)
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# select the device for computation
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
if force_cpu_device:
logger.info("forcing CPU device for SAM 2 demo")
if torch.cuda.is_available() and not force_cpu_device:
device = torch.device("cuda")
elif torch.backends.mps.is_available() and not force_cpu_device:
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"using device: {device}")
if device.type == "cuda":
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
logging.warning(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)
self.device = device
self.predictor = build_sam2_video_predictor(
model_cfg, checkpoint, device=device
)
self.inference_lock = Lock()
def autocast_context(self):
if self.device.type == "cuda":
return torch.autocast("cuda", dtype=torch.bfloat16)
else:
return contextlib.nullcontext()
def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
with self.autocast_context(), self.inference_lock:
session_id = str(uuid.uuid4())
# for MPS devices, we offload the video frames to CPU by default to avoid
# memory fragmentation in MPS (which sometimes crashes the entire process)
offload_video_to_cpu = self.device.type == "mps"
inference_state = self.predictor.init_state(
request.path,
offload_video_to_cpu=offload_video_to_cpu,
)
self.session_states[session_id] = {
"canceled": False,
"state": inference_state,
}
return StartSessionResponse(session_id=session_id)
def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
is_successful = self.__clear_session_state(request.session_id)
return CloseSessionResponse(success=is_successful)
def add_points(
self, request: AddPointsRequest, test: str = ""
) -> PropagateDataResponse:
with self.autocast_context(), self.inference_lock:
session = self.__get_session(request.session_id)
inference_state = session["state"]
frame_idx = request.frame_index
obj_id = request.object_id
points = request.points
labels = request.labels
clear_old_points = request.clear_old_points
# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
points=points,
labels=labels,
clear_old_points=clear_old_points,
normalize_coords=False,
)
masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
rle_mask_list = self.__get_rle_mask_list(
object_ids=object_ids, masks=masks_binary
)
return PropagateDataResponse(
frame_index=frame_idx,
results=rle_mask_list,
)
def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
"""
Add new points on a specific video frame.
- mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
Note: providing an input mask would overwrite any previous input points on this frame.
"""
with self.autocast_context(), self.inference_lock:
session_id = request.session_id
frame_idx = request.frame_index
obj_id = request.object_id
rle_mask = {
"counts": request.mask.counts,
"size": request.mask.size,
}
mask = decode_masks(rle_mask)
logger.info(
f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
)
session = self.__get_session(session_id)
inference_state = session["state"]
frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
mask=torch.tensor(mask > 0),
)
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
rle_mask_list = self.__get_rle_mask_list(
object_ids=obj_ids, masks=masks_binary
)
return PropagateDataResponse(
frame_index=frame_idx,
results=rle_mask_list,
)
def clear_points_in_frame(
self, request: ClearPointsInFrameRequest
) -> PropagateDataResponse:
"""
Remove all input points in a specific frame.
"""
with self.autocast_context(), self.inference_lock:
session_id = request.session_id
frame_idx = request.frame_index
obj_id = request.object_id
logger.info(
f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
)
session = self.__get_session(session_id)
inference_state = session["state"]
frame_idx, obj_ids, video_res_masks = (
self.predictor.clear_all_prompts_in_frame(
inference_state, frame_idx, obj_id
)
)
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
rle_mask_list = self.__get_rle_mask_list(
object_ids=obj_ids, masks=masks_binary
)
return PropagateDataResponse(
frame_index=frame_idx,
results=rle_mask_list,
)
def clear_points_in_video(
self, request: ClearPointsInVideoRequest
) -> ClearPointsInVideoResponse:
"""
Remove all input points in all frames throughout the video.
"""
with self.autocast_context(), self.inference_lock:
session_id = request.session_id
logger.info(f"clear all inputs across the video in session {session_id}")
session = self.__get_session(session_id)
inference_state = session["state"]
self.predictor.reset_state(inference_state)
return ClearPointsInVideoResponse(success=True)
def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
"""
Remove an object id from the tracking state.
"""
with self.autocast_context(), self.inference_lock:
session_id = request.session_id
obj_id = request.object_id
logger.info(f"remove object in session {session_id}: {obj_id=}")
session = self.__get_session(session_id)
inference_state = session["state"]
new_obj_ids, updated_frames = self.predictor.remove_object(
inference_state, obj_id
)
results = []
for frame_index, video_res_masks in updated_frames:
masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
rle_mask_list = self.__get_rle_mask_list(
object_ids=new_obj_ids, masks=masks
)
results.append(
PropagateDataResponse(
frame_index=frame_index,
results=rle_mask_list,
)
)
return RemoveObjectResponse(results=results)
def propagate_in_video(
self, request: PropagateInVideoRequest
) -> Generator[PropagateDataResponse, None, None]:
session_id = request.session_id
start_frame_idx = request.start_frame_index
propagation_direction = "both"
max_frame_num_to_track = None
"""
Propagate existing input points in all frames to track the object across video.
"""
# Note that as this method is a generator, we also need to use autocast_context
# in caller to this method to ensure that it's called under the correct context
# (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
with self.autocast_context(), self.inference_lock:
logger.info(
f"propagate in video in session {session_id}: "
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
)
try:
session = self.__get_session(session_id)
session["canceled"] = False
inference_state = session["state"]
if propagation_direction not in ["both", "forward", "backward"]:
raise ValueError(
f"invalid propagation direction: {propagation_direction}"
)
# First doing the forward propagation
if propagation_direction in ["both", "forward"]:
for outputs in self.predictor.propagate_in_video(
inference_state=inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=False,
):
if session["canceled"]:
return None
frame_idx, obj_ids, video_res_masks = outputs
masks_binary = (
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
)
rle_mask_list = self.__get_rle_mask_list(
object_ids=obj_ids, masks=masks_binary
)
yield PropagateDataResponse(
frame_index=frame_idx,
results=rle_mask_list,
)
# Then doing the backward propagation (reverse in time)
if propagation_direction in ["both", "backward"]:
for outputs in self.predictor.propagate_in_video(
inference_state=inference_state,
start_frame_idx=start_frame_idx,
max_frame_num_to_track=max_frame_num_to_track,
reverse=True,
):
if session["canceled"]:
return None
frame_idx, obj_ids, video_res_masks = outputs
masks_binary = (
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
)
rle_mask_list = self.__get_rle_mask_list(
object_ids=obj_ids, masks=masks_binary
)
yield PropagateDataResponse(
frame_index=frame_idx,
results=rle_mask_list,
)
finally:
# Log upon completion (so that e.g. we can see if two propagations happen in parallel).
# Using `finally` here to log even when the tracking is aborted with GeneratorExit.
logger.info(
f"propagation ended in session {session_id}; {self.__get_session_stats()}"
)
def cancel_propagate_in_video(
self, request: CancelPropagateInVideoRequest
) -> CancelPorpagateResponse:
session = self.__get_session(request.session_id)
session["canceled"] = True
return CancelPorpagateResponse(success=True)
def __get_rle_mask_list(
self, object_ids: List[int], masks: np.ndarray
) -> List[PropagateDataValue]:
"""
Return a list of data values, i.e. list of object/mask combos.
"""
return [
self.__get_mask_for_object(object_id=object_id, mask=mask)
for object_id, mask in zip(object_ids, masks)
]
def __get_mask_for_object(
self, object_id: int, mask: np.ndarray
) -> PropagateDataValue:
"""
Create a data value for an object/mask combo.
"""
mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
mask_rle["counts"] = mask_rle["counts"].decode()
return PropagateDataValue(
object_id=object_id,
mask=Mask(
size=mask_rle["size"],
counts=mask_rle["counts"],
),
)
def __get_session(self, session_id: str):
session = self.session_states.get(session_id, None)
if session is None:
raise RuntimeError(
f"Cannot find session {session_id}; it might have expired"
)
return session
def __get_session_stats(self):
"""Get a statistics string for live sessions and their GPU usage."""
# print both the session ids and their video frame numbers
live_session_strs = [
f"'{session_id}' ({session['state']['num_frames']} frames, "
f"{len(session['state']['obj_ids'])} objects)"
for session_id, session in self.session_states.items()
]
session_stats_str = (
"Test String Here - -"
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
)
return session_stats_str
def __clear_session_state(self, session_id: str) -> bool:
session = self.session_states.pop(session_id, None)
if session is None:
logger.warning(
f"cannot close session {session_id} as it does not exist (it might have expired); "
f"{self.__get_session_stats()}"
)
return False
else:
logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
return True