SAM2.1 checkpoints + training code + Demo
This commit is contained in:
Haitham Khedr
2024-09-28 08:20:56 -07:00
parent 7e1596c0b6
commit aa9b8722d0
325 changed files with 38174 additions and 223 deletions

View File

@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Iterable, List, Optional
import strawberry
from app_conf import API_URL
from data.resolver import resolve_videos
from dataclasses_json import dataclass_json
from strawberry import relay
@strawberry.type
class Video(relay.Node):
"""Core type for video."""
code: relay.NodeID[str]
path: str
poster_path: Optional[str]
width: int
height: int
@strawberry.field
def url(self) -> str:
return f"{API_URL}/{self.path}"
@strawberry.field
def poster_url(self) -> str:
return f"{API_URL}/{self.poster_path}"
@classmethod
def resolve_nodes(
cls,
*,
info: relay.PageInfo,
node_ids: Iterable[str],
required: bool = False,
):
return resolve_videos(node_ids, required)
@strawberry.type
class RLEMask:
"""Core type for Onevision GraphQL RLE mask."""
size: List[int]
counts: str
order: str
@strawberry.type
class RLEMaskForObject:
"""Type for RLE mask associated with a specific object id."""
object_id: int
rle_mask: RLEMask
@strawberry.type
class RLEMaskListOnFrame:
"""Type for a list of object-associated RLE masks on a specific video frame."""
frame_index: int
rle_mask_list: List[RLEMaskForObject]
@strawberry.input
class StartSessionInput:
path: str
@strawberry.type
class StartSession:
session_id: str
@strawberry.input
class PingInput:
session_id: str
@strawberry.type
class Pong:
success: bool
@strawberry.input
class CloseSessionInput:
session_id: str
@strawberry.type
class CloseSession:
success: bool
@strawberry.input
class AddPointsInput:
session_id: str
frame_index: int
clear_old_points: bool
object_id: int
labels: List[int]
points: List[List[float]]
@strawberry.input
class ClearPointsInFrameInput:
session_id: str
frame_index: int
object_id: int
@strawberry.input
class ClearPointsInVideoInput:
session_id: str
@strawberry.type
class ClearPointsInVideo:
success: bool
@strawberry.input
class RemoveObjectInput:
session_id: str
object_id: int
@strawberry.input
class PropagateInVideoInput:
session_id: str
start_frame_index: int
@strawberry.input
class CancelPropagateInVideoInput:
session_id: str
@strawberry.type
class CancelPropagateInVideo:
success: bool
@strawberry.type
class SessionExpiration:
session_id: str
expiration_time: int
max_expiration_time: int
ttl: int

View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import subprocess
from glob import glob
from pathlib import Path
from typing import Dict, Optional
import imagesize
from app_conf import GALLERY_PATH, POSTERS_PATH, POSTERS_PREFIX
from data.data_types import Video
from tqdm import tqdm
def preload_data() -> Dict[str, Video]:
"""
Preload data including gallery videos and their posters.
"""
# Dictionaries for videos and datasets on the backend.
# Note that since Python 3.7, dictionaries preserve their insert order, so
# when looping over its `.values()`, elements inserted first also appear first.
# https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6
all_videos = {}
video_path_pattern = os.path.join(GALLERY_PATH, "**/*.mp4")
video_paths = glob(video_path_pattern, recursive=True)
for p in tqdm(video_paths):
video = get_video(p, GALLERY_PATH)
all_videos[video.code] = video
return all_videos
def get_video(
filepath: os.PathLike,
absolute_path: Path,
file_key: Optional[str] = None,
generate_poster: bool = True,
width: Optional[int] = None,
height: Optional[int] = None,
verbose: Optional[bool] = False,
) -> Video:
"""
Get video object given
"""
# Use absolute_path to include the parent directory in the video
video_path = os.path.relpath(filepath, absolute_path.parent)
poster_path = None
if generate_poster:
poster_id = os.path.splitext(os.path.basename(filepath))[0]
poster_filename = f"{str(poster_id)}.jpg"
poster_path = f"{POSTERS_PREFIX}/{poster_filename}"
# Extract the first frame from video
poster_output_path = os.path.join(POSTERS_PATH, poster_filename)
ffmpeg = shutil.which("ffmpeg")
subprocess.call(
[
ffmpeg,
"-y",
"-i",
str(filepath),
"-pix_fmt",
"yuv420p",
"-frames:v",
"1",
"-update",
"1",
"-strict",
"unofficial",
str(poster_output_path),
],
stdout=None if verbose else subprocess.DEVNULL,
stderr=None if verbose else subprocess.DEVNULL,
)
# Extract video width and height from poster. This is important to optimize
# rendering previews in the mosaic video preview.
width, height = imagesize.get(poster_output_path)
return Video(
code=video_path,
path=video_path if file_key is None else file_key,
poster_path=poster_path,
width=width,
height=height,
)

