support gsam2 image predictor model
This commit is contained in:
36
tools/README.md
Normal file
36
tools/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
## SAM 2 toolkits
|
||||
|
||||
This directory provides toolkits for additional SAM 2 use cases.
|
||||
|
||||
### Semi-supervised VOS inference
|
||||
|
||||
The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset.
|
||||
|
||||
After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`.
|
||||
```bash
|
||||
python ./tools/vos_inference.py \
|
||||
--sam2_cfg sam2_hiera_b+.yaml \
|
||||
--sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \
|
||||
--base_video_dir /path-to-davis-2017/JPEGImages/480p \
|
||||
--input_mask_dir /path-to-davis-2017/Annotations/480p \
|
||||
--video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \
|
||||
--output_mask_dir ./outputs/davis_2017_pred_pngs
|
||||
```
|
||||
(replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset)
|
||||
|
||||
To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag.
|
||||
```bash
|
||||
python ./tools/vos_inference.py \
|
||||
--sam2_cfg sam2_hiera_b+.yaml \
|
||||
--sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \
|
||||
--base_video_dir /path-to-sav-val/JPEGImages_24fps \
|
||||
--input_mask_dir /path-to-sav-val/Annotations_6fps \
|
||||
--video_list_file /path-to-sav-val/sav_val.txt \
|
||||
--per_obj_png_file \
|
||||
--output_mask_dir ./outputs/sav_val_pred_pngs
|
||||
```
|
||||
(replace `/path-to-sav-val` with the path to SA-V val)
|
||||
|
||||
Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above.
|
||||
|
||||
**Note: a limitation of the `vos_inference.py` script above is that currently it only supports VOS datasets where all objects to track already appear on frame 0 in each video** (and therefore it doesn't apply to some datasets such as [LVOS](https://lingyihongfd.github.io/lvos.github.io/) that have objects only appearing in the middle of a video).
|
320
tools/vos_inference.py
Normal file
320
tools/vos_inference.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
|
||||
# the PNG palette for DAVIS 2017 dataset
|
||||
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
|
||||
|
||||
|
||||
def load_ann_png(path):
|
||||
"""Load a PNG file as a mask and its palette."""
|
||||
mask = Image.open(path)
|
||||
palette = mask.getpalette()
|
||||
mask = np.array(mask).astype(np.uint8)
|
||||
return mask, palette
|
||||
|
||||
|
||||
def save_ann_png(path, mask, palette):
|
||||
"""Save a mask as a PNG file with the given palette."""
|
||||
assert mask.dtype == np.uint8
|
||||
assert mask.ndim == 2
|
||||
output_mask = Image.fromarray(mask)
|
||||
output_mask.putpalette(palette)
|
||||
output_mask.save(path)
|
||||
|
||||
|
||||
def get_per_obj_mask(mask):
|
||||
"""Split a mask into per-object masks."""
|
||||
object_ids = np.unique(mask)
|
||||
object_ids = object_ids[object_ids > 0].tolist()
|
||||
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
|
||||
return per_obj_mask
|
||||
|
||||
|
||||
def put_per_obj_mask(per_obj_mask, height, width):
|
||||
"""Combine per-object masks into a single mask."""
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
object_ids = sorted(per_obj_mask)[::-1]
|
||||
for object_id in object_ids:
|
||||
object_mask = per_obj_mask[object_id]
|
||||
object_mask = object_mask.reshape(height, width)
|
||||
mask[object_mask] = object_id
|
||||
return mask
|
||||
|
||||
|
||||
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file):
|
||||
"""Load masks from a directory as a dict of per-object masks."""
|
||||
if not per_obj_png_file:
|
||||
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask = get_per_obj_mask(input_mask)
|
||||
else:
|
||||
per_obj_input_mask = {}
|
||||
# each object is a directory in "{object_id:%03d}" format
|
||||
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
|
||||
object_id = int(object_name)
|
||||
input_mask_path = os.path.join(
|
||||
input_mask_dir, video_name, object_name, f"{frame_name}.png"
|
||||
)
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask[object_id] = input_mask > 0
|
||||
|
||||
return per_obj_input_mask, input_palette
|
||||
|
||||
|
||||
def save_masks_to_dir(
|
||||
output_mask_dir,
|
||||
video_name,
|
||||
frame_name,
|
||||
per_obj_output_mask,
|
||||
height,
|
||||
width,
|
||||
per_obj_png_file,
|
||||
output_palette,
|
||||
):
|
||||
"""Save masks to a directory as PNG files."""
|
||||
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
|
||||
if not per_obj_png_file:
|
||||
output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
|
||||
output_mask_path = os.path.join(
|
||||
output_mask_dir, video_name, f"{frame_name}.png"
|
||||
)
|
||||
save_ann_png(output_mask_path, output_mask, output_palette)
|
||||
else:
|
||||
for object_id, object_mask in per_obj_output_mask.items():
|
||||
object_name = f"{object_id:03d}"
|
||||
os.makedirs(
|
||||
os.path.join(output_mask_dir, video_name, object_name),
|
||||
exist_ok=True,
|
||||
)
|
||||
output_mask = object_mask.reshape(height, width).astype(np.uint8)
|
||||
output_mask_path = os.path.join(
|
||||
output_mask_dir, video_name, object_name, f"{frame_name}.png"
|
||||
)
|
||||
save_ann_png(output_mask_path, output_mask, output_palette)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
def vos_inference(
|
||||
predictor,
|
||||
base_video_dir,
|
||||
input_mask_dir,
|
||||
output_mask_dir,
|
||||
video_name,
|
||||
score_thresh=0.0,
|
||||
use_all_masks=False,
|
||||
per_obj_png_file=False,
|
||||
):
|
||||
"""Run VOS inference on a single video with the given predictor."""
|
||||
# load the video frames and initialize the inference state on this video
|
||||
video_dir = os.path.join(base_video_dir, video_name)
|
||||
frame_names = [
|
||||
os.path.splitext(p)[0]
|
||||
for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
inference_state = predictor.init_state(
|
||||
video_path=video_dir, async_loading_frames=False
|
||||
)
|
||||
height = inference_state["video_height"]
|
||||
width = inference_state["video_width"]
|
||||
input_palette = None
|
||||
|
||||
# fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
|
||||
if not use_all_masks:
|
||||
# use only the first video's ground-truth mask as the input mask
|
||||
input_frame_inds = [0]
|
||||
else:
|
||||
# use all mask files available in the input_mask_dir as the input masks
|
||||
if not per_obj_png_file:
|
||||
input_frame_inds = [
|
||||
idx
|
||||
for idx, name in enumerate(frame_names)
|
||||
if os.path.exists(
|
||||
os.path.join(input_mask_dir, video_name, f"{name}.png")
|
||||
)
|
||||
]
|
||||
else:
|
||||
input_frame_inds = [
|
||||
idx
|
||||
for object_name in os.listdir(os.path.join(input_mask_dir, video_name))
|
||||
for idx, name in enumerate(frame_names)
|
||||
if os.path.exists(
|
||||
os.path.join(input_mask_dir, video_name, object_name, f"{name}.png")
|
||||
)
|
||||
]
|
||||
input_frame_inds = sorted(set(input_frame_inds))
|
||||
|
||||
# add those input masks to SAM 2 inference state before propagation
|
||||
for input_frame_idx in input_frame_inds:
|
||||
per_obj_input_mask, input_palette = load_masks_from_dir(
|
||||
input_mask_dir=input_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[input_frame_idx],
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
)
|
||||
for object_id, object_mask in per_obj_input_mask.items():
|
||||
predictor.add_new_mask(
|
||||
inference_state=inference_state,
|
||||
frame_idx=input_frame_idx,
|
||||
obj_id=object_id,
|
||||
mask=object_mask,
|
||||
)
|
||||
|
||||
# run propagation throughout the video and collect the results in a dict
|
||||
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
|
||||
output_palette = input_palette or DAVIS_PALETTE
|
||||
video_segments = {} # video_segments contains the per-frame segmentation results
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
||||
inference_state
|
||||
):
|
||||
per_obj_output_mask = {
|
||||
out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
video_segments[out_frame_idx] = per_obj_output_mask
|
||||
|
||||
# write the output masks as palette PNG files to output_mask_dir
|
||||
for out_frame_idx, per_obj_output_mask in video_segments.items():
|
||||
save_masks_to_dir(
|
||||
output_mask_dir=output_mask_dir,
|
||||
video_name=video_name,
|
||||
frame_name=frame_names[out_frame_idx],
|
||||
per_obj_output_mask=per_obj_output_mask,
|
||||
height=height,
|
||||
width=width,
|
||||
per_obj_png_file=per_obj_png_file,
|
||||
output_palette=output_palette,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sam2_cfg",
|
||||
type=str,
|
||||
default="sam2_hiera_b+.yaml",
|
||||
help="SAM 2 model configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sam2_checkpoint",
|
||||
type=str,
|
||||
default="./checkpoints/sam2_hiera_b+.pt",
|
||||
help="path to the SAM 2 model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_video_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory containing videos (as JPEG files) to run VOS prediction on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_mask_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory containing input masks (as PNG files) of each video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_list_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="text file containing the list of video names to run VOS prediction on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_mask_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to save the output masks (as PNG files)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--score_thresh",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="threshold for the output mask logits (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_all_masks",
|
||||
action="store_true",
|
||||
help="whether to use all available PNG files in input_mask_dir "
|
||||
"(default without this flag: just the first PNG file as input to the SAM 2 model; "
|
||||
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_obj_png_file",
|
||||
action="store_true",
|
||||
help="whether use separate per-object PNG files for input and output masks "
|
||||
"(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
|
||||
"note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_postprocessing",
|
||||
action="store_true",
|
||||
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
|
||||
"(we don't apply such post-processing in the SAM 2 model evaluation)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
||||
hydra_overrides_extra = [
|
||||
"++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
|
||||
]
|
||||
predictor = build_sam2_video_predictor(
|
||||
config_file=args.sam2_cfg,
|
||||
ckpt_path=args.sam2_checkpoint,
|
||||
apply_postprocessing=args.apply_postprocessing,
|
||||
hydra_overrides_extra=hydra_overrides_extra,
|
||||
)
|
||||
|
||||
if args.use_all_masks:
|
||||
print("using all available masks in input_mask_dir as input to the SAM 2 model")
|
||||
else:
|
||||
print(
|
||||
"using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
|
||||
)
|
||||
# if a video list file is provided, read the video names from the file
|
||||
# (otherwise, we use all subdirectories in base_video_dir)
|
||||
if args.video_list_file is not None:
|
||||
with open(args.video_list_file, "r") as f:
|
||||
video_names = [v.strip() for v in f.readlines()]
|
||||
else:
|
||||
video_names = [
|
||||
p
|
||||
for p in os.listdir(args.base_video_dir)
|
||||
if os.path.isdir(os.path.join(args.base_video_dir, p))
|
||||
]
|
||||
print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
|
||||
|
||||
for n_video, video_name in enumerate(video_names):
|
||||
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
|
||||
vos_inference(
|
||||
predictor=predictor,
|
||||
base_video_dir=args.base_video_dir,
|
||||
input_mask_dir=args.input_mask_dir,
|
||||
output_mask_dir=args.output_mask_dir,
|
||||
video_name=video_name,
|
||||
score_thresh=args.score_thresh,
|
||||
use_all_masks=args.use_all_masks,
|
||||
per_obj_png_file=args.per_obj_png_file,
|
||||
)
|
||||
|
||||
print(
|
||||
f"completed VOS prediction on {len(video_names)} videos -- "
|
||||
f"output masks saved to {args.output_mask_dir}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user