2024-09-28 08:20:56 -07:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from typing import Any, Generator
|
|
|
|
|
|
|
|
from app_conf import (
|
|
|
|
GALLERY_PATH,
|
|
|
|
GALLERY_PREFIX,
|
|
|
|
POSTERS_PATH,
|
|
|
|
POSTERS_PREFIX,
|
|
|
|
UPLOADS_PATH,
|
|
|
|
UPLOADS_PREFIX,
|
|
|
|
)
|
|
|
|
from data.loader import preload_data
|
|
|
|
from data.schema import schema
|
|
|
|
from data.store import set_videos
|
|
|
|
from flask import Flask, make_response, Request, request, Response, send_from_directory
|
|
|
|
from flask_cors import CORS
|
|
|
|
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
|
|
|
|
from inference.multipart import MultipartResponseBuilder
|
|
|
|
from inference.predictor import InferenceAPI
|
|
|
|
from strawberry.flask.views import GraphQLView
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
cors = CORS(app, supports_credentials=True)
|
|
|
|
|
|
|
|
videos = preload_data()
|
|
|
|
set_videos(videos)
|
|
|
|
|
|
|
|
inference_api = InferenceAPI()
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/healthy")
|
|
|
|
def healthy() -> Response:
|
|
|
|
return make_response("OK", 200)
|
|
|
|
|
|
|
|
|
|
|
|
@app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
|
|
|
|
def send_gallery_video(path: str) -> Response:
|
|
|
|
try:
|
|
|
|
return send_from_directory(
|
|
|
|
GALLERY_PATH,
|
|
|
|
path,
|
|
|
|
)
|
|
|
|
except:
|
|
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
|
|
|
|
|
|
@app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
|
|
|
|
def send_poster_image(path: str) -> Response:
|
|
|
|
try:
|
|
|
|
return send_from_directory(
|
|
|
|
POSTERS_PATH,
|
|
|
|
path,
|
|
|
|
)
|
|
|
|
except:
|
|
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
|
|
|
|
|
|
@app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
|
|
|
|
def send_uploaded_video(path: str):
|
|
|
|
try:
|
|
|
|
return send_from_directory(
|
|
|
|
UPLOADS_PATH,
|
|
|
|
path,
|
|
|
|
)
|
|
|
|
except:
|
|
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
|
|
|
|
|
|
# TOOD: Protect route with ToS permission check
|
|
|
|
@app.route("/propagate_in_video", methods=["POST"])
|
|
|
|
def propagate_in_video() -> Response:
|
|
|
|
data = request.json
|
|
|
|
args = {
|
|
|
|
"session_id": data["session_id"],
|
|
|
|
"start_frame_index": data.get("start_frame_index", 0),
|
|
|
|
}
|
|
|
|
|
|
|
|
boundary = "frame"
|
|
|
|
frame = gen_track_with_mask_stream(boundary, **args)
|
|
|
|
return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
|
|
|
|
|
|
|
|
|
|
|
|
def gen_track_with_mask_stream(
|
|
|
|
boundary: str,
|
|
|
|
session_id: str,
|
|
|
|
start_frame_index: int,
|
|
|
|
) -> Generator[bytes, None, None]:
|
|
|
|
with inference_api.autocast_context():
|
|
|
|
request = PropagateInVideoRequest(
|
|
|
|
type="propagate_in_video",
|
|
|
|
session_id=session_id,
|
|
|
|
start_frame_index=start_frame_index,
|
|
|
|
)
|
|
|
|
|
|
|
|
for chunk in inference_api.propagate_in_video(request=request):
|
|
|
|
yield MultipartResponseBuilder.build(
|
|
|
|
boundary=boundary,
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
|
|
"Frame-Current": "-1",
|
|
|
|
# Total frames minus the reference frame
|
|
|
|
"Frame-Total": "-1",
|
|
|
|
"Mask-Type": "RLE[]",
|
|
|
|
},
|
|
|
|
body=chunk.to_json().encode("UTF-8"),
|
|
|
|
).get_message()
|
|
|
|
|
|
|
|
|
|
|
|
class MyGraphQLView(GraphQLView):
|
|
|
|
def get_context(self, request: Request, response: Response) -> Any:
|
|
|
|
return {"inference_api": inference_api}
|
|
|
|
|
|
|
|
|
|
|
|
# Add GraphQL route to Flask app.
|
|
|
|
app.add_url_rule(
|
|
|
|
"/graphql",
|
|
|
|
view_func=MyGraphQLView.as_view(
|
|
|
|
"graphql_view",
|
|
|
|
schema=schema,
|
|
|
|
# Disable GET queries
|
|
|
|
# https://strawberry.rocks/docs/operations/deployment
|
|
|
|
# https://strawberry.rocks/docs/integrations/flask
|
|
|
|
allow_queries_via_get=False,
|
2024-10-08 14:58:27 -07:00
|
|
|
# Strawberry recently changed multipart request handling, which now
|
|
|
|
# requires enabling support explicitly for views.
|
|
|
|
# https://github.com/strawberry-graphql/strawberry/issues/3655
|
|
|
|
multipart_uploads_enabled=True,
|
2024-09-28 08:20:56 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app.run(host="0.0.0.0", port=5000)
|