View File

@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable
def resolve_videos(node_ids: Iterable[str], required: bool = False):
"""
Resolve videos given node ids.
"""
from data.store import get_videos
all_videos = get_videos()
return [
all_videos[nid] if required else all_videos.get(nid, None) for nid in node_ids
]

View File

@@ -0,0 +1,357 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import os
import shutil
import tempfile
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
import av
import strawberry
from app_conf import (
DATA_PATH,
DEFAULT_VIDEO_PATH,
MAX_UPLOAD_VIDEO_DURATION,
UPLOADS_PATH,
UPLOADS_PREFIX,
)
from data.data_types import (
AddPointsInput,
CancelPropagateInVideo,
CancelPropagateInVideoInput,
ClearPointsInFrameInput,
ClearPointsInVideo,
ClearPointsInVideoInput,
CloseSession,
CloseSessionInput,
RemoveObjectInput,
RLEMask,
RLEMaskForObject,
RLEMaskListOnFrame,
StartSession,
StartSessionInput,
Video,
)
from data.loader import get_video
from data.store import get_videos
from data.transcoder import get_video_metadata, transcode, VideoMetadata
from inference.data_types import (
AddPointsRequest,
CancelPropagateInVideoRequest,
CancelPropagateInVideoRequest,
ClearPointsInFrameRequest,
ClearPointsInVideoRequest,
CloseSessionRequest,
RemoveObjectRequest,
StartSessionRequest,
)
from inference.predictor import InferenceAPI
from strawberry import relay
from strawberry.file_uploads import Upload
@strawberry.type
class Query:
@strawberry.field
def default_video(self) -> Video:
"""
Return the default video.
The default video can be set with the DEFAULT_VIDEO_PATH environment
variable. It will return the video that matches this path. If no video
is found, it will return the first video.
"""
all_videos = get_videos()
# Find the video that matches the default path and return that as
# default video.
for _, v in all_videos.items():
if v.path == DEFAULT_VIDEO_PATH:
return v
# Fallback is returning the first video
return next(iter(all_videos.values()))
@relay.connection(relay.ListConnection[Video])
def videos(
self,
) -> Iterable[Video]:
"""
Return all available videos.
"""
all_videos = get_videos()
return all_videos.values()
@strawberry.type
class Mutation:
@strawberry.mutation
def upload_video(
self,
file: Upload,
start_time_sec: Optional[float] = None,
duration_time_sec: Optional[float] = None,
) -> Video:
"""
Receive a video file and store it in the configured S3 bucket.
"""
max_time = MAX_UPLOAD_VIDEO_DURATION
filepath, file_key, vm = process_video(
file,
max_time=max_time,
start_time_sec=start_time_sec,
duration_time_sec=duration_time_sec,
)
video = get_video(
filepath,
UPLOADS_PATH,
file_key=file_key,
width=vm.width,
height=vm.height,
generate_poster=False,
)
return video
@strawberry.mutation
def start_session(
self, input: StartSessionInput, info: strawberry.Info
) -> StartSession:
inference_api: InferenceAPI = info.context["inference_api"]
request = StartSessionRequest(
type="start_session",
path=f"{DATA_PATH}/{input.path}",
)
response = inference_api.start_session(request=request)
return StartSession(session_id=response.session_id)
@strawberry.mutation
def close_session(
self, input: CloseSessionInput, info: strawberry.Info
) -> CloseSession:
inference_api: InferenceAPI = info.context["inference_api"]
request = CloseSessionRequest(
type="close_session",
session_id=input.session_id,
)
response = inference_api.close_session(request)
return CloseSession(success=response.success)
@strawberry.mutation
def add_points(
self, input: AddPointsInput, info: strawberry.Info
) -> RLEMaskListOnFrame:
inference_api: InferenceAPI = info.context["inference_api"]
request = AddPointsRequest(
type="add_points",
session_id=input.session_id,
frame_index=input.frame_index,
object_id=input.object_id,
points=input.points,
labels=input.labels,
clear_old_points=input.clear_old_points,
)
reponse = inference_api.add_points(request)
return RLEMaskListOnFrame(
frame_index=reponse.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
)
for r in reponse.results
],
)
@strawberry.mutation
def remove_object(
self, input: RemoveObjectInput, info: strawberry.Info
) -> List[RLEMaskListOnFrame]:
inference_api: InferenceAPI = info.context["inference_api"]
request = RemoveObjectRequest(
type="remove_object", session_id=input.session_id, object_id=input.object_id
)
response = inference_api.remove_object(request)
return [
RLEMaskListOnFrame(
frame_index=res.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(
counts=r.mask.counts, size=r.mask.size, order="F"
),
)
for r in res.results
],
)
for res in response.results
]
@strawberry.mutation
def clear_points_in_frame(
self, input: ClearPointsInFrameInput, info: strawberry.Info
) -> RLEMaskListOnFrame:
inference_api: InferenceAPI = info.context["inference_api"]
request = ClearPointsInFrameRequest(
type="clear_points_in_frame",
session_id=input.session_id,
frame_index=input.frame_index,
object_id=input.object_id,
)
response = inference_api.clear_points_in_frame(request)
return RLEMaskListOnFrame(
frame_index=response.frame_index,
rle_mask_list=[
RLEMaskForObject(
object_id=r.object_id,
rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
)
for r in response.results
],
)
@strawberry.mutation
def clear_points_in_video(
self, input: ClearPointsInVideoInput, info: strawberry.Info
) -> ClearPointsInVideo:
inference_api: InferenceAPI = info.context["inference_api"]
request = ClearPointsInVideoRequest(
type="clear_points_in_video",
session_id=input.session_id,
)
response = inference_api.clear_points_in_video(request)
return ClearPointsInVideo(success=response.success)
@strawberry.mutation
def cancel_propagate_in_video(
self, input: CancelPropagateInVideoInput, info: strawberry.Info
) -> CancelPropagateInVideo:
inference_api: InferenceAPI = info.context["inference_api"]
request = CancelPropagateInVideoRequest(
type="cancel_propagate_in_video",
session_id=input.session_id,
)
response = inference_api.cancel_propagate_in_video(request)
return CancelPropagateInVideo(success=response.success)
def get_file_hash(video_path_or_file) -> str:
if isinstance(video_path_or_file, str):
with open(video_path_or_file, "rb") as in_f:
result = hashlib.sha256(in_f.read()).hexdigest()
else:
video_path_or_file.seek(0)
result = hashlib.sha256(video_path_or_file.read()).hexdigest()
return result
def _get_start_sec_duration_sec(
start_time_sec: Union[float, None],
duration_time_sec: Union[float, None],
max_time: float,
) -> Tuple[float, float]:
default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0"))
if start_time_sec is None:
start_time_sec = default_seek_t
if duration_time_sec is not None:
duration_time_sec = min(duration_time_sec, max_time)
else:
duration_time_sec = max_time
return start_time_sec, duration_time_sec
def process_video(
file: Upload,
max_time: float,
start_time_sec: Optional[float] = None,
duration_time_sec: Optional[float] = None,
) -> Tuple[Optional[str], str, str, VideoMetadata]:
"""
Process file upload including video trimming and content moderation checks.
Returns the filepath, s3_file_key, hash & video metaedata as a tuple.
"""
with tempfile.TemporaryDirectory() as tempdir:
in_path = f"{tempdir}/in.mp4"
out_path = f"{tempdir}/out.mp4"
with open(in_path, "wb") as in_f:
in_f.write(file.read())
try:
video_metadata = get_video_metadata(in_path)
except av.InvalidDataError:
raise Exception("not valid video file")
if video_metadata.num_video_streams == 0:
raise Exception("video container does not contain a video stream")
if video_metadata.width is None or video_metadata.height is None:
raise Exception("video container does not contain width or height metadata")
if video_metadata.duration_sec in (None, 0):
raise Exception("video container does time duration metadata")
start_time_sec, duration_time_sec = _get_start_sec_duration_sec(
max_time=max_time,
start_time_sec=start_time_sec,
duration_time_sec=duration_time_sec,
)
# Transcode video to make sure videos returned to the app are all in
# the same format, duration, resolution, fps.
transcode(
in_path,
out_path,
video_metadata,
seek_t=start_time_sec,
duration_time_sec=duration_time_sec,
)
os.remove(in_path) # don't need original video now
out_video_metadata = get_video_metadata(out_path)
if out_video_metadata.num_video_frames == 0:
raise Exception(
"transcode produced empty video; check seek time or your input video"
)
filepath = None
file_key = None
with open(out_path, "rb") as file_data:
file_hash = get_file_hash(file_data)
file_data.seek(0)
file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4"
filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4")
assert filepath is not None and file_key is not None
shutil.move(out_path, filepath)
return filepath, file_key, out_video_metadata
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
)

