SAM2.1
SAM2.1 checkpoints + training code + Demo
This commit is contained in:
136
demo/backend/server/app.py
Normal file
136
demo/backend/server/app.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Generator
|
||||
|
||||
from app_conf import (
|
||||
GALLERY_PATH,
|
||||
GALLERY_PREFIX,
|
||||
POSTERS_PATH,
|
||||
POSTERS_PREFIX,
|
||||
UPLOADS_PATH,
|
||||
UPLOADS_PREFIX,
|
||||
)
|
||||
from data.loader import preload_data
|
||||
from data.schema import schema
|
||||
from data.store import set_videos
|
||||
from flask import Flask, make_response, Request, request, Response, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
|
||||
from inference.multipart import MultipartResponseBuilder
|
||||
from inference.predictor import InferenceAPI
|
||||
from strawberry.flask.views import GraphQLView
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__)
|
||||
cors = CORS(app, supports_credentials=True)
|
||||
|
||||
videos = preload_data()
|
||||
set_videos(videos)
|
||||
|
||||
inference_api = InferenceAPI()
|
||||
|
||||
|
||||
@app.route("/healthy")
|
||||
def healthy() -> Response:
|
||||
return make_response("OK", 200)
|
||||
|
||||
|
||||
@app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
|
||||
def send_gallery_video(path: str) -> Response:
|
||||
try:
|
||||
return send_from_directory(
|
||||
GALLERY_PATH,
|
||||
path,
|
||||
)
|
||||
except:
|
||||
raise ValueError("resource not found")
|
||||
|
||||
|
||||
@app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
|
||||
def send_poster_image(path: str) -> Response:
|
||||
try:
|
||||
return send_from_directory(
|
||||
POSTERS_PATH,
|
||||
path,
|
||||
)
|
||||
except:
|
||||
raise ValueError("resource not found")
|
||||
|
||||
|
||||
@app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
|
||||
def send_uploaded_video(path: str):
|
||||
try:
|
||||
return send_from_directory(
|
||||
UPLOADS_PATH,
|
||||
path,
|
||||
)
|
||||
except:
|
||||
raise ValueError("resource not found")
|
||||
|
||||
|
||||
# TOOD: Protect route with ToS permission check
|
||||
@app.route("/propagate_in_video", methods=["POST"])
|
||||
def propagate_in_video() -> Response:
|
||||
data = request.json
|
||||
args = {
|
||||
"session_id": data["session_id"],
|
||||
"start_frame_index": data.get("start_frame_index", 0),
|
||||
}
|
||||
|
||||
boundary = "frame"
|
||||
frame = gen_track_with_mask_stream(boundary, **args)
|
||||
return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
|
||||
|
||||
|
||||
def gen_track_with_mask_stream(
|
||||
boundary: str,
|
||||
session_id: str,
|
||||
start_frame_index: int,
|
||||
) -> Generator[bytes, None, None]:
|
||||
with inference_api.autocast_context():
|
||||
request = PropagateInVideoRequest(
|
||||
type="propagate_in_video",
|
||||
session_id=session_id,
|
||||
start_frame_index=start_frame_index,
|
||||
)
|
||||
|
||||
for chunk in inference_api.propagate_in_video(request=request):
|
||||
yield MultipartResponseBuilder.build(
|
||||
boundary=boundary,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Frame-Current": "-1",
|
||||
# Total frames minus the reference frame
|
||||
"Frame-Total": "-1",
|
||||
"Mask-Type": "RLE[]",
|
||||
},
|
||||
body=chunk.to_json().encode("UTF-8"),
|
||||
).get_message()
|
||||
|
||||
|
||||
class MyGraphQLView(GraphQLView):
|
||||
def get_context(self, request: Request, response: Response) -> Any:
|
||||
return {"inference_api": inference_api}
|
||||
|
||||
|
||||
# Add GraphQL route to Flask app.
|
||||
app.add_url_rule(
|
||||
"/graphql",
|
||||
view_func=MyGraphQLView.as_view(
|
||||
"graphql_view",
|
||||
schema=schema,
|
||||
# Disable GET queries
|
||||
# https://strawberry.rocks/docs/operations/deployment
|
||||
# https://strawberry.rocks/docs/integrations/flask
|
||||
allow_queries_via_get=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5000)
|
55
demo/backend/server/app_conf.py
Normal file
55
demo/backend/server/app_conf.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ROOT = os.getenv("APP_ROOT", "/opt/sam2")
|
||||
|
||||
API_URL = os.getenv("API_URL", "http://localhost:7263")
|
||||
|
||||
MODEL_SIZE = os.getenv("MODEL_SIZE", "base_plus")
|
||||
|
||||
logger.info(f"using model size {MODEL_SIZE}")
|
||||
|
||||
FFMPEG_NUM_THREADS = int(os.getenv("FFMPEG_NUM_THREADS", "1"))
|
||||
|
||||
# Path for all data used in API
|
||||
DATA_PATH = Path(os.getenv("DATA_PATH", "/data"))
|
||||
|
||||
# Max duration an uploaded video can have in seconds. The default is 10
|
||||
# seconds.
|
||||
MAX_UPLOAD_VIDEO_DURATION = float(os.environ.get("MAX_UPLOAD_VIDEO_DURATION", "10"))
|
||||
|
||||
# If set, it will define which video is returned by the default video query for
|
||||
# desktop
|
||||
DEFAULT_VIDEO_PATH = os.getenv("DEFAULT_VIDEO_PATH")
|
||||
|
||||
# Prefix for gallery videos
|
||||
GALLERY_PREFIX = "gallery"
|
||||
|
||||
# Path where all gallery videos are stored
|
||||
GALLERY_PATH = DATA_PATH / GALLERY_PREFIX
|
||||
|
||||
# Prefix for uploaded videos
|
||||
UPLOADS_PREFIX = "uploads"
|
||||
|
||||
# Path where all uploaded videos are stored
|
||||
UPLOADS_PATH = DATA_PATH / UPLOADS_PREFIX
|
||||
|
||||
# Prefix for video posters (1st frame of video)
|
||||
POSTERS_PREFIX = "posters"
|
||||
|
||||
# Path where all posters are stored
|
||||
POSTERS_PATH = DATA_PATH / POSTERS_PREFIX
|
||||
|
||||
# Make sure any of those paths exist
|
||||
os.makedirs(DATA_PATH, exist_ok=True)
|
||||
os.makedirs(GALLERY_PATH, exist_ok=True)
|
||||
os.makedirs(UPLOADS_PATH, exist_ok=True)
|
||||
os.makedirs(POSTERS_PATH, exist_ok=True)
|
154
demo/backend/server/data/data_types.py
Normal file
154
demo/backend/server/data/data_types.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import strawberry
|
||||
from app_conf import API_URL
|
||||
from data.resolver import resolve_videos
|
||||
from dataclasses_json import dataclass_json
|
||||
from strawberry import relay
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Video(relay.Node):
|
||||
"""Core type for video."""
|
||||
|
||||
code: relay.NodeID[str]
|
||||
path: str
|
||||
poster_path: Optional[str]
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@strawberry.field
|
||||
def url(self) -> str:
|
||||
return f"{API_URL}/{self.path}"
|
||||
|
||||
@strawberry.field
|
||||
def poster_url(self) -> str:
|
||||
return f"{API_URL}/{self.poster_path}"
|
||||
|
||||
@classmethod
|
||||
def resolve_nodes(
|
||||
cls,
|
||||
*,
|
||||
info: relay.PageInfo,
|
||||
node_ids: Iterable[str],
|
||||
required: bool = False,
|
||||
):
|
||||
return resolve_videos(node_ids, required)
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class RLEMask:
|
||||
"""Core type for Onevision GraphQL RLE mask."""
|
||||
|
||||
size: List[int]
|
||||
counts: str
|
||||
order: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class RLEMaskForObject:
|
||||
"""Type for RLE mask associated with a specific object id."""
|
||||
|
||||
object_id: int
|
||||
rle_mask: RLEMask
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class RLEMaskListOnFrame:
|
||||
"""Type for a list of object-associated RLE masks on a specific video frame."""
|
||||
|
||||
frame_index: int
|
||||
rle_mask_list: List[RLEMaskForObject]
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class StartSessionInput:
|
||||
path: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class StartSession:
|
||||
session_id: str
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class PingInput:
|
||||
session_id: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Pong:
|
||||
success: bool
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class CloseSessionInput:
|
||||
session_id: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class CloseSession:
|
||||
success: bool
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class AddPointsInput:
|
||||
session_id: str
|
||||
frame_index: int
|
||||
clear_old_points: bool
|
||||
object_id: int
|
||||
labels: List[int]
|
||||
points: List[List[float]]
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class ClearPointsInFrameInput:
|
||||
session_id: str
|
||||
frame_index: int
|
||||
object_id: int
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class ClearPointsInVideoInput:
|
||||
session_id: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class ClearPointsInVideo:
|
||||
success: bool
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class RemoveObjectInput:
|
||||
session_id: str
|
||||
object_id: int
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class PropagateInVideoInput:
|
||||
session_id: str
|
||||
start_frame_index: int
|
||||
|
||||
|
||||
@strawberry.input
|
||||
class CancelPropagateInVideoInput:
|
||||
session_id: str
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class CancelPropagateInVideo:
|
||||
success: bool
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class SessionExpiration:
|
||||
session_id: str
|
||||
expiration_time: int
|
||||
max_expiration_time: int
|
||||
ttl: int
|
92
demo/backend/server/data/loader.py
Normal file
92
demo/backend/server/data/loader.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import imagesize
|
||||
from app_conf import GALLERY_PATH, POSTERS_PATH, POSTERS_PREFIX
|
||||
from data.data_types import Video
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def preload_data() -> Dict[str, Video]:
|
||||
"""
|
||||
Preload data including gallery videos and their posters.
|
||||
"""
|
||||
# Dictionaries for videos and datasets on the backend.
|
||||
# Note that since Python 3.7, dictionaries preserve their insert order, so
|
||||
# when looping over its `.values()`, elements inserted first also appear first.
|
||||
# https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6
|
||||
all_videos = {}
|
||||
|
||||
video_path_pattern = os.path.join(GALLERY_PATH, "**/*.mp4")
|
||||
video_paths = glob(video_path_pattern, recursive=True)
|
||||
|
||||
for p in tqdm(video_paths):
|
||||
video = get_video(p, GALLERY_PATH)
|
||||
all_videos[video.code] = video
|
||||
|
||||
return all_videos
|
||||
|
||||
|
||||
def get_video(
|
||||
filepath: os.PathLike,
|
||||
absolute_path: Path,
|
||||
file_key: Optional[str] = None,
|
||||
generate_poster: bool = True,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> Video:
|
||||
"""
|
||||
Get video object given
|
||||
"""
|
||||
# Use absolute_path to include the parent directory in the video
|
||||
video_path = os.path.relpath(filepath, absolute_path.parent)
|
||||
poster_path = None
|
||||
if generate_poster:
|
||||
poster_id = os.path.splitext(os.path.basename(filepath))[0]
|
||||
poster_filename = f"{str(poster_id)}.jpg"
|
||||
poster_path = f"{POSTERS_PREFIX}/{poster_filename}"
|
||||
|
||||
# Extract the first frame from video
|
||||
poster_output_path = os.path.join(POSTERS_PATH, poster_filename)
|
||||
ffmpeg = shutil.which("ffmpeg")
|
||||
subprocess.call(
|
||||
[
|
||||
ffmpeg,
|
||||
"-y",
|
||||
"-i",
|
||||
str(filepath),
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-frames:v",
|
||||
"1",
|
||||
"-update",
|
||||
"1",
|
||||
"-strict",
|
||||
"unofficial",
|
||||
str(poster_output_path),
|
||||
],
|
||||
stdout=None if verbose else subprocess.DEVNULL,
|
||||
stderr=None if verbose else subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# Extract video width and height from poster. This is important to optimize
|
||||
# rendering previews in the mosaic video preview.
|
||||
width, height = imagesize.get(poster_output_path)
|
||||
|
||||
return Video(
|
||||
code=video_path,
|
||||
path=video_path if file_key is None else file_key,
|
||||
poster_path=poster_path,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
18
demo/backend/server/data/resolver.py
Normal file
18
demo/backend/server/data/resolver.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def resolve_videos(node_ids: Iterable[str], required: bool = False):
|
||||
"""
|
||||
Resolve videos given node ids.
|
||||
"""
|
||||
from data.store import get_videos
|
||||
|
||||
all_videos = get_videos()
|
||||
return [
|
||||
all_videos[nid] if required else all_videos.get(nid, None) for nid in node_ids
|
||||
]
|
357
demo/backend/server/data/schema.py
Normal file
357
demo/backend/server/data/schema.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import av
|
||||
import strawberry
|
||||
from app_conf import (
|
||||
DATA_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
MAX_UPLOAD_VIDEO_DURATION,
|
||||
UPLOADS_PATH,
|
||||
UPLOADS_PREFIX,
|
||||
)
|
||||
from data.data_types import (
|
||||
AddPointsInput,
|
||||
CancelPropagateInVideo,
|
||||
CancelPropagateInVideoInput,
|
||||
ClearPointsInFrameInput,
|
||||
ClearPointsInVideo,
|
||||
ClearPointsInVideoInput,
|
||||
CloseSession,
|
||||
CloseSessionInput,
|
||||
RemoveObjectInput,
|
||||
RLEMask,
|
||||
RLEMaskForObject,
|
||||
RLEMaskListOnFrame,
|
||||
StartSession,
|
||||
StartSessionInput,
|
||||
Video,
|
||||
)
|
||||
from data.loader import get_video
|
||||
from data.store import get_videos
|
||||
from data.transcoder import get_video_metadata, transcode, VideoMetadata
|
||||
from inference.data_types import (
|
||||
AddPointsRequest,
|
||||
CancelPropagateInVideoRequest,
|
||||
CancelPropagateInVideoRequest,
|
||||
ClearPointsInFrameRequest,
|
||||
ClearPointsInVideoRequest,
|
||||
CloseSessionRequest,
|
||||
RemoveObjectRequest,
|
||||
StartSessionRequest,
|
||||
)
|
||||
from inference.predictor import InferenceAPI
|
||||
from strawberry import relay
|
||||
from strawberry.file_uploads import Upload
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Query:
|
||||
|
||||
@strawberry.field
|
||||
def default_video(self) -> Video:
|
||||
"""
|
||||
Return the default video.
|
||||
|
||||
The default video can be set with the DEFAULT_VIDEO_PATH environment
|
||||
variable. It will return the video that matches this path. If no video
|
||||
is found, it will return the first video.
|
||||
"""
|
||||
all_videos = get_videos()
|
||||
|
||||
# Find the video that matches the default path and return that as
|
||||
# default video.
|
||||
for _, v in all_videos.items():
|
||||
if v.path == DEFAULT_VIDEO_PATH:
|
||||
return v
|
||||
|
||||
# Fallback is returning the first video
|
||||
return next(iter(all_videos.values()))
|
||||
|
||||
@relay.connection(relay.ListConnection[Video])
|
||||
def videos(
|
||||
self,
|
||||
) -> Iterable[Video]:
|
||||
"""
|
||||
Return all available videos.
|
||||
"""
|
||||
all_videos = get_videos()
|
||||
return all_videos.values()
|
||||
|
||||
|
||||
@strawberry.type
|
||||
class Mutation:
|
||||
|
||||
@strawberry.mutation
|
||||
def upload_video(
|
||||
self,
|
||||
file: Upload,
|
||||
start_time_sec: Optional[float] = None,
|
||||
duration_time_sec: Optional[float] = None,
|
||||
) -> Video:
|
||||
"""
|
||||
Receive a video file and store it in the configured S3 bucket.
|
||||
"""
|
||||
max_time = MAX_UPLOAD_VIDEO_DURATION
|
||||
filepath, file_key, vm = process_video(
|
||||
file,
|
||||
max_time=max_time,
|
||||
start_time_sec=start_time_sec,
|
||||
duration_time_sec=duration_time_sec,
|
||||
)
|
||||
|
||||
video = get_video(
|
||||
filepath,
|
||||
UPLOADS_PATH,
|
||||
file_key=file_key,
|
||||
width=vm.width,
|
||||
height=vm.height,
|
||||
generate_poster=False,
|
||||
)
|
||||
return video
|
||||
|
||||
@strawberry.mutation
|
||||
def start_session(
|
||||
self, input: StartSessionInput, info: strawberry.Info
|
||||
) -> StartSession:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = StartSessionRequest(
|
||||
type="start_session",
|
||||
path=f"{DATA_PATH}/{input.path}",
|
||||
)
|
||||
|
||||
response = inference_api.start_session(request=request)
|
||||
|
||||
return StartSession(session_id=response.session_id)
|
||||
|
||||
@strawberry.mutation
|
||||
def close_session(
|
||||
self, input: CloseSessionInput, info: strawberry.Info
|
||||
) -> CloseSession:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = CloseSessionRequest(
|
||||
type="close_session",
|
||||
session_id=input.session_id,
|
||||
)
|
||||
response = inference_api.close_session(request)
|
||||
return CloseSession(success=response.success)
|
||||
|
||||
@strawberry.mutation
|
||||
def add_points(
|
||||
self, input: AddPointsInput, info: strawberry.Info
|
||||
) -> RLEMaskListOnFrame:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = AddPointsRequest(
|
||||
type="add_points",
|
||||
session_id=input.session_id,
|
||||
frame_index=input.frame_index,
|
||||
object_id=input.object_id,
|
||||
points=input.points,
|
||||
labels=input.labels,
|
||||
clear_old_points=input.clear_old_points,
|
||||
)
|
||||
reponse = inference_api.add_points(request)
|
||||
|
||||
return RLEMaskListOnFrame(
|
||||
frame_index=reponse.frame_index,
|
||||
rle_mask_list=[
|
||||
RLEMaskForObject(
|
||||
object_id=r.object_id,
|
||||
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
|
||||
)
|
||||
for r in reponse.results
|
||||
],
|
||||
)
|
||||
|
||||
@strawberry.mutation
|
||||
def remove_object(
|
||||
self, input: RemoveObjectInput, info: strawberry.Info
|
||||
) -> List[RLEMaskListOnFrame]:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = RemoveObjectRequest(
|
||||
type="remove_object", session_id=input.session_id, object_id=input.object_id
|
||||
)
|
||||
|
||||
response = inference_api.remove_object(request)
|
||||
|
||||
return [
|
||||
RLEMaskListOnFrame(
|
||||
frame_index=res.frame_index,
|
||||
rle_mask_list=[
|
||||
RLEMaskForObject(
|
||||
object_id=r.object_id,
|
||||
rle_mask=RLEMask(
|
||||
counts=r.mask.counts, size=r.mask.size, order="F"
|
||||
),
|
||||
)
|
||||
for r in res.results
|
||||
],
|
||||
)
|
||||
for res in response.results
|
||||
]
|
||||
|
||||
@strawberry.mutation
|
||||
def clear_points_in_frame(
|
||||
self, input: ClearPointsInFrameInput, info: strawberry.Info
|
||||
) -> RLEMaskListOnFrame:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = ClearPointsInFrameRequest(
|
||||
type="clear_points_in_frame",
|
||||
session_id=input.session_id,
|
||||
frame_index=input.frame_index,
|
||||
object_id=input.object_id,
|
||||
)
|
||||
|
||||
response = inference_api.clear_points_in_frame(request)
|
||||
|
||||
return RLEMaskListOnFrame(
|
||||
frame_index=response.frame_index,
|
||||
rle_mask_list=[
|
||||
RLEMaskForObject(
|
||||
object_id=r.object_id,
|
||||
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
|
||||
)
|
||||
for r in response.results
|
||||
],
|
||||
)
|
||||
|
||||
@strawberry.mutation
|
||||
def clear_points_in_video(
|
||||
self, input: ClearPointsInVideoInput, info: strawberry.Info
|
||||
) -> ClearPointsInVideo:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = ClearPointsInVideoRequest(
|
||||
type="clear_points_in_video",
|
||||
session_id=input.session_id,
|
||||
)
|
||||
response = inference_api.clear_points_in_video(request)
|
||||
return ClearPointsInVideo(success=response.success)
|
||||
|
||||
@strawberry.mutation
|
||||
def cancel_propagate_in_video(
|
||||
self, input: CancelPropagateInVideoInput, info: strawberry.Info
|
||||
) -> CancelPropagateInVideo:
|
||||
inference_api: InferenceAPI = info.context["inference_api"]
|
||||
|
||||
request = CancelPropagateInVideoRequest(
|
||||
type="cancel_propagate_in_video",
|
||||
session_id=input.session_id,
|
||||
)
|
||||
response = inference_api.cancel_propagate_in_video(request)
|
||||
return CancelPropagateInVideo(success=response.success)
|
||||
|
||||
|
||||
def get_file_hash(video_path_or_file) -> str:
|
||||
if isinstance(video_path_or_file, str):
|
||||
with open(video_path_or_file, "rb") as in_f:
|
||||
result = hashlib.sha256(in_f.read()).hexdigest()
|
||||
else:
|
||||
video_path_or_file.seek(0)
|
||||
result = hashlib.sha256(video_path_or_file.read()).hexdigest()
|
||||
return result
|
||||
|
||||
|
||||
def _get_start_sec_duration_sec(
|
||||
start_time_sec: Union[float, None],
|
||||
duration_time_sec: Union[float, None],
|
||||
max_time: float,
|
||||
) -> Tuple[float, float]:
|
||||
default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0"))
|
||||
if start_time_sec is None:
|
||||
start_time_sec = default_seek_t
|
||||
|
||||
if duration_time_sec is not None:
|
||||
duration_time_sec = min(duration_time_sec, max_time)
|
||||
else:
|
||||
duration_time_sec = max_time
|
||||
return start_time_sec, duration_time_sec
|
||||
|
||||
|
||||
def process_video(
|
||||
file: Upload,
|
||||
max_time: float,
|
||||
start_time_sec: Optional[float] = None,
|
||||
duration_time_sec: Optional[float] = None,
|
||||
) -> Tuple[Optional[str], str, str, VideoMetadata]:
|
||||
"""
|
||||
Process file upload including video trimming and content moderation checks.
|
||||
|
||||
Returns the filepath, s3_file_key, hash & video metaedata as a tuple.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
in_path = f"{tempdir}/in.mp4"
|
||||
out_path = f"{tempdir}/out.mp4"
|
||||
with open(in_path, "wb") as in_f:
|
||||
in_f.write(file.read())
|
||||
|
||||
try:
|
||||
video_metadata = get_video_metadata(in_path)
|
||||
except av.InvalidDataError:
|
||||
raise Exception("not valid video file")
|
||||
|
||||
if video_metadata.num_video_streams == 0:
|
||||
raise Exception("video container does not contain a video stream")
|
||||
if video_metadata.width is None or video_metadata.height is None:
|
||||
raise Exception("video container does not contain width or height metadata")
|
||||
|
||||
if video_metadata.duration_sec in (None, 0):
|
||||
raise Exception("video container does time duration metadata")
|
||||
|
||||
start_time_sec, duration_time_sec = _get_start_sec_duration_sec(
|
||||
max_time=max_time,
|
||||
start_time_sec=start_time_sec,
|
||||
duration_time_sec=duration_time_sec,
|
||||
)
|
||||
|
||||
# Transcode video to make sure videos returned to the app are all in
|
||||
# the same format, duration, resolution, fps.
|
||||
transcode(
|
||||
in_path,
|
||||
out_path,
|
||||
video_metadata,
|
||||
seek_t=start_time_sec,
|
||||
duration_time_sec=duration_time_sec,
|
||||
)
|
||||
|
||||
os.remove(in_path) # don't need original video now
|
||||
|
||||
out_video_metadata = get_video_metadata(out_path)
|
||||
if out_video_metadata.num_video_frames == 0:
|
||||
raise Exception(
|
||||
"transcode produced empty video; check seek time or your input video"
|
||||
)
|
||||
|
||||
filepath = None
|
||||
file_key = None
|
||||
with open(out_path, "rb") as file_data:
|
||||
file_hash = get_file_hash(file_data)
|
||||
file_data.seek(0)
|
||||
|
||||
file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4"
|
||||
filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4")
|
||||
|
||||
assert filepath is not None and file_key is not None
|
||||
shutil.move(out_path, filepath)
|
||||
|
||||
return filepath, file_key, out_video_metadata
|
||||
|
||||
|
||||
schema = strawberry.Schema(
|
||||
query=Query,
|
||||
mutation=Mutation,
|
||||
)
|
28
demo/backend/server/data/store.py
Normal file
28
demo/backend/server/data/store.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from data.data_types import Video
|
||||
|
||||
ALL_VIDEOS: Dict[str, Video] = []
|
||||
|
||||
|
||||
def set_videos(videos: Dict[str, Video]) -> None:
|
||||
"""
|
||||
Set the videos available in the backend. The data is kept in-memory, but a future change could replace the
|
||||
in-memory storage with a database backend. This would also be more efficient when querying videos given a
|
||||
dataset name etc.
|
||||
"""
|
||||
global ALL_VIDEOS
|
||||
ALL_VIDEOS = videos
|
||||
|
||||
|
||||
def get_videos() -> Dict[str, Video]:
|
||||
"""
|
||||
Return the videos available in the backend.
|
||||
"""
|
||||
global ALL_VIDEOS
|
||||
return ALL_VIDEOS
|
186
demo/backend/server/data/transcoder.py
Normal file
186
demo/backend/server/data/transcoder.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
from app_conf import FFMPEG_NUM_THREADS
|
||||
from dataclasses_json import dataclass_json
|
||||
|
||||
TRANSCODE_VERSION = 1
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
duration_sec: Optional[float]
|
||||
video_duration_sec: Optional[float]
|
||||
container_duration_sec: Optional[float]
|
||||
fps: Optional[float]
|
||||
width: Optional[int]
|
||||
height: Optional[int]
|
||||
num_video_frames: int
|
||||
num_video_streams: int
|
||||
video_start_time: float
|
||||
|
||||
|
||||
def transcode(
|
||||
in_path: str,
|
||||
out_path: str,
|
||||
in_metadata: Optional[VideoMetadata],
|
||||
seek_t: float,
|
||||
duration_time_sec: float,
|
||||
):
|
||||
codec = os.environ.get("VIDEO_ENCODE_CODEC", "libx264")
|
||||
crf = int(os.environ.get("VIDEO_ENCODE_CRF", "23"))
|
||||
fps = int(os.environ.get("VIDEO_ENCODE_FPS", "24"))
|
||||
max_w = int(os.environ.get("VIDEO_ENCODE_MAX_WIDTH", "1280"))
|
||||
max_h = int(os.environ.get("VIDEO_ENCODE_MAX_HEIGHT", "720"))
|
||||
verbose = ast.literal_eval(os.environ.get("VIDEO_ENCODE_VERBOSE", "False"))
|
||||
|
||||
normalize_video(
|
||||
in_path=in_path,
|
||||
out_path=out_path,
|
||||
max_w=max_w,
|
||||
max_h=max_h,
|
||||
seek_t=seek_t,
|
||||
max_time=duration_time_sec,
|
||||
in_metadata=in_metadata,
|
||||
codec=codec,
|
||||
crf=crf,
|
||||
fps=fps,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def get_video_metadata(path: str) -> VideoMetadata:
|
||||
with av.open(path) as cont:
|
||||
num_video_streams = len(cont.streams.video)
|
||||
width, height, fps = None, None, None
|
||||
video_duration_sec = 0
|
||||
container_duration_sec = float((cont.duration or 0) / av.time_base)
|
||||
video_start_time = 0.0
|
||||
rotation_deg = 0
|
||||
num_video_frames = 0
|
||||
if num_video_streams > 0:
|
||||
video_stream = cont.streams.video[0]
|
||||
assert video_stream.time_base is not None
|
||||
|
||||
# for rotation, see: https://github.com/PyAV-Org/PyAV/pull/1249
|
||||
rotation_deg = video_stream.side_data.get("DISPLAYMATRIX", 0)
|
||||
num_video_frames = video_stream.frames
|
||||
video_start_time = float(video_stream.start_time * video_stream.time_base)
|
||||
width, height = video_stream.width, video_stream.height
|
||||
fps = float(video_stream.guessed_rate)
|
||||
fps_avg = video_stream.average_rate
|
||||
if video_stream.duration is not None:
|
||||
video_duration_sec = float(
|
||||
video_stream.duration * video_stream.time_base
|
||||
)
|
||||
if fps is None:
|
||||
fps = float(fps_avg)
|
||||
|
||||
if not math.isnan(rotation_deg) and int(rotation_deg) in (
|
||||
90,
|
||||
-90,
|
||||
270,
|
||||
-270,
|
||||
):
|
||||
width, height = height, width
|
||||
|
||||
duration_sec = max(container_duration_sec, video_duration_sec)
|
||||
|
||||
return VideoMetadata(
|
||||
duration_sec=duration_sec,
|
||||
container_duration_sec=container_duration_sec,
|
||||
video_duration_sec=video_duration_sec,
|
||||
video_start_time=video_start_time,
|
||||
fps=fps,
|
||||
width=width,
|
||||
height=height,
|
||||
num_video_streams=num_video_streams,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
|
||||
|
||||
def normalize_video(
|
||||
in_path: str,
|
||||
out_path: str,
|
||||
max_w: int,
|
||||
max_h: int,
|
||||
seek_t: float,
|
||||
max_time: float,
|
||||
in_metadata: Optional[VideoMetadata],
|
||||
codec: str = "libx264",
|
||||
crf: int = 23,
|
||||
fps: int = 24,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if in_metadata is None:
|
||||
in_metadata = get_video_metadata(in_path)
|
||||
|
||||
assert in_metadata.num_video_streams > 0, "no video stream present"
|
||||
|
||||
w, h = in_metadata.width, in_metadata.height
|
||||
assert w is not None, "width not available"
|
||||
assert h is not None, "height not available"
|
||||
|
||||
# rescale to max_w:max_h if needed & preserve aspect ratio
|
||||
r = w / h
|
||||
if r < 1:
|
||||
h = min(720, h)
|
||||
w = h * r
|
||||
else:
|
||||
w = min(1280, w)
|
||||
h = w / r
|
||||
|
||||
# h264 cannot encode w/ odd dimensions
|
||||
w = int(w)
|
||||
h = int(h)
|
||||
if w % 2 != 0:
|
||||
w += 1
|
||||
if h % 2 != 0:
|
||||
h += 1
|
||||
|
||||
ffmpeg = shutil.which("ffmpeg")
|
||||
cmd = [
|
||||
ffmpeg,
|
||||
"-threads",
|
||||
f"{FFMPEG_NUM_THREADS}", # global threads
|
||||
"-ss",
|
||||
f"{seek_t:.2f}",
|
||||
"-t",
|
||||
f"{max_time:.2f}",
|
||||
"-i",
|
||||
in_path,
|
||||
"-threads",
|
||||
f"{FFMPEG_NUM_THREADS}", # decode (or filter..?) threads
|
||||
"-vf",
|
||||
f"fps={fps},scale={w}:{h},setsar=1:1",
|
||||
"-c:v",
|
||||
codec,
|
||||
"-crf",
|
||||
f"{crf}",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-threads",
|
||||
f"{FFMPEG_NUM_THREADS}", # encode threads
|
||||
out_path,
|
||||
"-y",
|
||||
]
|
||||
if verbose:
|
||||
print(" ".join(cmd))
|
||||
|
||||
subprocess.call(
|
||||
cmd,
|
||||
stdout=None if verbose else subprocess.DEVNULL,
|
||||
stderr=None if verbose else subprocess.DEVNULL,
|
||||
)
|
191
demo/backend/server/inference/data_types.py
Normal file
191
demo/backend/server/inference/data_types.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from dataclasses_json import dataclass_json
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class Mask:
|
||||
size: List[int]
|
||||
counts: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class BaseRequest:
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class StartSessionRequest(BaseRequest):
|
||||
type: str
|
||||
path: str
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class SaveSessionRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class LoadSessionRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class RenewSessionRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class CloseSessionRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class AddPointsRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
frame_index: int
|
||||
clear_old_points: bool
|
||||
object_id: int
|
||||
labels: List[int]
|
||||
points: List[List[float]]
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class AddMaskRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
frame_index: int
|
||||
object_id: int
|
||||
mask: Mask
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class ClearPointsInFrameRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
frame_index: int
|
||||
object_id: int
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class ClearPointsInVideoRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class RemoveObjectRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
object_id: int
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class PropagateInVideoRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
start_frame_index: int
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class CancelPropagateInVideoRequest(BaseRequest):
|
||||
type: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class StartSessionResponse:
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class SaveSessionResponse:
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class LoadSessionResponse:
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class RenewSessionResponse:
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class CloseSessionResponse:
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class ClearPointsInVideoResponse:
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class PropagateDataValue:
|
||||
object_id: int
|
||||
mask: Mask
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class PropagateDataResponse:
|
||||
frame_index: int
|
||||
results: List[PropagateDataValue]
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class RemoveObjectResponse:
|
||||
results: List[PropagateDataResponse]
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class CancelPorpagateResponse:
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class InferenceSession:
|
||||
start_time: float
|
||||
last_use_time: float
|
||||
session_id: str
|
||||
state: Dict[str, Dict[str, Union[Tensor, Dict[int, Tensor]]]]
|
48
demo/backend/server/inference/multipart.py
Normal file
48
demo/backend/server/inference/multipart.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class MultipartResponseBuilder:
|
||||
message: bytes
|
||||
|
||||
def __init__(self, boundary: str) -> None:
|
||||
self.message = b"--" + boundary.encode("utf-8") + b"\r\n"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, boundary: str, headers: Dict[str, str], body: Union[str, bytes]
|
||||
) -> "MultipartResponseBuilder":
|
||||
builder = cls(boundary=boundary)
|
||||
for k, v in headers.items():
|
||||
builder.__append_header(key=k, value=v)
|
||||
if isinstance(body, bytes):
|
||||
builder.__append_body(body)
|
||||
elif isinstance(body, str):
|
||||
builder.__append_body(body.encode("utf-8"))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"body needs to be of type bytes or str but got {type(body)}"
|
||||
)
|
||||
|
||||
return builder
|
||||
|
||||
def get_message(self) -> bytes:
|
||||
return self.message
|
||||
|
||||
def __append_header(self, key: str, value: str) -> "MultipartResponseBuilder":
|
||||
self.message += key.encode("utf-8") + b": " + value.encode("utf-8") + b"\r\n"
|
||||
return self
|
||||
|
||||
def __close_header(self) -> "MultipartResponseBuilder":
|
||||
self.message += b"\r\n"
|
||||
return self
|
||||
|
||||
def __append_body(self, body: bytes) -> "MultipartResponseBuilder":
|
||||
self.__append_header(key="Content-Length", value=str(len(body)))
|
||||
self.__close_header()
|
||||
self.message += body
|
||||
return self
|
427
demo/backend/server/inference/predictor.py
Normal file
427
demo/backend/server/inference/predictor.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Generator, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from app_conf import APP_ROOT, MODEL_SIZE
|
||||
from inference.data_types import (
|
||||
AddMaskRequest,
|
||||
AddPointsRequest,
|
||||
CancelPorpagateResponse,
|
||||
CancelPropagateInVideoRequest,
|
||||
ClearPointsInFrameRequest,
|
||||
ClearPointsInVideoRequest,
|
||||
ClearPointsInVideoResponse,
|
||||
CloseSessionRequest,
|
||||
CloseSessionResponse,
|
||||
Mask,
|
||||
PropagateDataResponse,
|
||||
PropagateDataValue,
|
||||
PropagateInVideoRequest,
|
||||
RemoveObjectRequest,
|
||||
RemoveObjectResponse,
|
||||
StartSessionRequest,
|
||||
StartSessionResponse,
|
||||
)
|
||||
from pycocotools.mask import decode as decode_masks, encode as encode_masks
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InferenceAPI:
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(InferenceAPI, self).__init__()
|
||||
|
||||
self.session_states: Dict[str, Any] = {}
|
||||
self.score_thresh = 0
|
||||
|
||||
if MODEL_SIZE == "tiny":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_tiny.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
||||
elif MODEL_SIZE == "small":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_small.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
|
||||
elif MODEL_SIZE == "large":
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_large.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
else: # base_plus (default)
|
||||
checkpoint = Path(APP_ROOT) / "checkpoints/sam2.1_hiera_base_plus.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||
|
||||
# select the device for computation
|
||||
force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1"
|
||||
if force_cpu_device:
|
||||
logger.info("forcing CPU device for SAM 2 demo")
|
||||
if torch.cuda.is_available() and not force_cpu_device:
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available() and not force_cpu_device:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info(f"using device: {device}")
|
||||
|
||||
if device.type == "cuda":
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
elif device.type == "mps":
|
||||
logging.warning(
|
||||
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS. "
|
||||
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
|
||||
)
|
||||
|
||||
self.device = device
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg, checkpoint, device=device
|
||||
)
|
||||
self.inference_lock = Lock()
|
||||
|
||||
def autocast_context(self):
|
||||
if self.device.type == "cuda":
|
||||
return torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def start_session(self, request: StartSessionRequest) -> StartSessionResponse:
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = str(uuid.uuid4())
|
||||
# for MPS devices, we offload the video frames to CPU by default to avoid
|
||||
# memory fragmentation in MPS (which sometimes crashes the entire process)
|
||||
offload_video_to_cpu = self.device.type == "mps"
|
||||
inference_state = self.predictor.init_state(
|
||||
request.path,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
)
|
||||
self.session_states[session_id] = {
|
||||
"canceled": False,
|
||||
"state": inference_state,
|
||||
}
|
||||
return StartSessionResponse(session_id=session_id)
|
||||
|
||||
def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
|
||||
is_successful = self.__clear_session_state(request.session_id)
|
||||
return CloseSessionResponse(success=is_successful)
|
||||
|
||||
def add_points(
|
||||
self, request: AddPointsRequest, test: str = ""
|
||||
) -> PropagateDataResponse:
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session = self.__get_session(request.session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
points = request.points
|
||||
labels = request.labels
|
||||
clear_old_points = request.clear_old_points
|
||||
|
||||
# add new prompts and instantly get the output on the same frame
|
||||
frame_idx, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
clear_old_points=clear_old_points,
|
||||
normalize_coords=False,
|
||||
)
|
||||
|
||||
masks_binary = (masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=object_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def add_mask(self, request: AddMaskRequest) -> PropagateDataResponse:
|
||||
"""
|
||||
Add new points on a specific video frame.
|
||||
- mask is a numpy array of shape [H_im, W_im] (containing 1 for foreground and 0 for background).
|
||||
Note: providing an input mask would overwrite any previous input points on this frame.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
rle_mask = {
|
||||
"counts": request.mask.counts,
|
||||
"size": request.mask.size,
|
||||
}
|
||||
|
||||
mask = decode_masks(rle_mask)
|
||||
|
||||
logger.info(
|
||||
f"add mask on frame {frame_idx} in session {session_id}: {obj_id=}, {mask.shape=}"
|
||||
)
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = self.model.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
mask=torch.tensor(mask > 0),
|
||||
)
|
||||
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def clear_points_in_frame(
|
||||
self, request: ClearPointsInFrameRequest
|
||||
) -> PropagateDataResponse:
|
||||
"""
|
||||
Remove all input points in a specific frame.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
frame_idx = request.frame_index
|
||||
obj_id = request.object_id
|
||||
|
||||
logger.info(
|
||||
f"clear inputs on frame {frame_idx} in session {session_id}: {obj_id=}"
|
||||
)
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
frame_idx, obj_ids, video_res_masks = (
|
||||
self.predictor.clear_all_prompts_in_frame(
|
||||
inference_state, frame_idx, obj_id
|
||||
)
|
||||
)
|
||||
masks_binary = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
return PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
def clear_points_in_video(
|
||||
self, request: ClearPointsInVideoRequest
|
||||
) -> ClearPointsInVideoResponse:
|
||||
"""
|
||||
Remove all input points in all frames throughout the video.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
logger.info(f"clear all inputs across the video in session {session_id}")
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
self.predictor.reset_state(inference_state)
|
||||
return ClearPointsInVideoResponse(success=True)
|
||||
|
||||
def remove_object(self, request: RemoveObjectRequest) -> RemoveObjectResponse:
|
||||
"""
|
||||
Remove an object id from the tracking state.
|
||||
"""
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
session_id = request.session_id
|
||||
obj_id = request.object_id
|
||||
logger.info(f"remove object in session {session_id}: {obj_id=}")
|
||||
session = self.__get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
new_obj_ids, updated_frames = self.predictor.remove_object(
|
||||
inference_state, obj_id
|
||||
)
|
||||
|
||||
results = []
|
||||
for frame_index, video_res_masks in updated_frames:
|
||||
masks = (video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=new_obj_ids, masks=masks
|
||||
)
|
||||
results.append(
|
||||
PropagateDataResponse(
|
||||
frame_index=frame_index,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
)
|
||||
|
||||
return RemoveObjectResponse(results=results)
|
||||
|
||||
def propagate_in_video(
|
||||
self, request: PropagateInVideoRequest
|
||||
) -> Generator[PropagateDataResponse, None, None]:
|
||||
session_id = request.session_id
|
||||
start_frame_idx = request.start_frame_index
|
||||
propagation_direction = "both"
|
||||
max_frame_num_to_track = None
|
||||
|
||||
"""
|
||||
Propagate existing input points in all frames to track the object across video.
|
||||
"""
|
||||
|
||||
# Note that as this method is a generator, we also need to use autocast_context
|
||||
# in caller to this method to ensure that it's called under the correct context
|
||||
# (we've added `autocast_context` to `gen_track_with_mask_stream` in app.py).
|
||||
with self.autocast_context(), self.inference_lock:
|
||||
logger.info(
|
||||
f"propagate in video in session {session_id}: "
|
||||
f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
|
||||
)
|
||||
|
||||
try:
|
||||
session = self.__get_session(session_id)
|
||||
session["canceled"] = False
|
||||
|
||||
inference_state = session["state"]
|
||||
if propagation_direction not in ["both", "forward", "backward"]:
|
||||
raise ValueError(
|
||||
f"invalid propagation direction: {propagation_direction}"
|
||||
)
|
||||
|
||||
# First doing the forward propagation
|
||||
if propagation_direction in ["both", "forward"]:
|
||||
for outputs in self.predictor.propagate_in_video(
|
||||
inference_state=inference_state,
|
||||
start_frame_idx=start_frame_idx,
|
||||
max_frame_num_to_track=max_frame_num_to_track,
|
||||
reverse=False,
|
||||
):
|
||||
if session["canceled"]:
|
||||
return None
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = outputs
|
||||
masks_binary = (
|
||||
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
)
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
yield PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
|
||||
# Then doing the backward propagation (reverse in time)
|
||||
if propagation_direction in ["both", "backward"]:
|
||||
for outputs in self.predictor.propagate_in_video(
|
||||
inference_state=inference_state,
|
||||
start_frame_idx=start_frame_idx,
|
||||
max_frame_num_to_track=max_frame_num_to_track,
|
||||
reverse=True,
|
||||
):
|
||||
if session["canceled"]:
|
||||
return None
|
||||
|
||||
frame_idx, obj_ids, video_res_masks = outputs
|
||||
masks_binary = (
|
||||
(video_res_masks > self.score_thresh)[:, 0].cpu().numpy()
|
||||
)
|
||||
|
||||
rle_mask_list = self.__get_rle_mask_list(
|
||||
object_ids=obj_ids, masks=masks_binary
|
||||
)
|
||||
|
||||
yield PropagateDataResponse(
|
||||
frame_index=frame_idx,
|
||||
results=rle_mask_list,
|
||||
)
|
||||
finally:
|
||||
# Log upon completion (so that e.g. we can see if two propagations happen in parallel).
|
||||
# Using `finally` here to log even when the tracking is aborted with GeneratorExit.
|
||||
logger.info(
|
||||
f"propagation ended in session {session_id}; {self.__get_session_stats()}"
|
||||
)
|
||||
|
||||
def cancel_propagate_in_video(
|
||||
self, request: CancelPropagateInVideoRequest
|
||||
) -> CancelPorpagateResponse:
|
||||
session = self.__get_session(request.session_id)
|
||||
session["canceled"] = True
|
||||
return CancelPorpagateResponse(success=True)
|
||||
|
||||
def __get_rle_mask_list(
|
||||
self, object_ids: List[int], masks: np.ndarray
|
||||
) -> List[PropagateDataValue]:
|
||||
"""
|
||||
Return a list of data values, i.e. list of object/mask combos.
|
||||
"""
|
||||
return [
|
||||
self.__get_mask_for_object(object_id=object_id, mask=mask)
|
||||
for object_id, mask in zip(object_ids, masks)
|
||||
]
|
||||
|
||||
def __get_mask_for_object(
|
||||
self, object_id: int, mask: np.ndarray
|
||||
) -> PropagateDataValue:
|
||||
"""
|
||||
Create a data value for an object/mask combo.
|
||||
"""
|
||||
mask_rle = encode_masks(np.array(mask, dtype=np.uint8, order="F"))
|
||||
mask_rle["counts"] = mask_rle["counts"].decode()
|
||||
return PropagateDataValue(
|
||||
object_id=object_id,
|
||||
mask=Mask(
|
||||
size=mask_rle["size"],
|
||||
counts=mask_rle["counts"],
|
||||
),
|
||||
)
|
||||
|
||||
def __get_session(self, session_id: str):
|
||||
session = self.session_states.get(session_id, None)
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot find session {session_id}; it might have expired"
|
||||
)
|
||||
return session
|
||||
|
||||
def __get_session_stats(self):
|
||||
"""Get a statistics string for live sessions and their GPU usage."""
|
||||
# print both the session ids and their video frame numbers
|
||||
live_session_strs = [
|
||||
f"'{session_id}' ({session['state']['num_frames']} frames, "
|
||||
f"{len(session['state']['obj_ids'])} objects)"
|
||||
for session_id, session in self.session_states.items()
|
||||
]
|
||||
session_stats_str = (
|
||||
"Test String Here - -"
|
||||
f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
|
||||
f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
|
||||
f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
|
||||
f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
|
||||
f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
|
||||
)
|
||||
return session_stats_str
|
||||
|
||||
def __clear_session_state(self, session_id: str) -> bool:
|
||||
session = self.session_states.pop(session_id, None)
|
||||
if session is None:
|
||||
logger.warning(
|
||||
f"cannot close session {session_id} as it does not exist (it might have expired); "
|
||||
f"{self.__get_session_stats()}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.info(f"removed session {session_id}; {self.__get_session_stats()}")
|
||||
return True
|
Reference in New Issue
Block a user