add readme (#10)
* Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * remove submodule * add mPLUG MiniGPT4 * Update Readme.md * Update Readme.md * Update Readme.md --------- Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
This commit is contained in:
0
models/MiniGPT4/minigpt4/datasets/__init__.py
Normal file
0
models/MiniGPT4/minigpt4/datasets/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
71
models/MiniGPT4/minigpt4/datasets/builders/__init__.py
Normal file
71
models/MiniGPT4/minigpt4/datasets/builders/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
||||
from minigpt4.datasets.builders.image_text_pair_builder import (
|
||||
CCSBUBuilder,
|
||||
LaionBuilder,
|
||||
CCSBUAlignBuilder
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
__all__ = [
|
||||
"CCSBUBuilder",
|
||||
"LaionBuilder",
|
||||
"CCSBUAlignBuilder"
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
||||
"""
|
||||
Example
|
||||
|
||||
>>> dataset = load_dataset("coco_caption", cfg=None)
|
||||
>>> splits = dataset.keys()
|
||||
>>> print([len(dataset[split]) for split in splits])
|
||||
|
||||
"""
|
||||
if cfg_path is None:
|
||||
cfg = None
|
||||
else:
|
||||
cfg = load_dataset_config(cfg_path)
|
||||
|
||||
try:
|
||||
builder = registry.get_builder_class(name)(cfg)
|
||||
except TypeError:
|
||||
print(
|
||||
f"Dataset {name} not found. Available datasets:\n"
|
||||
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
||||
)
|
||||
exit(1)
|
||||
|
||||
if vis_path is not None:
|
||||
if data_type is None:
|
||||
# use default data type in the config
|
||||
data_type = builder.config.data_type
|
||||
|
||||
assert (
|
||||
data_type in builder.config.build_info
|
||||
), f"Invalid data_type {data_type} for {name}."
|
||||
|
||||
builder.config.build_info.get(data_type).storage = vis_path
|
||||
|
||||
dataset = builder.build_datasets()
|
||||
return dataset
|
||||
|
||||
|
||||
class DatasetZoo:
|
||||
def __init__(self) -> None:
|
||||
self.dataset_zoo = {
|
||||
k: list(v.DATASET_CONFIG_DICT.keys())
|
||||
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
||||
}
|
||||
|
||||
def get_names(self):
|
||||
return list(self.dataset_zoo.keys())
|
||||
|
||||
|
||||
dataset_zoo = DatasetZoo()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
This file is from
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
import minigpt4.common.utils as utils
|
||||
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
|
||||
class BaseDatasetBuilder:
|
||||
train_dataset_cls, eval_dataset_cls = None, None
|
||||
|
||||
def __init__(self, cfg=None):
|
||||
super().__init__()
|
||||
|
||||
if cfg is None:
|
||||
# help to create datasets from default config.
|
||||
self.config = load_dataset_config(self.default_config_path())
|
||||
elif isinstance(cfg, str):
|
||||
self.config = load_dataset_config(cfg)
|
||||
else:
|
||||
# when called from task.build_dataset()
|
||||
self.config = cfg
|
||||
|
||||
self.data_type = self.config.data_type
|
||||
|
||||
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
|
||||
def build_datasets(self):
|
||||
# download, split, etc...
|
||||
# only called on 1 GPU/TPU in distributed
|
||||
|
||||
if is_main_process():
|
||||
self._download_data()
|
||||
|
||||
if is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
datasets = self.build() # dataset['train'/'val'/'test']
|
||||
|
||||
return datasets
|
||||
|
||||
def build_processors(self):
|
||||
vis_proc_cfg = self.config.get("vis_processor")
|
||||
txt_proc_cfg = self.config.get("text_processor")
|
||||
|
||||
if vis_proc_cfg is not None:
|
||||
vis_train_cfg = vis_proc_cfg.get("train")
|
||||
vis_eval_cfg = vis_proc_cfg.get("eval")
|
||||
|
||||
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
||||
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
||||
|
||||
if txt_proc_cfg is not None:
|
||||
txt_train_cfg = txt_proc_cfg.get("train")
|
||||
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||||
|
||||
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
||||
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
||||
|
||||
@staticmethod
|
||||
def _build_proc_from_cfg(cfg):
|
||||
return (
|
||||
registry.get_processor_class(cfg.name).from_config(cfg)
|
||||
if cfg is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, type="default"):
|
||||
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||
|
||||
def _download_data(self):
|
||||
self._download_ann()
|
||||
self._download_vis()
|
||||
|
||||
def _download_ann(self):
|
||||
"""
|
||||
Download annotation files if necessary.
|
||||
All the vision-language datasets should have annotations of unified format.
|
||||
|
||||
storage_path can be:
|
||||
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
||||
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
||||
|
||||
Local annotation paths should be relative.
|
||||
"""
|
||||
anns = self.config.build_info.annotations
|
||||
|
||||
splits = anns.keys()
|
||||
|
||||
cache_root = registry.get_path("cache_root")
|
||||
|
||||
for split in splits:
|
||||
info = anns[split]
|
||||
|
||||
urls, storage_paths = info.get("url", None), info.storage
|
||||
|
||||
if isinstance(urls, str):
|
||||
urls = [urls]
|
||||
if isinstance(storage_paths, str):
|
||||
storage_paths = [storage_paths]
|
||||
|
||||
assert len(urls) == len(storage_paths)
|
||||
|
||||
for url_or_filename, storage_path in zip(urls, storage_paths):
|
||||
# if storage_path is relative, make it full by prefixing with cache_root.
|
||||
if not os.path.isabs(storage_path):
|
||||
storage_path = os.path.join(cache_root, storage_path)
|
||||
|
||||
dirname = os.path.dirname(storage_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
if os.path.isfile(url_or_filename):
|
||||
src, dst = url_or_filename, storage_path
|
||||
if not os.path.exists(dst):
|
||||
shutil.copyfile(src=src, dst=dst)
|
||||
else:
|
||||
logging.info("Using existing file {}.".format(dst))
|
||||
else:
|
||||
if os.path.isdir(storage_path):
|
||||
# if only dirname is provided, suffix with basename of URL.
|
||||
raise ValueError(
|
||||
"Expecting storage_path to be a file path, got directory {}".format(
|
||||
storage_path
|
||||
)
|
||||
)
|
||||
else:
|
||||
filename = os.path.basename(storage_path)
|
||||
|
||||
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||
|
||||
def _download_vis(self):
|
||||
|
||||
storage_path = self.config.build_info.get(self.data_type).storage
|
||||
storage_path = utils.get_cache_path(storage_path)
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
warnings.warn(
|
||||
f"""
|
||||
The specified path {storage_path} for visual inputs does not exist.
|
||||
Please provide a correct path to the visual inputs or
|
||||
refer to datasets/download_scripts/README.md for downloading instructions.
|
||||
"""
|
||||
)
|
||||
|
||||
def build(self):
|
||||
"""
|
||||
Create by split datasets inheriting torch.utils.data.Datasets.
|
||||
|
||||
# build() can be dataset-specific. Overwrite to customize.
|
||||
"""
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
ann_info = build_info.annotations
|
||||
vis_info = build_info.get(self.data_type)
|
||||
|
||||
datasets = dict()
|
||||
for split in ann_info.keys():
|
||||
if split not in ["train", "val", "test"]:
|
||||
continue
|
||||
|
||||
is_train = split == "train"
|
||||
|
||||
# processors
|
||||
vis_processor = (
|
||||
self.vis_processors["train"]
|
||||
if is_train
|
||||
else self.vis_processors["eval"]
|
||||
)
|
||||
text_processor = (
|
||||
self.text_processors["train"]
|
||||
if is_train
|
||||
else self.text_processors["eval"]
|
||||
)
|
||||
|
||||
# annotation path
|
||||
ann_paths = ann_info.get(split).storage
|
||||
if isinstance(ann_paths, str):
|
||||
ann_paths = [ann_paths]
|
||||
|
||||
abs_ann_paths = []
|
||||
for ann_path in ann_paths:
|
||||
if not os.path.isabs(ann_path):
|
||||
ann_path = utils.get_cache_path(ann_path)
|
||||
abs_ann_paths.append(ann_path)
|
||||
ann_paths = abs_ann_paths
|
||||
|
||||
# visual data storage path
|
||||
vis_path = os.path.join(vis_info.storage, split)
|
||||
|
||||
if not os.path.isabs(vis_path):
|
||||
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
||||
vis_path = utils.get_cache_path(vis_path)
|
||||
|
||||
if not os.path.exists(vis_path):
|
||||
warnings.warn("storage path {} does not exist.".format(vis_path))
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=vis_processor,
|
||||
text_processor=text_processor,
|
||||
ann_paths=ann_paths,
|
||||
vis_root=vis_path,
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def load_dataset_config(cfg_path):
|
||||
cfg = OmegaConf.load(cfg_path).datasets
|
||||
cfg = cfg[list(cfg.keys())[0]]
|
||||
|
||||
return cfg
|
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
||||
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu")
|
||||
class CCSBUBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
||||
|
||||
def _download_ann(self):
|
||||
pass
|
||||
|
||||
def _download_vis(self):
|
||||
pass
|
||||
|
||||
def build(self):
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
datasets = dict()
|
||||
split = "train"
|
||||
|
||||
# create datasets
|
||||
# [NOTE] return inner_datasets (wds.DataPipeline)
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=self.vis_processors[split],
|
||||
text_processor=self.text_processors[split],
|
||||
location=build_info.storage,
|
||||
).inner_dataset
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@registry.register_builder("laion")
|
||||
class LaionBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = LaionDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
||||
|
||||
def _download_ann(self):
|
||||
pass
|
||||
|
||||
def _download_vis(self):
|
||||
pass
|
||||
|
||||
def build(self):
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
datasets = dict()
|
||||
split = "train"
|
||||
|
||||
# create datasets
|
||||
# [NOTE] return inner_datasets (wds.DataPipeline)
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=self.vis_processors[split],
|
||||
text_processor=self.text_processors[split],
|
||||
location=build_info.storage,
|
||||
).inner_dataset
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu_align")
|
||||
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUAlignDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {
|
||||
"default": "configs/datasets/cc_sbu/align.yaml",
|
||||
}
|
||||
|
||||
def build_datasets(self):
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
storage_path = build_info.storage
|
||||
|
||||
datasets = dict()
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
warnings.warn("storage path {} does not exist.".format(storage_path))
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets['train'] = dataset_cls(
|
||||
vis_processor=self.vis_processors["train"],
|
||||
text_processor=self.text_processors["train"],
|
||||
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
|
||||
vis_root=os.path.join(storage_path, 'image'),
|
||||
)
|
||||
|
||||
return datasets
|
196
models/MiniGPT4/minigpt4/datasets/data_utils.py
Normal file
196
models/MiniGPT4/minigpt4/datasets/data_utils.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import random as rnd
|
||||
import tarfile
|
||||
import zipfile
|
||||
import random
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import decord
|
||||
from decord import VideoReader
|
||||
import webdataset as wds
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
||||
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
MAX_INT = registry.get("MAX_INT")
|
||||
|
||||
|
||||
class ChainDataset(wds.DataPipeline):
|
||||
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
||||
|
||||
This class is useful to assemble different existing dataset streams. The
|
||||
chaining operation is done on-the-fly, so concatenating large-scale
|
||||
datasets with this class will be efficient.
|
||||
|
||||
Args:
|
||||
datasets (iterable of IterableDataset): datasets to be chained together
|
||||
"""
|
||||
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
||||
super().__init__()
|
||||
self.datasets = datasets
|
||||
self.prob = []
|
||||
self.names = []
|
||||
for dataset in self.datasets:
|
||||
if hasattr(dataset, 'name'):
|
||||
self.names.append(dataset.name)
|
||||
else:
|
||||
self.names.append('Unknown')
|
||||
if hasattr(dataset, 'sample_ratio'):
|
||||
self.prob.append(dataset.sample_ratio)
|
||||
else:
|
||||
self.prob.append(1)
|
||||
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
|
||||
|
||||
def __iter__(self):
|
||||
datastreams = [iter(dataset) for dataset in self.datasets]
|
||||
while True:
|
||||
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
|
||||
yield next(select_datastream)
|
||||
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _apply(x):
|
||||
if torch.is_tensor(x):
|
||||
return f(x)
|
||||
elif isinstance(x, dict):
|
||||
return {key: _apply(value) for key, value in x.items()}
|
||||
elif isinstance(x, list):
|
||||
return [_apply(x) for x in x]
|
||||
else:
|
||||
return x
|
||||
|
||||
return _apply(sample)
|
||||
|
||||
|
||||
def move_to_cuda(sample):
|
||||
def _move_to_cuda(tensor):
|
||||
return tensor.cuda()
|
||||
|
||||
return apply_to_sample(_move_to_cuda, sample)
|
||||
|
||||
|
||||
def prepare_sample(samples, cuda_enabled=True):
|
||||
if cuda_enabled:
|
||||
samples = move_to_cuda(samples)
|
||||
|
||||
# TODO fp16 support
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def reorg_datasets_by_split(datasets):
|
||||
"""
|
||||
Organizes datasets by split.
|
||||
|
||||
Args:
|
||||
datasets: dict of torch.utils.data.Dataset objects by name.
|
||||
|
||||
Returns:
|
||||
Dict of datasets by split {split_name: List[Datasets]}.
|
||||
"""
|
||||
# if len(datasets) == 1:
|
||||
# return datasets[list(datasets.keys())[0]]
|
||||
# else:
|
||||
reorg_datasets = dict()
|
||||
|
||||
# reorganize by split
|
||||
for _, dataset in datasets.items():
|
||||
for split_name, dataset_split in dataset.items():
|
||||
if split_name not in reorg_datasets:
|
||||
reorg_datasets[split_name] = [dataset_split]
|
||||
else:
|
||||
reorg_datasets[split_name].append(dataset_split)
|
||||
|
||||
return reorg_datasets
|
||||
|
||||
|
||||
def concat_datasets(datasets):
|
||||
"""
|
||||
Concatenates multiple datasets into a single dataset.
|
||||
|
||||
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
|
||||
generic IterableDataset because it requires creating separate samplers.
|
||||
|
||||
Now only supports conctenating training datasets and assuming validation and testing
|
||||
have only a single dataset. This is because metrics should not be computed on the concatenated
|
||||
datasets.
|
||||
|
||||
Args:
|
||||
datasets: dict of torch.utils.data.Dataset objects by split.
|
||||
|
||||
Returns:
|
||||
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
|
||||
"val" and "test" remain the same.
|
||||
|
||||
If the input training datasets contain both map-style and DataPipeline datasets, returns
|
||||
a tuple, where the first element is a concatenated map-style dataset and the second
|
||||
element is a chained DataPipeline dataset.
|
||||
|
||||
"""
|
||||
# concatenate datasets in the same split
|
||||
for split_name in datasets:
|
||||
if split_name != "train":
|
||||
assert (
|
||||
len(datasets[split_name]) == 1
|
||||
), "Do not support multiple {} datasets.".format(split_name)
|
||||
datasets[split_name] = datasets[split_name][0]
|
||||
else:
|
||||
iterable_datasets, map_datasets = [], []
|
||||
for dataset in datasets[split_name]:
|
||||
if isinstance(dataset, wds.DataPipeline):
|
||||
logging.info(
|
||||
"Dataset {} is IterableDataset, can't be concatenated.".format(
|
||||
dataset
|
||||
)
|
||||
)
|
||||
iterable_datasets.append(dataset)
|
||||
elif isinstance(dataset, IterableDataset):
|
||||
raise NotImplementedError(
|
||||
"Do not support concatenation of generic IterableDataset."
|
||||
)
|
||||
else:
|
||||
map_datasets.append(dataset)
|
||||
|
||||
# if len(iterable_datasets) > 0:
|
||||
# concatenate map-style datasets and iterable-style datasets separately
|
||||
if len(iterable_datasets) > 1:
|
||||
chained_datasets = (
|
||||
ChainDataset(iterable_datasets)
|
||||
)
|
||||
elif len(iterable_datasets) == 1:
|
||||
chained_datasets = iterable_datasets[0]
|
||||
else:
|
||||
chained_datasets = None
|
||||
|
||||
concat_datasets = (
|
||||
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
||||
)
|
||||
|
||||
train_datasets = concat_datasets, chained_datasets
|
||||
train_datasets = tuple([x for x in train_datasets if x is not None])
|
||||
train_datasets = (
|
||||
train_datasets[0] if len(train_datasets) == 1 else train_datasets
|
||||
)
|
||||
|
||||
datasets[split_name] = train_datasets
|
||||
|
||||
return datasets
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
68
models/MiniGPT4/minigpt4/datasets/datasets/base_dataset.py
Normal file
68
models/MiniGPT4/minigpt4/datasets/datasets/base_dataset.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Iterable
|
||||
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
self.annotation = []
|
||||
for ann_path in ann_paths:
|
||||
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
self._add_instance_ids()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def collater(self, samples):
|
||||
return default_collate(samples)
|
||||
|
||||
def set_processors(self, vis_processor, text_processor):
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
def _add_instance_ids(self, key="instance_id"):
|
||||
for idx, ann in enumerate(self.annotation):
|
||||
ann[key] = str(idx)
|
||||
|
||||
|
||||
class ConcatDataset(ConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super().__init__(datasets)
|
||||
|
||||
def collater(self, samples):
|
||||
# TODO For now only supports datasets with same underlying collater implementations
|
||||
|
||||
all_keys = set()
|
||||
for s in samples:
|
||||
all_keys.update(s)
|
||||
|
||||
shared_keys = all_keys
|
||||
for s in samples:
|
||||
shared_keys = shared_keys & set(s.keys())
|
||||
|
||||
samples_shared_keys = []
|
||||
for s in samples:
|
||||
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
||||
|
||||
return self.datasets[0].collater(samples_shared_keys)
|
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class __DisplMixin:
|
||||
def displ_item(self, index):
|
||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||
|
||||
return OrderedDict(
|
||||
{
|
||||
"file": ann["image"],
|
||||
"caption": ann["caption"],
|
||||
"image": sample["image"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
for ann in self.annotation:
|
||||
img_id = ann["image_id"]
|
||||
if img_id not in self.img_ids.keys():
|
||||
self.img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
||||
|
||||
|
||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.vis_root, ann["image"])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"image_id": ann["image_id"],
|
||||
"instance_id": ann["instance_id"],
|
||||
}
|
47
models/MiniGPT4/minigpt4/datasets/datasets/cc_sbu_dataset.py
Normal file
47
models/MiniGPT4/minigpt4/datasets/datasets/cc_sbu_dataset.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
||||
|
||||
|
||||
class CCSBUDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
|
||||
class CCSBUAlignDataset(CaptionDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
caption = ann["caption"]
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
162
models/MiniGPT4/minigpt4/datasets/datasets/dataloader_utils.py
Normal file
162
models/MiniGPT4/minigpt4/datasets/datasets/dataloader_utils.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
from minigpt4.datasets.data_utils import move_to_cuda
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class MultiIterLoader:
|
||||
"""
|
||||
A simple wrapper for iterating over multiple iterators.
|
||||
|
||||
Args:
|
||||
loaders (List[Loader]): List of Iterator loaders.
|
||||
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
||||
"""
|
||||
|
||||
def __init__(self, loaders, ratios=None):
|
||||
# assert all loaders has __next__ method
|
||||
for loader in loaders:
|
||||
assert hasattr(
|
||||
loader, "__next__"
|
||||
), "Loader {} has no __next__ method.".format(loader)
|
||||
|
||||
if ratios is None:
|
||||
ratios = [1.0] * len(loaders)
|
||||
else:
|
||||
assert len(ratios) == len(loaders)
|
||||
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
||||
|
||||
self.loaders = loaders
|
||||
self.ratios = ratios
|
||||
|
||||
def __next__(self):
|
||||
# random sample from each loader by ratio
|
||||
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
||||
return next(self.loaders[loader_idx])
|
||||
|
||||
|
||||
class PrefetchLoader(object):
|
||||
"""
|
||||
Modified from https://github.com/ChenRocks/UNITER.
|
||||
|
||||
overlap compute and cuda data transfer
|
||||
(copied and then modified from nvidia apex)
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
def __iter__(self):
|
||||
loader_it = iter(self.loader)
|
||||
self.preload(loader_it)
|
||||
batch = self.next(loader_it)
|
||||
while batch is not None:
|
||||
is_tuple = isinstance(batch, tuple)
|
||||
if is_tuple:
|
||||
task, batch = batch
|
||||
|
||||
if is_tuple:
|
||||
yield task, batch
|
||||
else:
|
||||
yield batch
|
||||
batch = self.next(loader_it)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def preload(self, it):
|
||||
try:
|
||||
self.batch = next(it)
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return
|
||||
# if record_stream() doesn't work, another option is to make sure
|
||||
# device inputs are created on the main stream.
|
||||
# self.next_input_gpu = torch.empty_like(self.next_input,
|
||||
# device='cuda')
|
||||
# self.next_target_gpu = torch.empty_like(self.next_target,
|
||||
# device='cuda')
|
||||
# Need to make sure the memory allocated for next_* is not still in use
|
||||
# by the main stream at the time we start copying to next_*:
|
||||
# self.stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.batch = move_to_cuda(self.batch)
|
||||
# more code for the alternative if record_stream() doesn't work:
|
||||
# copy_ will record the use of the pinned source tensor in this
|
||||
# side stream.
|
||||
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
||||
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
||||
# self.next_input = self.next_input_gpu
|
||||
# self.next_target = self.next_target_gpu
|
||||
|
||||
def next(self, it):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
if batch is not None:
|
||||
record_cuda_stream(batch)
|
||||
self.preload(it)
|
||||
return batch
|
||||
|
||||
def __getattr__(self, name):
|
||||
method = self.loader.__getattribute__(name)
|
||||
return method
|
||||
|
||||
|
||||
def record_cuda_stream(batch):
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch.record_stream(torch.cuda.current_stream())
|
||||
elif isinstance(batch, list) or isinstance(batch, tuple):
|
||||
for t in batch:
|
||||
record_cuda_stream(t)
|
||||
elif isinstance(batch, dict):
|
||||
for t in batch.values():
|
||||
record_cuda_stream(t)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
class IterLoader:
|
||||
"""
|
||||
A wrapper to convert DataLoader as an infinite iterator.
|
||||
|
||||
Modified from:
|
||||
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
||||
self._dataloader = dataloader
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
self._use_distributed = use_distributed
|
||||
self._epoch = 0
|
||||
|
||||
@property
|
||||
def epoch(self) -> int:
|
||||
return self._epoch
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
data = next(self.iter_loader)
|
||||
except StopIteration:
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
||||
self._dataloader.sampler.set_epoch(self._epoch)
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
data = next(self.iter_loader)
|
||||
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataloader)
|
31
models/MiniGPT4/minigpt4/datasets/datasets/laion_dataset.py
Normal file
31
models/MiniGPT4/minigpt4/datasets/datasets/laion_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class LaionDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
Reference in New Issue
Block a user