add readme (#10)
* Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * remove submodule * add mPLUG MiniGPT4 * Update Readme.md * Update Readme.md * Update Readme.md --------- Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
This commit is contained in:
422
models/mPLUG_Owl/pipeline/data_utils/registry.py
Normal file
422
models/mPLUG_Owl/pipeline/data_utils/registry.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) Alibaba. All rights reserved.
|
||||
import inspect
|
||||
import warnings
|
||||
import functools
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
from collections import abc
|
||||
from inspect import getfullargspec
|
||||
|
||||
|
||||
def is_seq_of(seq, expected_type, seq_type=None):
|
||||
"""Check whether it is a sequence of some type.
|
||||
Args:
|
||||
seq (Sequence): The sequence to be checked.
|
||||
expected_type (type): Expected type of sequence items.
|
||||
seq_type (type, optional): Expected sequence type.
|
||||
Returns:
|
||||
bool: Whether the sequence is valid.
|
||||
"""
|
||||
if seq_type is None:
|
||||
exp_seq_type = abc.Sequence
|
||||
else:
|
||||
assert isinstance(seq_type, type)
|
||||
exp_seq_type = seq_type
|
||||
if not isinstance(seq, exp_seq_type):
|
||||
return False
|
||||
for item in seq:
|
||||
if not isinstance(item, expected_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def deprecated_api_warning(name_dict, cls_name=None):
|
||||
"""A decorator to check if some arguments are deprecate and try to replace
|
||||
deprecate src_arg_name to dst_arg_name.
|
||||
Args:
|
||||
name_dict(dict):
|
||||
key (str): Deprecate argument names.
|
||||
val (str): Expected argument names.
|
||||
Returns:
|
||||
func: New function.
|
||||
"""
|
||||
|
||||
def api_warning_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get name of the function
|
||||
func_name = old_func.__name__
|
||||
if cls_name is not None:
|
||||
func_name = f'{cls_name}.{func_name}'
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in arg_names:
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
||||
if kwargs:
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in kwargs:
|
||||
|
||||
assert dst_arg_name not in kwargs, (
|
||||
f'The expected behavior is to replace '
|
||||
f'the deprecated key `{src_arg_name}` to '
|
||||
f'new key `{dst_arg_name}`, but got them '
|
||||
f'in the arguments at the same time, which '
|
||||
f'is confusing. `{src_arg_name} will be '
|
||||
f'deprecated in the future, please '
|
||||
f'use `{dst_arg_name}` instead.')
|
||||
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
||||
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*args, **kwargs)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return api_warning_wrapper
|
||||
|
||||
|
||||
|
||||
def build_from_cfg(cfg: Dict,
|
||||
registry: 'Registry',
|
||||
default_args: Optional[Dict] = None) -> Any:
|
||||
"""Build a module from config dict when it is a class configuration, or
|
||||
call a function from config dict when it is a function configuration.
|
||||
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
|
||||
>>> # Returns an instantiated object
|
||||
>>> @MODELS.register_module()
|
||||
>>> def resnet50():
|
||||
>>> pass
|
||||
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
||||
>>> # Return a result of the calling function
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an mmcv.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry')
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}')
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry to map strings to classes or functions.
|
||||
|
||||
Registered object could be built from registry. Meanwhile, registered
|
||||
functions could be called from registry.
|
||||
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||
>>> @MODELS.register_module()
|
||||
>>> def resnet50():
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='resnet50'))
|
||||
|
||||
Please refer to
|
||||
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
||||
advanced usage.
|
||||
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
build_func(func, optional): Build function to construct instance from
|
||||
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
||||
``build_func`` is specified. If ``parent`` is specified and
|
||||
``build_func`` is not given, ``build_func`` will be inherited
|
||||
from ``parent``. Default: None.
|
||||
parent (Registry, optional): Parent registry. The class registered in
|
||||
children registry could be built from parent. Default: None.
|
||||
scope (str, optional): The scope of registry. It is the key to search
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, name, build_func=None, parent=None, scope=None):
|
||||
self._name = name
|
||||
self._module_dict = dict()
|
||||
self._children = dict()
|
||||
self._scope = self.infer_scope() if scope is None else scope
|
||||
|
||||
# self.build_func will be set with the following priority:
|
||||
# 1. build_func
|
||||
# 2. parent.build_func
|
||||
# 3. build_from_cfg
|
||||
if build_func is None:
|
||||
if parent is not None:
|
||||
self.build_func = parent.build_func
|
||||
else:
|
||||
self.build_func = build_from_cfg
|
||||
else:
|
||||
self.build_func = build_func
|
||||
if parent is not None:
|
||||
assert isinstance(parent, Registry)
|
||||
parent._add_children(self)
|
||||
self.parent = parent
|
||||
else:
|
||||
self.parent = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._module_dict)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__ + \
|
||||
f'(name={self._name}, ' \
|
||||
f'items={self._module_dict})'
|
||||
return format_str
|
||||
|
||||
@staticmethod
|
||||
def infer_scope():
|
||||
"""Infer the scope of registry.
|
||||
|
||||
The name of the package where registry is defined will be returned.
|
||||
|
||||
Example:
|
||||
>>> # in mmdet/models/backbone/resnet.py
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
The scope of ``ResNet`` will be ``mmdet``.
|
||||
|
||||
Returns:
|
||||
str: The inferred scope name.
|
||||
"""
|
||||
# We access the caller using inspect.currentframe() instead of
|
||||
# inspect.stack() for performance reasons. See details in PR #1844
|
||||
frame = inspect.currentframe()
|
||||
# get the frame where `infer_scope()` is called
|
||||
infer_scope_caller = frame.f_back.f_back
|
||||
filename = inspect.getmodule(infer_scope_caller).__name__
|
||||
split_filename = filename.split('.')
|
||||
return split_filename[0]
|
||||
|
||||
@staticmethod
|
||||
def split_scope_key(key):
|
||||
"""Split scope and key.
|
||||
|
||||
The first scope will be split from key.
|
||||
|
||||
Examples:
|
||||
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||
'mmdet', 'ResNet'
|
||||
>>> Registry.split_scope_key('ResNet')
|
||||
None, 'ResNet'
|
||||
|
||||
Return:
|
||||
tuple[str | None, str]: The former element is the first scope of
|
||||
the key, which can be ``None``. The latter is the remaining key.
|
||||
"""
|
||||
split_index = key.find('.')
|
||||
if split_index != -1:
|
||||
return key[:split_index], key[split_index + 1:]
|
||||
else:
|
||||
return None, key
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
return self._module_dict[real_key]
|
||||
else:
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
return self._children[scope].get(real_key)
|
||||
else:
|
||||
# goto root
|
||||
parent = self.parent
|
||||
while parent.parent is not None:
|
||||
parent = parent.parent
|
||||
return parent.get(key)
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
return self.build_func(*args, **kwargs, registry=self)
|
||||
|
||||
def _add_children(self, registry):
|
||||
"""Add children for a registry.
|
||||
|
||||
The ``registry`` will be added as children based on its scope.
|
||||
The parent registry could build objects from children registry.
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> mmdet_models = Registry('models', parent=models)
|
||||
>>> @mmdet_models.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
||||
"""
|
||||
|
||||
assert isinstance(registry, Registry)
|
||||
assert registry.scope is not None
|
||||
assert registry.scope not in self.children, \
|
||||
f'scope {registry.scope} exists in {self.name} registry'
|
||||
self.children[registry.scope] = registry
|
||||
|
||||
@deprecated_api_warning(name_dict=dict(module_class='module'))
|
||||
def _register_module(self, module, module_name=None, force=False):
|
||||
if not inspect.isclass(module) and not inspect.isfunction(module):
|
||||
raise TypeError('module must be a class or a function, '
|
||||
f'but got {type(module)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in self._module_dict:
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'in {self.name}')
|
||||
self._module_dict[name] = module
|
||||
|
||||
def deprecated_register_module(self, cls=None, force=False):
|
||||
warnings.warn(
|
||||
'The old API of register_module(module, force=False) '
|
||||
'is deprecated and will be removed, please use the new API '
|
||||
'register_module(name=None, force=False, module=None) instead.',
|
||||
DeprecationWarning)
|
||||
if cls is None:
|
||||
return partial(self.deprecated_register_module, force=force)
|
||||
self._register_module(cls, force=force)
|
||||
return cls
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
"""Register a module.
|
||||
|
||||
A record will be added to `self._module_dict`, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
|
||||
Example:
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module(name='mnet')
|
||||
>>> class MobileNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> backbones.register_module(ResNet)
|
||||
|
||||
Args:
|
||||
name (str | None): The module name to be registered. If not
|
||||
specified, the class name will be used.
|
||||
force (bool, optional): Whether to override an existing class with
|
||||
the same name. Default: False.
|
||||
module (type): Module class or function to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
# NOTE: This is a walkaround to be compatible with the old api,
|
||||
# while it may introduce unexpected bugs.
|
||||
if isinstance(name, type):
|
||||
return self.deprecated_register_module(name, force=force)
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be either of None, an instance of str or a sequence'
|
||||
f' of str, but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
self._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
Reference in New Issue
Block a user