301 lines
10 KiB
Python
301 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
# All rights reserved.
|
||
|
||
# This source code is licensed under the license found in the
|
||
# LICENSE file in the root directory of this source tree.
|
||
|
||
import glob
|
||
import json
|
||
import os
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
|
||
from PIL import Image as PILImage
|
||
|
||
try:
|
||
from pycocotools import mask as mask_utils
|
||
except:
|
||
pass
|
||
|
||
|
||
class JSONSegmentLoader:
|
||
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
|
||
# Annotations in the json are provided every ann_every th frame
|
||
self.ann_every = ann_every
|
||
# Ids of the objects to consider when sampling this video
|
||
self.valid_obj_ids = valid_obj_ids
|
||
with open(video_json_path, "r") as f:
|
||
data = json.load(f)
|
||
if isinstance(data, list):
|
||
self.frame_annots = data
|
||
elif isinstance(data, dict):
|
||
masklet_field_name = "masklet" if "masklet" in data else "masks"
|
||
self.frame_annots = data[masklet_field_name]
|
||
if "fps" in data:
|
||
if isinstance(data["fps"], list):
|
||
annotations_fps = int(data["fps"][0])
|
||
else:
|
||
annotations_fps = int(data["fps"])
|
||
assert frames_fps % annotations_fps == 0
|
||
self.ann_every = frames_fps // annotations_fps
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
def load(self, frame_id, obj_ids=None):
|
||
assert frame_id % self.ann_every == 0
|
||
rle_mask = self.frame_annots[frame_id // self.ann_every]
|
||
|
||
valid_objs_ids = set(range(len(rle_mask)))
|
||
if self.valid_obj_ids is not None:
|
||
# Remove the masklets that have been filtered out for this video
|
||
valid_objs_ids &= set(self.valid_obj_ids)
|
||
if obj_ids is not None:
|
||
# Only keep the objects that have been sampled
|
||
valid_objs_ids &= set(obj_ids)
|
||
valid_objs_ids = sorted(list(valid_objs_ids))
|
||
|
||
# Construct rle_masks_filtered that only contains the rle masks we are interested in
|
||
id_2_idx = {}
|
||
rle_mask_filtered = []
|
||
for obj_id in valid_objs_ids:
|
||
if rle_mask[obj_id] is not None:
|
||
id_2_idx[obj_id] = len(rle_mask_filtered)
|
||
rle_mask_filtered.append(rle_mask[obj_id])
|
||
else:
|
||
id_2_idx[obj_id] = None
|
||
|
||
# Decode the masks
|
||
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
|
||
2, 0, 1
|
||
) # (num_obj, h, w)
|
||
segments = {}
|
||
for obj_id in valid_objs_ids:
|
||
if id_2_idx[obj_id] is None:
|
||
segments[obj_id] = None
|
||
else:
|
||
idx = id_2_idx[obj_id]
|
||
segments[obj_id] = raw_segments[idx]
|
||
return segments
|
||
|
||
def get_valid_obj_frames_ids(self, num_frames_min=None):
|
||
# For each object, find all the frames with a valid (not None) mask
|
||
num_objects = len(self.frame_annots[0])
|
||
|
||
# The result dict associates each obj_id with the id of its valid frames
|
||
res = {obj_id: [] for obj_id in range(num_objects)}
|
||
|
||
for annot_idx, annot in enumerate(self.frame_annots):
|
||
for obj_id in range(num_objects):
|
||
if annot[obj_id] is not None:
|
||
res[obj_id].append(int(annot_idx * self.ann_every))
|
||
|
||
if num_frames_min is not None:
|
||
# Remove masklets that have less than num_frames_min valid masks
|
||
for obj_id, valid_frames in list(res.items()):
|
||
if len(valid_frames) < num_frames_min:
|
||
res.pop(obj_id)
|
||
|
||
return res
|
||
|
||
|
||
class PalettisedPNGSegmentLoader:
|
||
def __init__(self, video_png_root):
|
||
"""
|
||
SegmentLoader for datasets with masks stored as palettised PNGs.
|
||
video_png_root: the folder contains all the masks stored in png
|
||
"""
|
||
self.video_png_root = video_png_root
|
||
# build a mapping from frame id to their PNG mask path
|
||
# note that in some datasets, the PNG paths could have more
|
||
# than 5 digits, e.g. "00000000.png" instead of "00000.png"
|
||
png_filenames = os.listdir(self.video_png_root)
|
||
self.frame_id_to_png_filename = {}
|
||
for filename in png_filenames:
|
||
frame_id, _ = os.path.splitext(filename)
|
||
self.frame_id_to_png_filename[int(frame_id)] = filename
|
||
|
||
def load(self, frame_id):
|
||
"""
|
||
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
|
||
Args:
|
||
frame_id: int, define the mask path
|
||
Return:
|
||
binary_segments: dict
|
||
"""
|
||
# check the path
|
||
mask_path = os.path.join(
|
||
self.video_png_root, self.frame_id_to_png_filename[frame_id]
|
||
)
|
||
|
||
# load the mask
|
||
masks = PILImage.open(mask_path).convert("P")
|
||
masks = np.array(masks)
|
||
|
||
object_id = pd.unique(masks.flatten())
|
||
object_id = object_id[object_id != 0] # remove background (0)
|
||
|
||
# convert into N binary segmentation masks
|
||
binary_segments = {}
|
||
for i in object_id:
|
||
bs = masks == i
|
||
binary_segments[i] = torch.from_numpy(bs)
|
||
|
||
return binary_segments
|
||
|
||
def __len__(self):
|
||
return
|
||
|
||
|
||
class MultiplePNGSegmentLoader:
|
||
def __init__(self, video_png_root, single_object_mode=False):
|
||
"""
|
||
video_png_root: the folder contains all the masks stored in png
|
||
single_object_mode: whether to load only a single object at a time
|
||
"""
|
||
self.video_png_root = video_png_root
|
||
self.single_object_mode = single_object_mode
|
||
# read a mask to know the resolution of the video
|
||
if self.single_object_mode:
|
||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
|
||
else:
|
||
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
|
||
tmp_mask = np.array(PILImage.open(tmp_mask_path))
|
||
self.H = tmp_mask.shape[0]
|
||
self.W = tmp_mask.shape[1]
|
||
if self.single_object_mode:
|
||
self.obj_id = (
|
||
int(video_png_root.split("/")[-1]) + 1
|
||
) # offset by 1 as bg is 0
|
||
else:
|
||
self.obj_id = None
|
||
|
||
def load(self, frame_id):
|
||
if self.single_object_mode:
|
||
return self._load_single_png(frame_id)
|
||
else:
|
||
return self._load_multiple_pngs(frame_id)
|
||
|
||
def _load_single_png(self, frame_id):
|
||
"""
|
||
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
|
||
Args:
|
||
frame_id: int, define the mask path
|
||
Return:
|
||
binary_segments: dict
|
||
"""
|
||
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
|
||
binary_segments = {}
|
||
|
||
if os.path.exists(mask_path):
|
||
mask = np.array(PILImage.open(mask_path))
|
||
else:
|
||
# if png doesn't exist, empty mask
|
||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||
binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
|
||
return binary_segments
|
||
|
||
def _load_multiple_pngs(self, frame_id):
|
||
"""
|
||
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
|
||
Args:
|
||
frame_id: int, define the mask path
|
||
Return:
|
||
binary_segments: dict
|
||
"""
|
||
# get the path
|
||
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
|
||
num_objects = len(all_objects)
|
||
assert num_objects > 0
|
||
|
||
# load the masks
|
||
binary_segments = {}
|
||
for obj_folder in all_objects:
|
||
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
|
||
obj_id = int(obj_folder.split("/")[-1])
|
||
obj_id = obj_id + 1 # offset 1 as bg is 0
|
||
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
|
||
if os.path.exists(mask_path):
|
||
mask = np.array(PILImage.open(mask_path))
|
||
else:
|
||
mask = np.zeros((self.H, self.W), dtype=bool)
|
||
binary_segments[obj_id] = torch.from_numpy(mask > 0)
|
||
|
||
return binary_segments
|
||
|
||
def __len__(self):
|
||
return
|
||
|
||
|
||
class LazySegments:
|
||
"""
|
||
Only decodes segments that are actually used.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.segments = {}
|
||
self.cache = {}
|
||
|
||
def __setitem__(self, key, item):
|
||
self.segments[key] = item
|
||
|
||
def __getitem__(self, key):
|
||
if key in self.cache:
|
||
return self.cache[key]
|
||
rle = self.segments[key]
|
||
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
|
||
self.cache[key] = mask
|
||
return mask
|
||
|
||
def __contains__(self, key):
|
||
return key in self.segments
|
||
|
||
def __len__(self):
|
||
return len(self.segments)
|
||
|
||
def keys(self):
|
||
return self.segments.keys()
|
||
|
||
|
||
class SA1BSegmentLoader:
|
||
def __init__(
|
||
self,
|
||
video_mask_path,
|
||
mask_area_frac_thresh=1.1,
|
||
video_frame_path=None,
|
||
uncertain_iou=-1,
|
||
):
|
||
with open(video_mask_path, "r") as f:
|
||
self.frame_annots = json.load(f)
|
||
|
||
if mask_area_frac_thresh <= 1.0:
|
||
# Lazily read frame
|
||
orig_w, orig_h = PILImage.open(video_frame_path).size
|
||
area = orig_w * orig_h
|
||
|
||
self.frame_annots = self.frame_annots["annotations"]
|
||
|
||
rle_masks = []
|
||
for frame_annot in self.frame_annots:
|
||
if not frame_annot["area"] > 0:
|
||
continue
|
||
if ("uncertain_iou" in frame_annot) and (
|
||
frame_annot["uncertain_iou"] < uncertain_iou
|
||
):
|
||
# uncertain_iou is stability score
|
||
continue
|
||
if (
|
||
mask_area_frac_thresh <= 1.0
|
||
and (frame_annot["area"] / area) >= mask_area_frac_thresh
|
||
):
|
||
continue
|
||
rle_masks.append(frame_annot["segmentation"])
|
||
|
||
self.segments = LazySegments()
|
||
for i, rle in enumerate(rle_masks):
|
||
self.segments[i] = rle
|
||
|
||
def load(self, frame_idx):
|
||
return self.segments
|