support gsam2 image predictor model
This commit is contained in:
175
sav_dataset/utils/sav_utils.py
Normal file
175
sav_dataset/utils/sav_utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the sav_dataset directory of this source tree.
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_util
|
||||
|
||||
|
||||
def decode_video(video_path: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Decode the video and return the RGB frames
|
||||
"""
|
||||
video = cv2.VideoCapture(video_path)
|
||||
video_frames = []
|
||||
while video.isOpened():
|
||||
ret, frame = video.read()
|
||||
if ret:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
video_frames.append(frame)
|
||||
else:
|
||||
break
|
||||
return video_frames
|
||||
|
||||
|
||||
def show_anns(masks, colors: List, borders=True) -> None:
|
||||
"""
|
||||
show the annotations
|
||||
"""
|
||||
# return if no masks
|
||||
if len(masks) == 0:
|
||||
return
|
||||
|
||||
# sort masks by size
|
||||
sorted_annot_and_color = sorted(
|
||||
zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True
|
||||
)
|
||||
H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1]
|
||||
|
||||
canvas = np.ones((H, W, 4))
|
||||
canvas[:, :, 3] = 0 # set the alpha channel
|
||||
contour_thickness = max(1, int(min(5, 0.01 * min(H, W))))
|
||||
for mask, color in sorted_annot_and_color:
|
||||
canvas[mask] = np.concatenate([color, [0.55]])
|
||||
if borders:
|
||||
contours, _ = cv2.findContours(
|
||||
np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
|
||||
)
|
||||
cv2.drawContours(
|
||||
canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness
|
||||
)
|
||||
|
||||
ax = plt.gca()
|
||||
ax.imshow(canvas)
|
||||
|
||||
|
||||
class SAVDataset:
|
||||
"""
|
||||
SAVDataset is a class to load the SAV dataset and visualize the annotations.
|
||||
"""
|
||||
|
||||
def __init__(self, sav_dir, annot_sample_rate=4):
|
||||
"""
|
||||
Args:
|
||||
sav_dir: the directory of the SAV dataset
|
||||
annot_sample_rate: the sampling rate of the annotations.
|
||||
The annotations are aligned with the videos at 6 fps.
|
||||
"""
|
||||
self.sav_dir = sav_dir
|
||||
self.annot_sample_rate = annot_sample_rate
|
||||
self.manual_mask_colors = np.random.random((256, 3))
|
||||
self.auto_mask_colors = np.random.random((256, 3))
|
||||
|
||||
def read_frames(self, mp4_path: str) -> None:
|
||||
"""
|
||||
Read the frames and downsample them to align with the annotations.
|
||||
"""
|
||||
if not os.path.exists(mp4_path):
|
||||
print(f"{mp4_path} doesn't exist.")
|
||||
return None
|
||||
else:
|
||||
# decode the video
|
||||
frames = decode_video(mp4_path)
|
||||
print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).")
|
||||
|
||||
# downsample the frames to align with the annotations
|
||||
frames = frames[:: self.annot_sample_rate]
|
||||
print(
|
||||
f"Videos are annotated every {self.annot_sample_rate} frames. "
|
||||
"To align with the annotations, "
|
||||
f"downsample the video to {len(frames)} frames."
|
||||
)
|
||||
return frames
|
||||
|
||||
def get_frames_and_annotations(
|
||||
self, video_id: str
|
||||
) -> Tuple[List | None, Dict | None, Dict | None]:
|
||||
"""
|
||||
Get the frames and annotations for video.
|
||||
"""
|
||||
# load the video
|
||||
mp4_path = os.path.join(self.sav_dir, video_id + ".mp4")
|
||||
frames = self.read_frames(mp4_path)
|
||||
if frames is None:
|
||||
return None, None, None
|
||||
|
||||
# load the manual annotations
|
||||
manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json")
|
||||
if not os.path.exists(manual_annot_path):
|
||||
print(f"{manual_annot_path} doesn't exist. Something might be wrong.")
|
||||
manual_annot = None
|
||||
else:
|
||||
manual_annot = json.load(open(manual_annot_path))
|
||||
|
||||
# load the manual annotations
|
||||
auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json")
|
||||
if not os.path.exists(auto_annot_path):
|
||||
print(f"{auto_annot_path} doesn't exist.")
|
||||
auto_annot = None
|
||||
else:
|
||||
auto_annot = json.load(open(auto_annot_path))
|
||||
|
||||
return frames, manual_annot, auto_annot
|
||||
|
||||
def visualize_annotation(
|
||||
self,
|
||||
frames: List[np.ndarray],
|
||||
auto_annot: Optional[Dict],
|
||||
manual_annot: Optional[Dict],
|
||||
annotated_frame_id: int,
|
||||
show_auto=True,
|
||||
show_manual=True,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize the annotations on the annotated_frame_id.
|
||||
If show_manual is True, show the manual annotations.
|
||||
If show_auto is True, show the auto annotations.
|
||||
By default, show both auto and manual annotations.
|
||||
"""
|
||||
|
||||
if annotated_frame_id >= len(frames):
|
||||
print("invalid annotated_frame_id")
|
||||
return
|
||||
|
||||
rles = []
|
||||
colors = []
|
||||
if show_manual and manual_annot is not None:
|
||||
rles.extend(manual_annot["masklet"][annotated_frame_id])
|
||||
colors.extend(
|
||||
self.manual_mask_colors[
|
||||
: len(manual_annot["masklet"][annotated_frame_id])
|
||||
]
|
||||
)
|
||||
if show_auto and auto_annot is not None:
|
||||
rles.extend(auto_annot["masklet"][annotated_frame_id])
|
||||
colors.extend(
|
||||
self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])]
|
||||
)
|
||||
|
||||
plt.imshow(frames[annotated_frame_id])
|
||||
|
||||
if len(rles) > 0:
|
||||
masks = [mask_util.decode(rle) > 0 for rle in rles]
|
||||
show_anns(masks, colors)
|
||||
else:
|
||||
print("No annotation will be shown")
|
||||
|
||||
plt.axis("off")
|
||||
plt.show()
|
Reference in New Issue
Block a user