309 lines
9.8 KiB
Python
309 lines
9.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import glob
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
from typing import List, Optional
|
|
|
|
import pandas as pd
|
|
|
|
import torch
|
|
|
|
from iopath.common.file_io import g_pathmgr
|
|
|
|
from omegaconf.listconfig import ListConfig
|
|
|
|
from training.dataset.vos_segment_loader import (
|
|
JSONSegmentLoader,
|
|
MultiplePNGSegmentLoader,
|
|
PalettisedPNGSegmentLoader,
|
|
SA1BSegmentLoader,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class VOSFrame:
|
|
frame_idx: int
|
|
image_path: str
|
|
data: Optional[torch.Tensor] = None
|
|
is_conditioning_only: Optional[bool] = False
|
|
|
|
|
|
@dataclass
|
|
class VOSVideo:
|
|
video_name: str
|
|
video_id: int
|
|
frames: List[VOSFrame]
|
|
|
|
def __len__(self):
|
|
return len(self.frames)
|
|
|
|
|
|
class VOSRawDataset:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_video(self, idx):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class PNGRawDataset(VOSRawDataset):
|
|
def __init__(
|
|
self,
|
|
img_folder,
|
|
gt_folder,
|
|
file_list_txt=None,
|
|
excluded_videos_list_txt=None,
|
|
sample_rate=1,
|
|
is_palette=True,
|
|
single_object_mode=False,
|
|
truncate_video=-1,
|
|
frames_sampling_mult=False,
|
|
):
|
|
self.img_folder = img_folder
|
|
self.gt_folder = gt_folder
|
|
self.sample_rate = sample_rate
|
|
self.is_palette = is_palette
|
|
self.single_object_mode = single_object_mode
|
|
self.truncate_video = truncate_video
|
|
|
|
# Read the subset defined in file_list_txt
|
|
if file_list_txt is not None:
|
|
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
else:
|
|
subset = os.listdir(self.img_folder)
|
|
|
|
# Read and process excluded files if provided
|
|
if excluded_videos_list_txt is not None:
|
|
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
|
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
|
else:
|
|
excluded_files = []
|
|
|
|
# Check if it's not in excluded_files
|
|
self.video_names = sorted(
|
|
[video_name for video_name in subset if video_name not in excluded_files]
|
|
)
|
|
|
|
if self.single_object_mode:
|
|
# single object mode
|
|
self.video_names = sorted(
|
|
[
|
|
os.path.join(video_name, obj)
|
|
for video_name in self.video_names
|
|
for obj in os.listdir(os.path.join(self.gt_folder, video_name))
|
|
]
|
|
)
|
|
|
|
if frames_sampling_mult:
|
|
video_names_mult = []
|
|
for video_name in self.video_names:
|
|
num_frames = len(os.listdir(os.path.join(self.img_folder, video_name)))
|
|
video_names_mult.extend([video_name] * num_frames)
|
|
self.video_names = video_names_mult
|
|
|
|
def get_video(self, idx):
|
|
"""
|
|
Given a VOSVideo object, return the mask tensors.
|
|
"""
|
|
video_name = self.video_names[idx]
|
|
|
|
if self.single_object_mode:
|
|
video_frame_root = os.path.join(
|
|
self.img_folder, os.path.dirname(video_name)
|
|
)
|
|
else:
|
|
video_frame_root = os.path.join(self.img_folder, video_name)
|
|
|
|
video_mask_root = os.path.join(self.gt_folder, video_name)
|
|
|
|
if self.is_palette:
|
|
segment_loader = PalettisedPNGSegmentLoader(video_mask_root)
|
|
else:
|
|
segment_loader = MultiplePNGSegmentLoader(
|
|
video_mask_root, self.single_object_mode
|
|
)
|
|
|
|
all_frames = sorted(glob.glob(os.path.join(video_frame_root, "*.jpg")))
|
|
if self.truncate_video > 0:
|
|
all_frames = all_frames[: self.truncate_video]
|
|
frames = []
|
|
for _, fpath in enumerate(all_frames[:: self.sample_rate]):
|
|
fid = int(os.path.basename(fpath).split(".")[0])
|
|
frames.append(VOSFrame(fid, image_path=fpath))
|
|
video = VOSVideo(video_name, idx, frames)
|
|
return video, segment_loader
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|
|
|
|
|
|
class SA1BRawDataset(VOSRawDataset):
|
|
def __init__(
|
|
self,
|
|
img_folder,
|
|
gt_folder,
|
|
file_list_txt=None,
|
|
excluded_videos_list_txt=None,
|
|
num_frames=1,
|
|
mask_area_frac_thresh=1.1, # no filtering by default
|
|
uncertain_iou=-1, # no filtering by default
|
|
):
|
|
self.img_folder = img_folder
|
|
self.gt_folder = gt_folder
|
|
self.num_frames = num_frames
|
|
self.mask_area_frac_thresh = mask_area_frac_thresh
|
|
self.uncertain_iou = uncertain_iou # stability score
|
|
|
|
# Read the subset defined in file_list_txt
|
|
if file_list_txt is not None:
|
|
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
else:
|
|
subset = os.listdir(self.img_folder)
|
|
subset = [
|
|
path.split(".")[0] for path in subset if path.endswith(".jpg")
|
|
] # remove extension
|
|
|
|
# Read and process excluded files if provided
|
|
if excluded_videos_list_txt is not None:
|
|
with g_pathmgr.open(excluded_videos_list_txt, "r") as f:
|
|
excluded_files = [os.path.splitext(line.strip())[0] for line in f]
|
|
else:
|
|
excluded_files = []
|
|
|
|
# Check if it's not in excluded_files and it exists
|
|
self.video_names = [
|
|
video_name for video_name in subset if video_name not in excluded_files
|
|
]
|
|
|
|
def get_video(self, idx):
|
|
"""
|
|
Given a VOSVideo object, return the mask tensors.
|
|
"""
|
|
video_name = self.video_names[idx]
|
|
|
|
video_frame_path = os.path.join(self.img_folder, video_name + ".jpg")
|
|
video_mask_path = os.path.join(self.gt_folder, video_name + ".json")
|
|
|
|
segment_loader = SA1BSegmentLoader(
|
|
video_mask_path,
|
|
mask_area_frac_thresh=self.mask_area_frac_thresh,
|
|
video_frame_path=video_frame_path,
|
|
uncertain_iou=self.uncertain_iou,
|
|
)
|
|
|
|
frames = []
|
|
for frame_idx in range(self.num_frames):
|
|
frames.append(VOSFrame(frame_idx, image_path=video_frame_path))
|
|
video_name = video_name.split("_")[-1] # filename is sa_{int}
|
|
# video id needs to be image_id to be able to load correct annotation file during eval
|
|
video = VOSVideo(video_name, int(video_name), frames)
|
|
return video, segment_loader
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|
|
|
|
|
|
class JSONRawDataset(VOSRawDataset):
|
|
"""
|
|
Dataset where the annotation in the format of SA-V json files
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
img_folder,
|
|
gt_folder,
|
|
file_list_txt=None,
|
|
excluded_videos_list_txt=None,
|
|
sample_rate=1,
|
|
rm_unannotated=True,
|
|
ann_every=1,
|
|
frames_fps=24,
|
|
):
|
|
self.gt_folder = gt_folder
|
|
self.img_folder = img_folder
|
|
self.sample_rate = sample_rate
|
|
self.rm_unannotated = rm_unannotated
|
|
self.ann_every = ann_every
|
|
self.frames_fps = frames_fps
|
|
|
|
# Read and process excluded files if provided
|
|
excluded_files = []
|
|
if excluded_videos_list_txt is not None:
|
|
if isinstance(excluded_videos_list_txt, str):
|
|
excluded_videos_lists = [excluded_videos_list_txt]
|
|
elif isinstance(excluded_videos_list_txt, ListConfig):
|
|
excluded_videos_lists = list(excluded_videos_list_txt)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
for excluded_videos_list_txt in excluded_videos_lists:
|
|
with open(excluded_videos_list_txt, "r") as f:
|
|
excluded_files.extend(
|
|
[os.path.splitext(line.strip())[0] for line in f]
|
|
)
|
|
excluded_files = set(excluded_files)
|
|
|
|
# Read the subset defined in file_list_txt
|
|
if file_list_txt is not None:
|
|
with g_pathmgr.open(file_list_txt, "r") as f:
|
|
subset = [os.path.splitext(line.strip())[0] for line in f]
|
|
else:
|
|
subset = os.listdir(self.img_folder)
|
|
|
|
self.video_names = sorted(
|
|
[video_name for video_name in subset if video_name not in excluded_files]
|
|
)
|
|
|
|
def get_video(self, video_idx):
|
|
"""
|
|
Given a VOSVideo object, return the mask tensors.
|
|
"""
|
|
video_name = self.video_names[video_idx]
|
|
video_json_path = os.path.join(self.gt_folder, video_name + "_manual.json")
|
|
segment_loader = JSONSegmentLoader(
|
|
video_json_path=video_json_path,
|
|
ann_every=self.ann_every,
|
|
frames_fps=self.frames_fps,
|
|
)
|
|
|
|
frame_ids = [
|
|
int(os.path.splitext(frame_name)[0])
|
|
for frame_name in sorted(
|
|
os.listdir(os.path.join(self.img_folder, video_name))
|
|
)
|
|
]
|
|
|
|
frames = [
|
|
VOSFrame(
|
|
frame_id,
|
|
image_path=os.path.join(
|
|
self.img_folder, f"{video_name}/%05d.jpg" % (frame_id)
|
|
),
|
|
)
|
|
for frame_id in frame_ids[:: self.sample_rate]
|
|
]
|
|
|
|
if self.rm_unannotated:
|
|
# Eliminate the frames that have not been annotated
|
|
valid_frame_ids = [
|
|
i * segment_loader.ann_every
|
|
for i, annot in enumerate(segment_loader.frame_annots)
|
|
if annot is not None and None not in annot
|
|
]
|
|
frames = [f for f in frames if f.frame_idx in valid_frame_ids]
|
|
|
|
video = VOSVideo(video_name, video_idx, frames)
|
|
return video, segment_loader
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|