add readme (#10)
* Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * remove submodule * add mPLUG MiniGPT4 * Update Readme.md * Update Readme.md * Update Readme.md --------- Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
This commit is contained in:
0
models/MiniGPT4/minigpt4/common/__init__.py
Normal file
0
models/MiniGPT4/minigpt4/common/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
468
models/MiniGPT4/minigpt4/common/config.py
Normal file
468
models/MiniGPT4/minigpt4/common/config.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, args):
|
||||
self.config = {}
|
||||
|
||||
self.args = args
|
||||
|
||||
# Register the config and configuration for setup
|
||||
registry.register("configuration", self)
|
||||
|
||||
user_config = self._build_opt_list(self.args.options)
|
||||
|
||||
config = OmegaConf.load(self.args.cfg_path)
|
||||
|
||||
runner_config = self.build_runner_config(config)
|
||||
model_config = self.build_model_config(config, **user_config)
|
||||
dataset_config = self.build_dataset_config(config)
|
||||
|
||||
# Validate the user-provided runner configuration
|
||||
# model and dataset configuration are supposed to be validated by the respective classes
|
||||
# [TODO] validate the model/dataset configuration
|
||||
# self._validate_runner_config(runner_config)
|
||||
|
||||
# Override the default configuration with user options.
|
||||
self.config = OmegaConf.merge(
|
||||
runner_config, model_config, dataset_config, user_config
|
||||
)
|
||||
|
||||
def _validate_runner_config(self, runner_config):
|
||||
"""
|
||||
This method validates the configuration, such that
|
||||
1) all the user specified options are valid;
|
||||
2) no type mismatches between the user specified options and the config.
|
||||
"""
|
||||
runner_config_validator = create_runner_config_validator()
|
||||
runner_config_validator.validate(runner_config)
|
||||
|
||||
def _build_opt_list(self, opts):
|
||||
opts_dot_list = self._convert_to_dot_list(opts)
|
||||
return OmegaConf.from_dotlist(opts_dot_list)
|
||||
|
||||
@staticmethod
|
||||
def build_model_config(config, **kwargs):
|
||||
model = config.get("model", None)
|
||||
assert model is not None, "Missing model configuration file."
|
||||
|
||||
model_cls = registry.get_model_class(model.arch)
|
||||
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
||||
|
||||
model_type = kwargs.get("model.model_type", None)
|
||||
if not model_type:
|
||||
model_type = model.get("model_type", None)
|
||||
# else use the model type selected by user.
|
||||
|
||||
assert model_type is not None, "Missing model_type."
|
||||
|
||||
model_config_path = model_cls.default_config_path(model_type=model_type)
|
||||
|
||||
model_config = OmegaConf.create()
|
||||
# hierarchy override, customized config > default config
|
||||
model_config = OmegaConf.merge(
|
||||
model_config,
|
||||
OmegaConf.load(model_config_path),
|
||||
{"model": config["model"]},
|
||||
)
|
||||
|
||||
return model_config
|
||||
|
||||
@staticmethod
|
||||
def build_runner_config(config):
|
||||
return {"run": config.run}
|
||||
|
||||
@staticmethod
|
||||
def build_dataset_config(config):
|
||||
datasets = config.get("datasets", None)
|
||||
if datasets is None:
|
||||
raise KeyError(
|
||||
"Expecting 'datasets' as the root key for dataset configuration."
|
||||
)
|
||||
|
||||
dataset_config = OmegaConf.create()
|
||||
|
||||
for dataset_name in datasets:
|
||||
builder_cls = registry.get_builder_class(dataset_name)
|
||||
|
||||
dataset_config_type = datasets[dataset_name].get("type", "default")
|
||||
dataset_config_path = builder_cls.default_config_path(
|
||||
type=dataset_config_type
|
||||
)
|
||||
|
||||
# hierarchy override, customized config > default config
|
||||
dataset_config = OmegaConf.merge(
|
||||
dataset_config,
|
||||
OmegaConf.load(dataset_config_path),
|
||||
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
||||
)
|
||||
|
||||
return dataset_config
|
||||
|
||||
def _convert_to_dot_list(self, opts):
|
||||
if opts is None:
|
||||
opts = []
|
||||
|
||||
if len(opts) == 0:
|
||||
return opts
|
||||
|
||||
has_equal = opts[0].find("=") != -1
|
||||
|
||||
if has_equal:
|
||||
return opts
|
||||
|
||||
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
@property
|
||||
def run_cfg(self):
|
||||
return self.config.run
|
||||
|
||||
@property
|
||||
def datasets_cfg(self):
|
||||
return self.config.datasets
|
||||
|
||||
@property
|
||||
def model_cfg(self):
|
||||
return self.config.model
|
||||
|
||||
def pretty_print(self):
|
||||
logging.info("\n===== Running Parameters =====")
|
||||
logging.info(self._convert_node_to_json(self.config.run))
|
||||
|
||||
logging.info("\n====== Dataset Attributes ======")
|
||||
datasets = self.config.datasets
|
||||
|
||||
for dataset in datasets:
|
||||
if dataset in self.config.datasets:
|
||||
logging.info(f"\n======== {dataset} =======")
|
||||
dataset_config = self.config.datasets[dataset]
|
||||
logging.info(self._convert_node_to_json(dataset_config))
|
||||
else:
|
||||
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
||||
|
||||
logging.info(f"\n====== Model Attributes ======")
|
||||
logging.info(self._convert_node_to_json(self.config.model))
|
||||
|
||||
def _convert_node_to_json(self, node):
|
||||
container = OmegaConf.to_container(node, resolve=True)
|
||||
return json.dumps(container, indent=4, sort_keys=True)
|
||||
|
||||
def to_dict(self):
|
||||
return OmegaConf.to_container(self.config)
|
||||
|
||||
|
||||
def node_to_dict(node):
|
||||
return OmegaConf.to_container(node)
|
||||
|
||||
|
||||
class ConfigValidator:
|
||||
"""
|
||||
This is a preliminary implementation to centralize and validate the configuration.
|
||||
May be altered in the future.
|
||||
|
||||
A helper class to validate configurations from yaml file.
|
||||
|
||||
This serves the following purposes:
|
||||
1. Ensure all the options in the yaml are defined, raise error if not.
|
||||
2. when type mismatches are found, the validator will raise an error.
|
||||
3. a central place to store and display helpful messages for supported configurations.
|
||||
|
||||
"""
|
||||
|
||||
class _Argument:
|
||||
def __init__(self, name, choices=None, type=None, help=None):
|
||||
self.name = name
|
||||
self.val = None
|
||||
self.choices = choices
|
||||
self.type = type
|
||||
self.help = help
|
||||
|
||||
def __str__(self):
|
||||
s = f"{self.name}={self.val}"
|
||||
if self.type is not None:
|
||||
s += f", ({self.type})"
|
||||
if self.choices is not None:
|
||||
s += f", choices: {self.choices}"
|
||||
if self.help is not None:
|
||||
s += f", ({self.help})"
|
||||
return s
|
||||
|
||||
def __init__(self, description):
|
||||
self.description = description
|
||||
|
||||
self.arguments = dict()
|
||||
|
||||
self.parsed_args = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
assert self.parsed_args is not None, "No arguments parsed yet."
|
||||
|
||||
return self.parsed_args[key]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format_help()
|
||||
|
||||
def add_argument(self, *args, **kwargs):
|
||||
"""
|
||||
Assume the first argument is the name of the argument.
|
||||
"""
|
||||
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
||||
|
||||
def validate(self, config=None):
|
||||
"""
|
||||
Convert yaml config (dict-like) to list, required by argparse.
|
||||
"""
|
||||
for k, v in config.items():
|
||||
assert (
|
||||
k in self.arguments
|
||||
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
||||
|
||||
if self.arguments[k].type is not None:
|
||||
try:
|
||||
self.arguments[k].val = self.arguments[k].type(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
||||
|
||||
if self.arguments[k].choices is not None:
|
||||
assert (
|
||||
v in self.arguments[k].choices
|
||||
), f"""{k} must be one of {self.arguments[k].choices}."""
|
||||
|
||||
return config
|
||||
|
||||
def format_arguments(self):
|
||||
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
||||
|
||||
def format_help(self):
|
||||
# description + key-value pair string for each argument
|
||||
help_msg = str(self.description)
|
||||
return help_msg + ", available arguments: " + self.format_arguments()
|
||||
|
||||
def print_help(self):
|
||||
# display help message
|
||||
print(self.format_help())
|
||||
|
||||
|
||||
def create_runner_config_validator():
|
||||
validator = ConfigValidator(description="Runner configurations")
|
||||
|
||||
validator.add_argument(
|
||||
"runner",
|
||||
type=str,
|
||||
choices=["runner_base", "runner_iter"],
|
||||
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
||||
runner runs based on iters. Default: runner_base""",
|
||||
)
|
||||
# add argumetns for training dataset ratios
|
||||
validator.add_argument(
|
||||
"train_dataset_ratios",
|
||||
type=Dict[str, float],
|
||||
help="""Ratios of training dataset. This is used in iteration-based runner.
|
||||
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
||||
Default: None""",
|
||||
)
|
||||
validator.add_argument(
|
||||
"max_iters",
|
||||
type=float,
|
||||
help="Maximum number of iterations to run.",
|
||||
)
|
||||
validator.add_argument(
|
||||
"max_epoch",
|
||||
type=int,
|
||||
help="Maximum number of epochs to run.",
|
||||
)
|
||||
# add arguments for iters_per_inner_epoch
|
||||
validator.add_argument(
|
||||
"iters_per_inner_epoch",
|
||||
type=float,
|
||||
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
||||
)
|
||||
lr_scheds_choices = registry.list_lr_schedulers()
|
||||
validator.add_argument(
|
||||
"lr_sched",
|
||||
type=str,
|
||||
choices=lr_scheds_choices,
|
||||
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
||||
)
|
||||
task_choices = registry.list_tasks()
|
||||
validator.add_argument(
|
||||
"task",
|
||||
type=str,
|
||||
choices=task_choices,
|
||||
help="Task to use, from {}".format(task_choices),
|
||||
)
|
||||
# add arguments for init_lr
|
||||
validator.add_argument(
|
||||
"init_lr",
|
||||
type=float,
|
||||
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
||||
)
|
||||
# add arguments for min_lr
|
||||
validator.add_argument(
|
||||
"min_lr",
|
||||
type=float,
|
||||
help="Minimum learning rate (after decay).",
|
||||
)
|
||||
# add arguments for warmup_lr
|
||||
validator.add_argument(
|
||||
"warmup_lr",
|
||||
type=float,
|
||||
help="Starting learning rate for warmup.",
|
||||
)
|
||||
# add arguments for learning rate decay rate
|
||||
validator.add_argument(
|
||||
"lr_decay_rate",
|
||||
type=float,
|
||||
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
||||
)
|
||||
# add arguments for weight decay
|
||||
validator.add_argument(
|
||||
"weight_decay",
|
||||
type=float,
|
||||
help="Weight decay rate.",
|
||||
)
|
||||
# add arguments for training batch size
|
||||
validator.add_argument(
|
||||
"batch_size_train",
|
||||
type=int,
|
||||
help="Training batch size.",
|
||||
)
|
||||
# add arguments for evaluation batch size
|
||||
validator.add_argument(
|
||||
"batch_size_eval",
|
||||
type=int,
|
||||
help="Evaluation batch size, including validation and testing.",
|
||||
)
|
||||
# add arguments for number of workers for data loading
|
||||
validator.add_argument(
|
||||
"num_workers",
|
||||
help="Number of workers for data loading.",
|
||||
)
|
||||
# add arguments for warm up steps
|
||||
validator.add_argument(
|
||||
"warmup_steps",
|
||||
type=int,
|
||||
help="Number of warmup steps. Required if a warmup schedule is used.",
|
||||
)
|
||||
# add arguments for random seed
|
||||
validator.add_argument(
|
||||
"seed",
|
||||
type=int,
|
||||
help="Random seed.",
|
||||
)
|
||||
# add arguments for output directory
|
||||
validator.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="Output directory to save checkpoints and logs.",
|
||||
)
|
||||
# add arguments for whether only use evaluation
|
||||
validator.add_argument(
|
||||
"evaluate",
|
||||
help="Whether to only evaluate the model. If true, training will not be performed.",
|
||||
)
|
||||
# add arguments for splits used for training, e.g. ["train", "val"]
|
||||
validator.add_argument(
|
||||
"train_splits",
|
||||
type=list,
|
||||
help="Splits to use for training.",
|
||||
)
|
||||
# add arguments for splits used for validation, e.g. ["val"]
|
||||
validator.add_argument(
|
||||
"valid_splits",
|
||||
type=list,
|
||||
help="Splits to use for validation. If not provided, will skip the validation.",
|
||||
)
|
||||
# add arguments for splits used for testing, e.g. ["test"]
|
||||
validator.add_argument(
|
||||
"test_splits",
|
||||
type=list,
|
||||
help="Splits to use for testing. If not provided, will skip the testing.",
|
||||
)
|
||||
# add arguments for accumulating gradient for iterations
|
||||
validator.add_argument(
|
||||
"accum_grad_iters",
|
||||
type=int,
|
||||
help="Number of iterations to accumulate gradient for.",
|
||||
)
|
||||
|
||||
# ====== distributed training ======
|
||||
validator.add_argument(
|
||||
"device",
|
||||
type=str,
|
||||
choices=["cpu", "cuda"],
|
||||
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
||||
)
|
||||
validator.add_argument(
|
||||
"world_size",
|
||||
type=int,
|
||||
help="Number of processes participating in the job.",
|
||||
)
|
||||
validator.add_argument("dist_url", type=str)
|
||||
validator.add_argument("distributed", type=bool)
|
||||
# add arguments to opt using distributed sampler during evaluation or not
|
||||
validator.add_argument(
|
||||
"use_dist_eval_sampler",
|
||||
type=bool,
|
||||
help="Whether to use distributed sampler during evaluation or not.",
|
||||
)
|
||||
|
||||
# ====== task specific ======
|
||||
# generation task specific arguments
|
||||
# add arguments for maximal length of text output
|
||||
validator.add_argument(
|
||||
"max_len",
|
||||
type=int,
|
||||
help="Maximal length of text output.",
|
||||
)
|
||||
# add arguments for minimal length of text output
|
||||
validator.add_argument(
|
||||
"min_len",
|
||||
type=int,
|
||||
help="Minimal length of text output.",
|
||||
)
|
||||
# add arguments number of beams
|
||||
validator.add_argument(
|
||||
"num_beams",
|
||||
type=int,
|
||||
help="Number of beams used for beam search.",
|
||||
)
|
||||
|
||||
# vqa task specific arguments
|
||||
# add arguments for number of answer candidates
|
||||
validator.add_argument(
|
||||
"num_ans_candidates",
|
||||
type=int,
|
||||
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
||||
)
|
||||
# add arguments for inference method
|
||||
validator.add_argument(
|
||||
"inference_method",
|
||||
type=str,
|
||||
choices=["genearte", "rank"],
|
||||
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
||||
)
|
||||
|
||||
# ====== model specific ======
|
||||
validator.add_argument(
|
||||
"k_test",
|
||||
type=int,
|
||||
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
||||
)
|
||||
|
||||
return validator
|
137
models/MiniGPT4/minigpt4/common/dist_utils.py
Normal file
137
models/MiniGPT4/minigpt4/common/dist_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import timm.models.hub as timm_hub
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
"| distributed init (rank {}, world {}): {}".format(
|
||||
args.rank, args.world_size, args.dist_url
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method=args.dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
timeout=datetime.timedelta(
|
||||
days=365
|
||||
), # allow auto-downloading and de-compressing
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if torch.__version__ < "1.0":
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
initialized = dist.is_initialized()
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else: # non-distributed training
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def main_process(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
"""
|
||||
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
||||
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
||||
"""
|
||||
|
||||
def get_cached_file_path():
|
||||
# a hack to sync the file path across processes
|
||||
parts = torch.hub.urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
||||
|
||||
return cached_file
|
||||
|
||||
if is_main_process():
|
||||
timm_hub.download_cached_file(url, check_hash, progress)
|
||||
|
||||
if is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return get_cached_file_path()
|
24
models/MiniGPT4/minigpt4/common/gradcam.py
Normal file
24
models/MiniGPT4/minigpt4/common/gradcam.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from scipy.ndimage import filters
|
||||
from skimage import transform as skimage_transform
|
||||
|
||||
|
||||
def getAttMap(img, attMap, blur=True, overlap=True):
|
||||
attMap -= attMap.min()
|
||||
if attMap.max() > 0:
|
||||
attMap /= attMap.max()
|
||||
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
||||
if blur:
|
||||
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||
attMap -= attMap.min()
|
||||
attMap /= attMap.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
attMapV = cmap(attMap)
|
||||
attMapV = np.delete(attMapV, 3, 2)
|
||||
if overlap:
|
||||
attMap = (
|
||||
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||
)
|
||||
return attMap
|
195
models/MiniGPT4/minigpt4/common/logger.py
Normal file
195
models/MiniGPT4/minigpt4/common/logger.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from minigpt4.common import dist_utils
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not dist_utils.is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value,
|
||||
)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def global_avg(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
log_msg = [
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append("max mem: {memory:.0f}")
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(
|
||||
"{} Total time: {} ({:.4f} s / it)".format(
|
||||
header, total_time_str, total_time / len(iterable)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
119
models/MiniGPT4/minigpt4/common/optims.py
Normal file
119
models/MiniGPT4/minigpt4/common/optims.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
||||
class LinearWarmupStepLRScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_epoch,
|
||||
min_lr,
|
||||
init_lr,
|
||||
decay_rate=1,
|
||||
warmup_start_lr=-1,
|
||||
warmup_steps=0,
|
||||
**kwargs
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_epoch = max_epoch
|
||||
self.min_lr = min_lr
|
||||
|
||||
self.decay_rate = decay_rate
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
def step(self, cur_epoch, cur_step):
|
||||
if cur_epoch == 0:
|
||||
warmup_lr_schedule(
|
||||
step=cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_step=self.warmup_steps,
|
||||
init_lr=self.warmup_start_lr,
|
||||
max_lr=self.init_lr,
|
||||
)
|
||||
else:
|
||||
step_lr_schedule(
|
||||
epoch=cur_epoch,
|
||||
optimizer=self.optimizer,
|
||||
init_lr=self.init_lr,
|
||||
min_lr=self.min_lr,
|
||||
decay_rate=self.decay_rate,
|
||||
)
|
||||
|
||||
|
||||
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
||||
class LinearWarmupCosineLRScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_epoch,
|
||||
iters_per_epoch,
|
||||
min_lr,
|
||||
init_lr,
|
||||
warmup_steps=0,
|
||||
warmup_start_lr=-1,
|
||||
**kwargs
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_epoch = max_epoch
|
||||
self.iters_per_epoch = iters_per_epoch
|
||||
self.min_lr = min_lr
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
def step(self, cur_epoch, cur_step):
|
||||
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
||||
if total_cur_step < self.warmup_steps:
|
||||
warmup_lr_schedule(
|
||||
step=cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_step=self.warmup_steps,
|
||||
init_lr=self.warmup_start_lr,
|
||||
max_lr=self.init_lr,
|
||||
)
|
||||
else:
|
||||
cosine_lr_schedule(
|
||||
epoch=total_cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_epoch=self.max_epoch * self.iters_per_epoch,
|
||||
init_lr=self.init_lr,
|
||||
min_lr=self.min_lr,
|
||||
)
|
||||
|
||||
|
||||
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
||||
"""Decay the learning rate"""
|
||||
lr = (init_lr - min_lr) * 0.5 * (
|
||||
1.0 + math.cos(math.pi * epoch / max_epoch)
|
||||
) + min_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
||||
"""Warmup the learning rate"""
|
||||
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
||||
"""Decay the learning rate"""
|
||||
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
329
models/MiniGPT4/minigpt4/common/registry.py
Normal file
329
models/MiniGPT4/minigpt4/common/registry.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
|
||||
class Registry:
|
||||
mapping = {
|
||||
"builder_name_mapping": {},
|
||||
"task_name_mapping": {},
|
||||
"processor_name_mapping": {},
|
||||
"model_name_mapping": {},
|
||||
"lr_scheduler_name_mapping": {},
|
||||
"runner_name_mapping": {},
|
||||
"state": {},
|
||||
"paths": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_builder(cls, name):
|
||||
r"""Register a dataset builder to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the builder will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
||||
"""
|
||||
|
||||
def wrap(builder_cls):
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
|
||||
assert issubclass(
|
||||
builder_cls, BaseDatasetBuilder
|
||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||
builder_cls
|
||||
)
|
||||
if name in cls.mapping["builder_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["builder_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["builder_name_mapping"][name] = builder_cls
|
||||
return builder_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_task(cls, name):
|
||||
r"""Register a task to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(task_cls):
|
||||
from minigpt4.tasks.base_task import BaseTask
|
||||
|
||||
assert issubclass(
|
||||
task_cls, BaseTask
|
||||
), "All tasks must inherit BaseTask class"
|
||||
if name in cls.mapping["task_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["task_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["task_name_mapping"][name] = task_cls
|
||||
return task_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name):
|
||||
r"""Register a task to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(model_cls):
|
||||
from minigpt4.models import BaseModel
|
||||
|
||||
assert issubclass(
|
||||
model_cls, BaseModel
|
||||
), "All models must inherit BaseModel class"
|
||||
if name in cls.mapping["model_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["model_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["model_name_mapping"][name] = model_cls
|
||||
return model_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_processor(cls, name):
|
||||
r"""Register a processor to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(processor_cls):
|
||||
from minigpt4.processors import BaseProcessor
|
||||
|
||||
assert issubclass(
|
||||
processor_cls, BaseProcessor
|
||||
), "All processors must inherit BaseProcessor class"
|
||||
if name in cls.mapping["processor_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["processor_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["processor_name_mapping"][name] = processor_cls
|
||||
return processor_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_lr_scheduler(cls, name):
|
||||
r"""Register a model to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(lr_sched_cls):
|
||||
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
||||
return lr_sched_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_runner(cls, name):
|
||||
r"""Register a model to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(runner_cls):
|
||||
if name in cls.mapping["runner_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["runner_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["runner_name_mapping"][name] = runner_cls
|
||||
return runner_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_path(cls, name, path):
|
||||
r"""Register a path to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the path will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
assert isinstance(path, str), "All path must be str."
|
||||
if name in cls.mapping["paths"]:
|
||||
raise KeyError("Name '{}' already registered.".format(name))
|
||||
cls.mapping["paths"][name] = path
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, obj):
|
||||
r"""Register an item to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the item will be registered.
|
||||
|
||||
Usage::
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
registry.register("config", {})
|
||||
"""
|
||||
path = name.split(".")
|
||||
current = cls.mapping["state"]
|
||||
|
||||
for part in path[:-1]:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
current[path[-1]] = obj
|
||||
|
||||
# @classmethod
|
||||
# def get_trainer_class(cls, name):
|
||||
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_builder_class(cls, name):
|
||||
return cls.mapping["builder_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls, name):
|
||||
return cls.mapping["model_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_task_class(cls, name):
|
||||
return cls.mapping["task_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_processor_class(cls, name):
|
||||
return cls.mapping["processor_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_lr_scheduler_class(cls, name):
|
||||
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_runner_class(cls, name):
|
||||
return cls.mapping["runner_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def list_runners(cls):
|
||||
return sorted(cls.mapping["runner_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_models(cls):
|
||||
return sorted(cls.mapping["model_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_tasks(cls):
|
||||
return sorted(cls.mapping["task_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls):
|
||||
return sorted(cls.mapping["processor_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_lr_schedulers(cls):
|
||||
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls):
|
||||
return sorted(cls.mapping["builder_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, name):
|
||||
return cls.mapping["paths"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name, default=None, no_warning=False):
|
||||
r"""Get an item from registry with key 'name'
|
||||
|
||||
Args:
|
||||
name (string): Key whose value needs to be retrieved.
|
||||
default: If passed and key is not in registry, default value will
|
||||
be returned with a warning. Default: None
|
||||
no_warning (bool): If passed as True, warning when key doesn't exist
|
||||
will not be generated. Useful for MMF's
|
||||
internal operations. Default: False
|
||||
"""
|
||||
original_name = name
|
||||
name = name.split(".")
|
||||
value = cls.mapping["state"]
|
||||
for subname in name:
|
||||
value = value.get(subname, default)
|
||||
if value is default:
|
||||
break
|
||||
|
||||
if (
|
||||
"writer" in cls.mapping["state"]
|
||||
and value == default
|
||||
and no_warning is False
|
||||
):
|
||||
cls.mapping["state"]["writer"].warning(
|
||||
"Key {} is not present in registry, returning default value "
|
||||
"of {}".format(original_name, default)
|
||||
)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, name):
|
||||
r"""Remove an item from registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key which needs to be removed.
|
||||
Usage::
|
||||
|
||||
from mmf.common.registry import registry
|
||||
|
||||
config = registry.unregister("config")
|
||||
"""
|
||||
return cls.mapping["state"].pop(name, None)
|
||||
|
||||
|
||||
registry = Registry()
|
424
models/MiniGPT4/minigpt4/common/utils.py
Normal file
424
models/MiniGPT4/minigpt4/common/utils.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from iopath.common.download import download
|
||||
from iopath.common.file_io import file_lock, g_pathmgr
|
||||
from minigpt4.common.registry import registry
|
||||
from torch.utils.model_zoo import tqdm
|
||||
from torchvision.datasets.utils import (
|
||||
check_integrity,
|
||||
download_file_from_google_drive,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
def now():
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
||||
|
||||
|
||||
def is_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def get_cache_path(rel_path):
|
||||
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
||||
|
||||
|
||||
def get_abs_path(rel_path):
|
||||
return os.path.join(registry.get_path("library_root"), rel_path)
|
||||
|
||||
|
||||
def load_json(filename):
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# The following are adapted from torchvision and vissl
|
||||
# torchvision: https://github.com/pytorch/vision
|
||||
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
print(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def get_redirected_url(url: str):
|
||||
"""
|
||||
Given a URL, returns the URL it redirects to or the
|
||||
original URL in case of no indirection
|
||||
"""
|
||||
import requests
|
||||
|
||||
with requests.Session() as session:
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
if response.history:
|
||||
return response.url
|
||||
else:
|
||||
return url
|
||||
|
||||
|
||||
def to_google_drive_download_url(view_url: str) -> str:
|
||||
"""
|
||||
Utility function to transform a view URL of google drive
|
||||
to a download URL for google drive
|
||||
Example input:
|
||||
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
||||
Example output:
|
||||
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
||||
"""
|
||||
splits = view_url.split("/")
|
||||
assert splits[-1] == "view"
|
||||
file_id = splits[-2]
|
||||
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
|
||||
|
||||
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
||||
"""
|
||||
Download a file from google drive
|
||||
Downloading an URL from google drive requires confirmation when
|
||||
the file of the size is too big (google drive notifies that
|
||||
anti-viral checks cannot be performed on such files)
|
||||
"""
|
||||
import requests
|
||||
|
||||
with requests.Session() as session:
|
||||
|
||||
# First get the confirmation token and append it to the URL
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith("download_warning"):
|
||||
url = url + "&confirm=" + v
|
||||
|
||||
# Then download the content of the file
|
||||
with session.get(url, stream=True, verify=True) as response:
|
||||
makedir(output_path)
|
||||
path = os.path.join(output_path, output_file_name)
|
||||
total_size = int(response.headers.get("Content-length", 0))
|
||||
with open(path, "wb") as file:
|
||||
from tqdm import tqdm
|
||||
|
||||
with tqdm(total=total_size) as progress_bar:
|
||||
for block in response.iter_content(
|
||||
chunk_size=io.DEFAULT_BUFFER_SIZE
|
||||
):
|
||||
file.write(block)
|
||||
progress_bar.update(len(block))
|
||||
|
||||
|
||||
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||||
parts = urlparse(url)
|
||||
|
||||
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||||
return None
|
||||
|
||||
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||||
if match is None:
|
||||
return None
|
||||
|
||||
return match.group("id")
|
||||
|
||||
|
||||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
||||
with open(filename, "wb") as fh:
|
||||
with urllib.request.urlopen(
|
||||
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
||||
) as response:
|
||||
with tqdm(total=response.length) as pbar:
|
||||
for chunk in iter(lambda: response.read(chunk_size), ""):
|
||||
if not chunk:
|
||||
break
|
||||
pbar.update(chunk_size)
|
||||
fh.write(chunk)
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
root: str,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download a file from a url and place it in root.
|
||||
Args:
|
||||
url (str): URL to download file from
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under.
|
||||
If None, use the basename of the URL.
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
makedir(root)
|
||||
|
||||
# check if file is already present locally
|
||||
if check_integrity(fpath, md5):
|
||||
print("Using downloaded and verified file: " + fpath)
|
||||
return
|
||||
|
||||
# expand redirect chain if needed
|
||||
url = get_redirected_url(url)
|
||||
|
||||
# check if file is located on Google Drive
|
||||
file_id = _get_google_drive_file_id(url)
|
||||
if file_id is not None:
|
||||
return download_file_from_google_drive(file_id, root, filename, md5)
|
||||
|
||||
# download the file
|
||||
try:
|
||||
print("Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
||||
if url[:5] == "https":
|
||||
url = url.replace("https:", "http:")
|
||||
print(
|
||||
"Failed download. Trying https -> http instead."
|
||||
" Downloading " + url + " to " + fpath
|
||||
)
|
||||
_urlretrieve(url, fpath)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
||||
|
||||
|
||||
def download_and_extract_archive(
|
||||
url: str,
|
||||
download_root: str,
|
||||
extract_root: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
remove_finished: bool = False,
|
||||
) -> None:
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
|
||||
def cache_url(url: str, cache_dir: str) -> str:
|
||||
"""
|
||||
This implementation downloads the remote resource and caches it locally.
|
||||
The resource will only be downloaded if not previously requested.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
||||
makedir(dirname)
|
||||
filename = url.split("/")[-1]
|
||||
cached = os.path.join(dirname, filename)
|
||||
with file_lock(cached):
|
||||
if not os.path.isfile(cached):
|
||||
logging.info(f"Downloading {url} to {cached} ...")
|
||||
cached = download(url, dirname, filename=filename)
|
||||
logging.info(f"URL {url} cached in {cached}")
|
||||
return cached
|
||||
|
||||
|
||||
# TODO (prigoyal): convert this into RAII-style API
|
||||
def create_file_symlink(file1, file2):
|
||||
"""
|
||||
Simply create the symlinks for a given file1 to file2.
|
||||
Useful during model checkpointing to symlinks to the
|
||||
latest successful checkpoint.
|
||||
"""
|
||||
try:
|
||||
if g_pathmgr.exists(file2):
|
||||
g_pathmgr.rm(file2)
|
||||
g_pathmgr.symlink(file1, file2)
|
||||
except Exception as e:
|
||||
logging.info(f"Could NOT create symlink. Error: {e}")
|
||||
|
||||
|
||||
def save_file(data, filename, append_to_json=True, verbose=True):
|
||||
"""
|
||||
Common i/o utility to handle saving data to various file formats.
|
||||
Supported:
|
||||
.pkl, .pickle, .npy, .json
|
||||
Specifically for .json, users have the option to either append (default)
|
||||
or rewrite by passing in Boolean value to append_to_json.
|
||||
"""
|
||||
if verbose:
|
||||
logging.info(f"Saving data to file: {filename}")
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if file_ext in [".pkl", ".pickle"]:
|
||||
with g_pathmgr.open(filename, "wb") as fopen:
|
||||
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
||||
elif file_ext == ".npy":
|
||||
with g_pathmgr.open(filename, "wb") as fopen:
|
||||
np.save(fopen, data)
|
||||
elif file_ext == ".json":
|
||||
if append_to_json:
|
||||
with g_pathmgr.open(filename, "a") as fopen:
|
||||
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||
fopen.flush()
|
||||
else:
|
||||
with g_pathmgr.open(filename, "w") as fopen:
|
||||
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||
fopen.flush()
|
||||
elif file_ext == ".yaml":
|
||||
with g_pathmgr.open(filename, "w") as fopen:
|
||||
dump = yaml.dump(data)
|
||||
fopen.write(dump)
|
||||
fopen.flush()
|
||||
else:
|
||||
raise Exception(f"Saving {file_ext} is not supported yet")
|
||||
|
||||
if verbose:
|
||||
logging.info(f"Saved data to file: {filename}")
|
||||
|
||||
|
||||
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
||||
"""
|
||||
Common i/o utility to handle loading data from various file formats.
|
||||
Supported:
|
||||
.pkl, .pickle, .npy, .json
|
||||
For the npy files, we support reading the files in mmap_mode.
|
||||
If the mmap_mode of reading is not successful, we load data without the
|
||||
mmap_mode.
|
||||
"""
|
||||
if verbose:
|
||||
logging.info(f"Loading data from file: {filename}")
|
||||
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if file_ext == ".txt":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = fopen.readlines()
|
||||
elif file_ext in [".pkl", ".pickle"]:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = pickle.load(fopen, encoding="latin1")
|
||||
elif file_ext == ".npy":
|
||||
if mmap_mode:
|
||||
try:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(
|
||||
fopen,
|
||||
allow_pickle=allow_pickle,
|
||||
encoding="latin1",
|
||||
mmap_mode=mmap_mode,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.info(
|
||||
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
||||
)
|
||||
data = np.load(
|
||||
filename,
|
||||
allow_pickle=allow_pickle,
|
||||
encoding="latin1",
|
||||
mmap_mode=mmap_mode,
|
||||
)
|
||||
logging.info("Successfully loaded without g_pathmgr")
|
||||
except Exception:
|
||||
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||
else:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||
elif file_ext == ".json":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = json.load(fopen)
|
||||
elif file_ext == ".yaml":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
||||
elif file_ext == ".csv":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = pd.read_csv(fopen)
|
||||
else:
|
||||
raise Exception(f"Reading from {file_ext} is not supported yet")
|
||||
return data
|
||||
|
||||
|
||||
def abspath(resource_path: str):
|
||||
"""
|
||||
Make a path absolute, but take into account prefixes like
|
||||
"http://" or "manifold://"
|
||||
"""
|
||||
regex = re.compile(r"^\w+://")
|
||||
if regex.match(resource_path) is None:
|
||||
return os.path.abspath(resource_path)
|
||||
else:
|
||||
return resource_path
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
logging.info(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def is_url(input_url):
|
||||
"""
|
||||
Check if an input string is a url. look for http(s):// and ignoring the case
|
||||
"""
|
||||
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
||||
return is_url
|
||||
|
||||
|
||||
def cleanup_dir(dir):
|
||||
"""
|
||||
Utility for deleting a directory. Useful for cleaning the storage space
|
||||
that contains various training artifacts like checkpoints, data etc.
|
||||
"""
|
||||
if os.path.exists(dir):
|
||||
logging.info(f"Deleting directory: {dir}")
|
||||
shutil.rmtree(dir)
|
||||
logging.info(f"Deleted contents of directory: {dir}")
|
||||
|
||||
|
||||
def get_file_size(filename):
|
||||
"""
|
||||
Given a file, get the size of file in MB
|
||||
"""
|
||||
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
||||
return size_in_mb
|
Reference in New Issue
Block a user