SAM2.1
SAM2.1 checkpoints + training code + Demo
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -53,20 +54,27 @@ def put_per_obj_mask(per_obj_mask, height, width):
|
||||
return mask
|
||||
|
||||
|
||||
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file):
|
||||
def load_masks_from_dir(
|
||||
input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
|
||||
):
|
||||
"""Load masks from a directory as a dict of per-object masks."""
|
||||
if not per_obj_png_file:
|
||||
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
|
||||
if allow_missing and not os.path.exists(input_mask_path):
|
||||
return {}, None
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask = get_per_obj_mask(input_mask)
|
||||
else:
|
||||
per_obj_input_mask = {}
|
||||
input_palette = None
|
||||
# each object is a directory in "{object_id:%03d}" format
|
||||
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
|
||||
object_id = int(object_name)
|
||||
input_mask_path = os.path.join(
|
||||
input_mask_dir, video_name, object_name, f"{frame_name}.png"
|
||||
)
|
||||
if allow_missing and not os.path.exists(input_mask_path):
|
||||
continue
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask[object_id] = input_mask > 0
|
||||
|
||||
@@ -156,17 +164,44 @@ def vos_inference(
|
||||
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
|
||||
)
|
||||
]
|
||||
# check and make sure we got at least one input frame
|
||||
if len(input_frame_inds) == 0:
|
||||
raise RuntimeError(
|
||||
f"In {video_name=}, got no input masks in {input_mask_dir=}. "
|
||||
"Please make sure the input masks are available in the correct format."
|
||||
)
|
||||
input_frame_inds = sorted(set(input_frame_inds))
|
||||
|
||||
# add those input masks to SAM 2 inference state before propagation
|
||||
object_ids_set = None
|
||||
for input_frame_idx in input_frame_inds:
|
||||
per_obj_input_mask, input_palette = load_masks_from_dir(
|
||||
input_mask_dir=input_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[input_frame_idx],
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
)
|
||||
try:
|
||||
per_obj_input_mask, input_palette = load_masks_from_dir(
|
||||
input_mask_dir=input_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[input_frame_idx],
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
raise RuntimeError(
|
||||
f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
|
||||
"Please add the `--track_object_appearing_later_in_video` flag "
|
||||
"for VOS datasets that don't have all objects to track appearing "
|
||||
"in the first frame (such as LVOS or YouTube-VOS)."
|
||||
) from e
|
||||
# get the list of object ids to track from the first input frame
|
||||
if object_ids_set is None:
|
||||
object_ids_set = set(per_obj_input_mask)
|
||||
for object_id, object_mask in per_obj_input_mask.items():
|
||||
# check and make sure no new object ids appear only in later frames
|
||||
if object_id not in object_ids_set:
|
||||
raise RuntimeError(
|
||||
f"In {video_name=}, got a new {object_id=} appearing only in a "
|
||||
f"later {input_frame_idx=} (but not appearing in the first frame). "
|
||||
"Please add the `--track_object_appearing_later_in_video` flag "
|
||||
"for VOS datasets that don't have all objects to track appearing "
|
||||
"in the first frame (such as LVOS or YouTube-VOS)."
|
||||
)
|
||||
predictor.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=input_frame_idx,
|
||||
@@ -174,6 +209,14 @@ def vos_inference(
|
||||
mask=object_mask,
|
||||
)
|
||||
|
||||
# check and make sure we have at least one object to track
|
||||
if object_ids_set is None or len(object_ids_set) == 0:
|
||||
raise RuntimeError(
|
||||
f"In {video_name=}, got no object ids on {input_frame_inds=}. "
|
||||
"Please add the `--track_object_appearing_later_in_video` flag "
|
||||
"for VOS datasets that don't have all objects to track appearing "
|
||||
"in the first frame (such as LVOS or YouTube-VOS)."
|
||||
)
|
||||
# run propagation throughout the video and collect the results in a dict
|
||||
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
|
||||
output_palette = input_palette or DAVIS_PALETTE
|
||||
@@ -201,18 +244,138 @@ def vos_inference(
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
def vos_separate_inference_per_object(
|
||||
predictor,
|
||||
base_video_dir,
|
||||
input_mask_dir,
|
||||
output_mask_dir,
|
||||
video_name,
|
||||
score_thresh=0.0,
|
||||
use_all_masks=False,
|
||||
per_obj_png_file=False,
|
||||
):
|
||||
"""
|
||||
Run VOS inference on a single video with the given predictor.
|
||||
|
||||
Unlike `vos_inference`, this function run inference separately for each object
|
||||
in a video, which could be applied to datasets like LVOS or YouTube-VOS that
|
||||
don't have all objects to track appearing in the first frame (i.e. some objects
|
||||
might appear only later in the video).
|
||||
"""
|
||||
# load the video frames and initialize the inference state on this video
|
||||
video_dir = os.path.join(base_video_dir, video_name)
|
||||
frame_names = [
|
||||
os.path.splitext(p)[0]
|
||||
for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
inference_state = predictor.init_state(
|
||||
video_path=video_dir, async_loading_frames=False
|
||||
)
|
||||
height = inference_state["video_height"]
|
||||
width = inference_state["video_width"]
|
||||
input_palette = None
|
||||
|
||||
# collect all the object ids and their input masks
|
||||
inputs_per_object = defaultdict(dict)
|
||||
for idx, name in enumerate(frame_names):
|
||||
if per_obj_png_file or os.path.exists(
|
||||
os.path.join(input_mask_dir, video_name, f"{name}.png")
|
||||
):
|
||||
per_obj_input_mask, input_palette = load_masks_from_dir(
|
||||
input_mask_dir=input_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[idx],
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
allow_missing=True,
|
||||
)
|
||||
for object_id, object_mask in per_obj_input_mask.items():
|
||||
# skip empty masks
|
||||
if not np.any(object_mask):
|
||||
continue
|
||||
# if `use_all_masks=False`, we only use the first mask for each object
|
||||
if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
|
||||
continue
|
||||
print(f"adding mask from frame {idx} as input for {object_id=}")
|
||||
inputs_per_object[object_id][idx] = object_mask
|
||||
|
||||
# run inference separately for each object in the video
|
||||
object_ids = sorted(inputs_per_object)
|
||||
output_scores_per_object = defaultdict(dict)
|
||||
for object_id in object_ids:
|
||||
# add those input masks to SAM 2 inference state before propagation
|
||||
input_frame_inds = sorted(inputs_per_object[object_id])
|
||||
predictor.reset_state(inference_state)
|
||||
for input_frame_idx in input_frame_inds:
|
||||
predictor.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=input_frame_idx,
|
||||
obj_id=object_id,
|
||||
mask=inputs_per_object[object_id][input_frame_idx],
|
||||
)
|
||||
|
||||
# run propagation throughout the video and collect the results in a dict
|
||||
for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
|
||||
inference_state,
|
||||
start_frame_idx=min(input_frame_inds),
|
||||
reverse=False,
|
||||
):
|
||||
obj_scores = out_mask_logits.cpu().numpy()
|
||||
output_scores_per_object[object_id][out_frame_idx] = obj_scores
|
||||
|
||||
# post-processing: consolidate the per-object scores into per-frame masks
|
||||
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
|
||||
output_palette = input_palette or DAVIS_PALETTE
|
||||
video_segments = {} # video_segments contains the per-frame segmentation results
|
||||
for frame_idx in range(len(frame_names)):
|
||||
scores = torch.full(
|
||||
size=(len(object_ids), 1, height, width),
|
||||
fill_value=-1024.0,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for i, object_id in enumerate(object_ids):
|
||||
if frame_idx in output_scores_per_object[object_id]:
|
||||
scores[i] = torch.from_numpy(
|
||||
output_scores_per_object[object_id][frame_idx]
|
||||
)
|
||||
|
||||
if not per_obj_png_file:
|
||||
scores = predictor._apply_non_overlapping_constraints(scores)
|
||||
per_obj_output_mask = {
|
||||
object_id: (scores[i] > score_thresh).cpu().numpy()
|
||||
for i, object_id in enumerate(object_ids)
|
||||
}
|
||||
video_segments[frame_idx] = per_obj_output_mask
|
||||
|
||||
# write the output masks as palette PNG files to output_mask_dir
|
||||
for frame_idx, per_obj_output_mask in video_segments.items():
|
||||
save_masks_to_dir(
|
||||
output_mask_dir=output_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[frame_idx],
|
||||
per_obj_output_mask=per_obj_output_mask,
|
||||
height=height,
|
||||
width=width,
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
output_palette=output_palette,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sam2_cfg",
|
||||
type=str,
|
||||
default="sam2_hiera_b+.yaml",
|
||||
default="configs/sam2.1/sam2.1_hiera_b+.yaml",
|
||||
help="SAM 2 model configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sam2_checkpoint",
|
||||
type=str,
|
||||
default="./checkpoints/sam2_hiera_b+.pt",
|
||||
default="./checkpoints/sam2.1_hiera_b+.pt",
|
||||
help="path to the SAM 2 model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -265,6 +428,12 @@ def main():
|
||||
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
|
||||
"(we don't apply such post-processing in the SAM 2 model evaluation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--track_object_appearing_later_in_video",
|
||||
action="store_true",
|
||||
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
|
||||
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
||||
@@ -299,16 +468,28 @@ def main():
|
||||
|
||||
for n_video, video_name in enumerate(video_names):
|
||||
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
|
||||
vos_inference(
|
||||
predictor=predictor,
|
||||
base_video_dir=args.base_video_dir,
|
||||
input_mask_dir=args.input_mask_dir,
|
||||
output_mask_dir=args.output_mask_dir,
|
||||
video_name=video_name,
|
||||
score_thresh=args.score_thresh,
|
||||
use_all_masks=args.use_all_masks,
|
||||
per_obj_png_file=args.per_obj_png_file,
|
||||
)
|
||||
if not args.track_object_appearing_later_in_video:
|
||||
vos_inference(
|
||||
predictor=predictor,
|
||||
base_video_dir=args.base_video_dir,
|
||||
input_mask_dir=args.input_mask_dir,
|
||||
output_mask_dir=args.output_mask_dir,
|
||||
video_name=video_name,
|
||||
score_thresh=args.score_thresh,
|
||||
use_all_masks=args.use_all_masks,
|
||||
per_obj_png_file=args.per_obj_png_file,
|
||||
)
|
||||
else:
|
||||
vos_separate_inference_per_object(
|
||||
predictor=predictor,
|
||||
base_video_dir=args.base_video_dir,
|
||||
input_mask_dir=args.input_mask_dir,
|
||||
output_mask_dir=args.output_mask_dir,
|
||||
video_name=video_name,
|
||||
score_thresh=args.score_thresh,
|
||||
use_all_masks=args.use_all_masks,
|
||||
per_obj_png_file=args.per_obj_png_file,
|
||||
)
|
||||
|
||||
print(
|
||||
f"completed VOS prediction on {len(video_names)} videos -- "
|
||||
|
Reference in New Issue
Block a user