SAM2.1 checkpoints + training code + Demo
This commit is contained in:
Haitham Khedr
2024-09-28 08:20:56 -07:00
parent 7e1596c0b6
commit aa9b8722d0
325 changed files with 38174 additions and 223 deletions

View File

@@ -9,8 +9,8 @@ The `vos_inference.py` script can be used to generate predictions for semi-super
After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
```bash
python ./tools/vos_inference.py \
--sam2_cfg sam2_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-davis-2017/JPEGImages/480p \
--input_mask_dir /path-to-davis-2017/Annotations/480p \
--video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \
@@ -21,8 +21,8 @@ python ./tools/vos_inference.py \
To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag.
```bash
python ./tools/vos_inference.py \
--sam2_cfg sam2_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \
--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \
--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--base_video_dir /path-to-sav-val/JPEGImages_24fps \
--input_mask_dir /path-to-sav-val/Annotations_6fps \
--video_list_file /path-to-sav-val/sav_val.txt \
@@ -33,4 +33,4 @@ python ./tools/vos_inference.py \
Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above.
**Note: a limitation of the `vos_inference.py` script above is that currently it only supports VOS datasets where all objects to track already appear on frame 0 in each video** (and therefore it doesn't apply to some datasets such as [LVOS](https://lingyihongfd.github.io/lvos.github.io/) that have objects only appearing in the middle of a video).
Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**.

View File

@@ -6,6 +6,7 @@
import argparse
import os
from collections import defaultdict
import numpy as np
import torch
@@ -53,20 +54,27 @@ def put_per_obj_mask(per_obj_mask, height, width):
return mask
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file):
def load_masks_from_dir(
input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
):
"""Load masks from a directory as a dict of per-object masks."""
if not per_obj_png_file:
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
return {}, None
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask = get_per_obj_mask(input_mask)
else:
per_obj_input_mask = {}
input_palette = None
# each object is a directory in "{object_id:%03d}" format
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
object_id = int(object_name)
input_mask_path = os.path.join(
input_mask_dir, video_name, object_name, f"{frame_name}.png"
)
if allow_missing and not os.path.exists(input_mask_path):
continue
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask[object_id] = input_mask > 0
@@ -156,17 +164,44 @@ def vos_inference(
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
)
]
# check and make sure we got at least one input frame
if len(input_frame_inds) == 0:
raise RuntimeError(
f"In {video_name=}, got no input masks in {input_mask_dir=}. "
"Please make sure the input masks are available in the correct format."
)
input_frame_inds = sorted(set(input_frame_inds))
# add those input masks to SAM 2 inference state before propagation
object_ids_set = None
for input_frame_idx in input_frame_inds:
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[input_frame_idx],
per_obj_png_file=per_obj_png_file,
)
try:
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[input_frame_idx],
per_obj_png_file=per_obj_png_file,
)
except FileNotFoundError as e:
raise RuntimeError(
f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
) from e
# get the list of object ids to track from the first input frame
if object_ids_set is None:
object_ids_set = set(per_obj_input_mask)
for object_id, object_mask in per_obj_input_mask.items():
# check and make sure no new object ids appear only in later frames
if object_id not in object_ids_set:
raise RuntimeError(
f"In {video_name=}, got a new {object_id=} appearing only in a "
f"later {input_frame_idx=} (but not appearing in the first frame). "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
@@ -174,6 +209,14 @@ def vos_inference(
mask=object_mask,
)
# check and make sure we have at least one object to track
if object_ids_set is None or len(object_ids_set) == 0:
raise RuntimeError(
f"In {video_name=}, got no object ids on {input_frame_inds=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
# run propagation throughout the video and collect the results in a dict
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
@@ -201,18 +244,138 @@ def vos_inference(
)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_separate_inference_per_object(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""
Run VOS inference on a single video with the given predictor.
Unlike `vos_inference`, this function run inference separately for each object
in a video, which could be applied to datasets like LVOS or YouTube-VOS that
don't have all objects to track appearing in the first frame (i.e. some objects
might appear only later in the video).
"""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# collect all the object ids and their input masks
inputs_per_object = defaultdict(dict)
for idx, name in enumerate(frame_names):
if per_obj_png_file or os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
):
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[idx],
per_obj_png_file=per_obj_png_file,
allow_missing=True,
)
for object_id, object_mask in per_obj_input_mask.items():
# skip empty masks
if not np.any(object_mask):
continue
# if `use_all_masks=False`, we only use the first mask for each object
if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
continue
print(f"adding mask from frame {idx} as input for {object_id=}")
inputs_per_object[object_id][idx] = object_mask
# run inference separately for each object in the video
object_ids = sorted(inputs_per_object)
output_scores_per_object = defaultdict(dict)
for object_id in object_ids:
# add those input masks to SAM 2 inference state before propagation
input_frame_inds = sorted(inputs_per_object[object_id])
predictor.reset_state(inference_state)
for input_frame_idx in input_frame_inds:
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=inputs_per_object[object_id][input_frame_idx],
)
# run propagation throughout the video and collect the results in a dict
for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
inference_state,
start_frame_idx=min(input_frame_inds),
reverse=False,
):
obj_scores = out_mask_logits.cpu().numpy()
output_scores_per_object[object_id][out_frame_idx] = obj_scores
# post-processing: consolidate the per-object scores into per-frame masks
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for frame_idx in range(len(frame_names)):
scores = torch.full(
size=(len(object_ids), 1, height, width),
fill_value=-1024.0,
dtype=torch.float32,
)
for i, object_id in enumerate(object_ids):
if frame_idx in output_scores_per_object[object_id]:
scores[i] = torch.from_numpy(
output_scores_per_object[object_id][frame_idx]
)
if not per_obj_png_file:
scores = predictor._apply_non_overlapping_constraints(scores)
per_obj_output_mask = {
object_id: (scores[i] > score_thresh).cpu().numpy()
for i, object_id in enumerate(object_ids)
}
video_segments[frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sam2_cfg",
type=str,
default="sam2_hiera_b+.yaml",
default="configs/sam2.1/sam2.1_hiera_b+.yaml",
help="SAM 2 model configuration file",
)
parser.add_argument(
"--sam2_checkpoint",
type=str,
default="./checkpoints/sam2_hiera_b+.pt",
default="./checkpoints/sam2.1_hiera_b+.pt",
help="path to the SAM 2 model checkpoint",
)
parser.add_argument(
@@ -265,6 +428,12 @@ def main():
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
"(we don't apply such post-processing in the SAM 2 model evaluation)",
)
parser.add_argument(
"--track_object_appearing_later_in_video",
action="store_true",
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
)
args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
@@ -299,16 +468,28 @@ def main():
for n_video, video_name in enumerate(video_names):
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
vos_inference(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
if not args.track_object_appearing_later_in_video:
vos_inference(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
else:
vos_separate_inference_per_object(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
print(
f"completed VOS prediction on {len(video_names)} videos -- "