[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
300
training/dataset/vos_segment_loader.py
Normal file
300
training/dataset/vos_segment_loader.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
try:
|
||||
from pycocotools import mask as mask_utils
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class JSONSegmentLoader:
|
||||
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
|
||||
# Annotations in the json are provided every ann_every th frame
|
||||
self.ann_every = ann_every
|
||||
# Ids of the objects to consider when sampling this video
|
||||
self.valid_obj_ids = valid_obj_ids
|
||||
with open(video_json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
self.frame_annots = data
|
||||
elif isinstance(data, dict):
|
||||
masklet_field_name = "masklet" if "masklet" in data else "masks"
|
||||
self.frame_annots = data[masklet_field_name]
|
||||
if "fps" in data:
|
||||
if isinstance(data["fps"], list):
|
||||
annotations_fps = int(data["fps"][0])
|
||||
else:
|
||||
annotations_fps = int(data["fps"])
|
||||
assert frames_fps % annotations_fps == 0
|
||||
self.ann_every = frames_fps // annotations_fps
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, frame_id, obj_ids=None):
|
||||
assert frame_id % self.ann_every == 0
|
||||
rle_mask = self.frame_annots[frame_id // self.ann_every]
|
||||
|
||||
valid_objs_ids = set(range(len(rle_mask)))
|
||||
if self.valid_obj_ids is not None:
|
||||
# Remove the masklets that have been filtered out for this video
|
||||
valid_objs_ids &= set(self.valid_obj_ids)
|
||||
if obj_ids is not None:
|
||||
# Only keep the objects that have been sampled
|
||||
valid_objs_ids &= set(obj_ids)
|
||||
valid_objs_ids = sorted(list(valid_objs_ids))
|
||||
|
||||
# Construct rle_masks_filtered that only contains the rle masks we are interested in
|
||||
id_2_idx = {}
|
||||
rle_mask_filtered = []
|
||||
for obj_id in valid_objs_ids:
|
||||
if rle_mask[obj_id] is not None:
|
||||
id_2_idx[obj_id] = len(rle_mask_filtered)
|
||||
rle_mask_filtered.append(rle_mask[obj_id])
|
||||
else:
|
||||
id_2_idx[obj_id] = None
|
||||
|
||||
# Decode the masks
|
||||
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
|
||||
2, 0, 1
|
||||
) # (num_obj, h, w)
|
||||
segments = {}
|
||||
for obj_id in valid_objs_ids:
|
||||
if id_2_idx[obj_id] is None:
|
||||
segments[obj_id] = None
|
||||
else:
|
||||
idx = id_2_idx[obj_id]
|
||||
segments[obj_id] = raw_segments[idx]
|
||||
return segments
|
||||
|
||||
def get_valid_obj_frames_ids(self, num_frames_min=None):
|
||||
# For each object, find all the frames with a valid (not None) mask
|
||||
num_objects = len(self.frame_annots[0])
|
||||
|
||||
# The result dict associates each obj_id with the id of its valid frames
|
||||
res = {obj_id: [] for obj_id in range(num_objects)}
|
||||
|
||||
for annot_idx, annot in enumerate(self.frame_annots):
|
||||
for obj_id in range(num_objects):
|
||||
if annot[obj_id] is not None:
|
||||
res[obj_id].append(int(annot_idx * self.ann_every))
|
||||
|
||||
if num_frames_min is not None:
|
||||
# Remove masklets that have less than num_frames_min valid masks
|
||||
for obj_id, valid_frames in list(res.items()):
|
||||
if len(valid_frames) < num_frames_min:
|
||||
res.pop(obj_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PalettisedPNGSegmentLoader:
|
||||
def __init__(self, video_png_root):
|
||||
"""
|
||||
SegmentLoader for datasets with masks stored as palettised PNGs.
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
# build a mapping from frame id to their PNG mask path
|
||||
# note that in some datasets, the PNG paths could have more
|
||||
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
|
||||
png_filenames = os.listdir(self.video_png_root)
|
||||
self.frame_id_to_png_filename = {}
|
||||
for filename in png_filenames:
|
||||
frame_id, _ = os.path.splitext(filename)
|
||||
self.frame_id_to_png_filename[int(frame_id)] = filename
|
||||
|
||||
def load(self, frame_id):
|
||||
"""
|
||||
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# check the path
|
||||
mask_path = os.path.join(
|
||||
self.video_png_root, self.frame_id_to_png_filename[frame_id]
|
||||
)
|
||||
|
||||
# load the mask
|
||||
masks = PILImage.open(mask_path).convert("P")
|
||||
masks = np.array(masks)
|
||||
|
||||
object_id = pd.unique(masks.flatten())
|
||||
object_id = object_id[object_id != 0] # remove background (0)
|
||||
|
||||
# convert into N binary segmentation masks
|
||||
binary_segments = {}
|
||||
for i in object_id:
|
||||
bs = masks == i
|
||||
binary_segments[i] = torch.from_numpy(bs)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class MultiplePNGSegmentLoader:
|
||||
def __init__(self, video_png_root, single_object_mode=False):
|
||||
"""
|
||||
video_png_root: the folder contains all the masks stored in png
|
||||
single_object_mode: whether to load only a single object at a time
|
||||
"""
|
||||
self.video_png_root = video_png_root
|
||||
self.single_object_mode = single_object_mode
|
||||
# read a mask to know the resolution of the video
|
||||
if self.single_object_mode:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
|
||||
else:
|
||||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
|
||||
tmp_mask = np.array(PILImage.open(tmp_mask_path))
|
||||
self.H = tmp_mask.shape[0]
|
||||
self.W = tmp_mask.shape[1]
|
||||
if self.single_object_mode:
|
||||
self.obj_id = (
|
||||
int(video_png_root.split("/")[-1]) + 1
|
||||
) # offset by 1 as bg is 0
|
||||
else:
|
||||
self.obj_id = None
|
||||
|
||||
def load(self, frame_id):
|
||||
if self.single_object_mode:
|
||||
return self._load_single_png(frame_id)
|
||||
else:
|
||||
return self._load_multiple_pngs(frame_id)
|
||||
|
||||
def _load_single_png(self, frame_id):
|
||||
"""
|
||||
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
|
||||
binary_segments = {}
|
||||
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
# if png doesn't exist, empty mask
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
|
||||
return binary_segments
|
||||
|
||||
def _load_multiple_pngs(self, frame_id):
|
||||
"""
|
||||
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
|
||||
Args:
|
||||
frame_id: int, define the mask path
|
||||
Return:
|
||||
binary_segments: dict
|
||||
"""
|
||||
# get the path
|
||||
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
|
||||
num_objects = len(all_objects)
|
||||
assert num_objects > 0
|
||||
|
||||
# load the masks
|
||||
binary_segments = {}
|
||||
for obj_folder in all_objects:
|
||||
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
|
||||
obj_id = int(obj_folder.split("/")[-1])
|
||||
obj_id = obj_id + 1 # offset 1 as bg is 0
|
||||
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
|
||||
if os.path.exists(mask_path):
|
||||
mask = np.array(PILImage.open(mask_path))
|
||||
else:
|
||||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||||
binary_segments[obj_id] = torch.from_numpy(mask > 0)
|
||||
|
||||
return binary_segments
|
||||
|
||||
def __len__(self):
|
||||
return
|
||||
|
||||
|
||||
class LazySegments:
|
||||
"""
|
||||
Only decodes segments that are actually used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.segments = {}
|
||||
self.cache = {}
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
self.segments[key] = item
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
rle = self.segments[key]
|
||||
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
|
||||
self.cache[key] = mask
|
||||
return mask
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.segments
|
||||
|
||||
def __len__(self):
|
||||
return len(self.segments)
|
||||
|
||||
def keys(self):
|
||||
return self.segments.keys()
|
||||
|
||||
|
||||
class SA1BSegmentLoader:
|
||||
def __init__(
|
||||
self,
|
||||
video_mask_path,
|
||||
mask_area_frac_thresh=1.1,
|
||||
video_frame_path=None,
|
||||
uncertain_iou=-1,
|
||||
):
|
||||
with open(video_mask_path, "r") as f:
|
||||
self.frame_annots = json.load(f)
|
||||
|
||||
if mask_area_frac_thresh <= 1.0:
|
||||
# Lazily read frame
|
||||
orig_w, orig_h = PILImage.open(video_frame_path).size
|
||||
area = orig_w * orig_h
|
||||
|
||||
self.frame_annots = self.frame_annots["annotations"]
|
||||
|
||||
rle_masks = []
|
||||
for frame_annot in self.frame_annots:
|
||||
if not frame_annot["area"] > 0:
|
||||
continue
|
||||
if ("uncertain_iou" in frame_annot) and (
|
||||
frame_annot["uncertain_iou"] < uncertain_iou
|
||||
):
|
||||
# uncertain_iou is stability score
|
||||
continue
|
||||
if (
|
||||
mask_area_frac_thresh <= 1.0
|
||||
and (frame_annot["area"] / area) >= mask_area_frac_thresh
|
||||
):
|
||||
continue
|
||||
rle_masks.append(frame_annot["segmentation"])
|
||||
|
||||
self.segments = LazySegments()
|
||||
for i, rle in enumerate(rle_masks):
|
||||
self.segments[i] = rle
|
||||
|
||||
def load(self, frame_idx):
|
||||
return self.segments
|
Reference in New Issue
Block a user