View File

@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
from data.data_types import Video
ALL_VIDEOS: Dict[str, Video] = []
def set_videos(videos: Dict[str, Video]) -> None:
"""
Set the videos available in the backend. The data is kept in-memory, but a future change could replace the
in-memory storage with a database backend. This would also be more efficient when querying videos given a
dataset name etc.
"""
global ALL_VIDEOS
ALL_VIDEOS = videos
def get_videos() -> Dict[str, Video]:
"""
Return the videos available in the backend.
"""
global ALL_VIDEOS
return ALL_VIDEOS

View File

@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import ast
import math
import os
import shutil
import subprocess
from dataclasses import dataclass
from typing import Optional
import av
from app_conf import FFMPEG_NUM_THREADS
from dataclasses_json import dataclass_json
TRANSCODE_VERSION = 1
@dataclass_json
@dataclass
class VideoMetadata:
duration_sec: Optional[float]
video_duration_sec: Optional[float]
container_duration_sec: Optional[float]
fps: Optional[float]
width: Optional[int]
height: Optional[int]
num_video_frames: int
num_video_streams: int
video_start_time: float
def transcode(
in_path: str,
out_path: str,
in_metadata: Optional[VideoMetadata],
seek_t: float,
duration_time_sec: float,
):
codec = os.environ.get("VIDEO_ENCODE_CODEC", "libx264")
crf = int(os.environ.get("VIDEO_ENCODE_CRF", "23"))
fps = int(os.environ.get("VIDEO_ENCODE_FPS", "24"))
max_w = int(os.environ.get("VIDEO_ENCODE_MAX_WIDTH", "1280"))
max_h = int(os.environ.get("VIDEO_ENCODE_MAX_HEIGHT", "720"))
verbose = ast.literal_eval(os.environ.get("VIDEO_ENCODE_VERBOSE", "False"))
normalize_video(
in_path=in_path,
out_path=out_path,
max_w=max_w,
max_h=max_h,
seek_t=seek_t,
max_time=duration_time_sec,
in_metadata=in_metadata,
codec=codec,
crf=crf,
fps=fps,
verbose=verbose,
)
def get_video_metadata(path: str) -> VideoMetadata:
with av.open(path) as cont:
num_video_streams = len(cont.streams.video)
width, height, fps = None, None, None
video_duration_sec = 0
container_duration_sec = float((cont.duration or 0) / av.time_base)
video_start_time = 0.0
rotation_deg = 0
num_video_frames = 0
if num_video_streams > 0:
video_stream = cont.streams.video[0]
assert video_stream.time_base is not None
# for rotation, see: https://github.com/PyAV-Org/PyAV/pull/1249
rotation_deg = video_stream.side_data.get("DISPLAYMATRIX", 0)
num_video_frames = video_stream.frames
video_start_time = float(video_stream.start_time * video_stream.time_base)
width, height = video_stream.width, video_stream.height
fps = float(video_stream.guessed_rate)
fps_avg = video_stream.average_rate
if video_stream.duration is not None:
video_duration_sec = float(
video_stream.duration * video_stream.time_base
)
if fps is None:
fps = float(fps_avg)
if not math.isnan(rotation_deg) and int(rotation_deg) in (
90,
-90,
270,
-270,
):
width, height = height, width
duration_sec = max(container_duration_sec, video_duration_sec)
return VideoMetadata(
duration_sec=duration_sec,
container_duration_sec=container_duration_sec,
video_duration_sec=video_duration_sec,
video_start_time=video_start_time,
fps=fps,
width=width,
height=height,
num_video_streams=num_video_streams,
num_video_frames=num_video_frames,
)
def normalize_video(
in_path: str,
out_path: str,
max_w: int,
max_h: int,
seek_t: float,
max_time: float,
in_metadata: Optional[VideoMetadata],
codec: str = "libx264",
crf: int = 23,
fps: int = 24,
verbose: bool = False,
):
if in_metadata is None:
in_metadata = get_video_metadata(in_path)
assert in_metadata.num_video_streams > 0, "no video stream present"
w, h = in_metadata.width, in_metadata.height
assert w is not None, "width not available"
assert h is not None, "height not available"
# rescale to max_w:max_h if needed & preserve aspect ratio
r = w / h
if r < 1:
h = min(720, h)
w = h * r
else:
w = min(1280, w)
h = w / r
# h264 cannot encode w/ odd dimensions
w = int(w)
h = int(h)
if w % 2 != 0:
w += 1
if h % 2 != 0:
h += 1
ffmpeg = shutil.which("ffmpeg")
cmd = [
ffmpeg,
"-threads",
f"{FFMPEG_NUM_THREADS}", # global threads
"-ss",
f"{seek_t:.2f}",
"-t",
f"{max_time:.2f}",
"-i",
in_path,
"-threads",
f"{FFMPEG_NUM_THREADS}", # decode (or filter..?) threads
"-vf",
f"fps={fps},scale={w}:{h},setsar=1:1",
"-c:v",
codec,
"-crf",
f"{crf}",
"-pix_fmt",
"yuv420p",
"-threads",
f"{FFMPEG_NUM_THREADS}", # encode threads
out_path,
"-y",
]
if verbose:
print(" ".join(cmd))
subprocess.call(
cmd,
stdout=None if verbose else subprocess.DEVNULL,
stderr=None if verbose else subprocess.DEVNULL,
)