
* Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * remove submodule * add mPLUG MiniGPT4 * Update Readme.md * Update Readme.md * Update Readme.md --------- Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
422 lines
15 KiB
Python
422 lines
15 KiB
Python
# Copyright (c) Alibaba. All rights reserved.
|
|
import inspect
|
|
import warnings
|
|
import functools
|
|
from functools import partial
|
|
from typing import Any, Dict, Optional
|
|
from collections import abc
|
|
from inspect import getfullargspec
|
|
|
|
|
|
def is_seq_of(seq, expected_type, seq_type=None):
|
|
"""Check whether it is a sequence of some type.
|
|
Args:
|
|
seq (Sequence): The sequence to be checked.
|
|
expected_type (type): Expected type of sequence items.
|
|
seq_type (type, optional): Expected sequence type.
|
|
Returns:
|
|
bool: Whether the sequence is valid.
|
|
"""
|
|
if seq_type is None:
|
|
exp_seq_type = abc.Sequence
|
|
else:
|
|
assert isinstance(seq_type, type)
|
|
exp_seq_type = seq_type
|
|
if not isinstance(seq, exp_seq_type):
|
|
return False
|
|
for item in seq:
|
|
if not isinstance(item, expected_type):
|
|
return False
|
|
return True
|
|
|
|
|
|
def deprecated_api_warning(name_dict, cls_name=None):
|
|
"""A decorator to check if some arguments are deprecate and try to replace
|
|
deprecate src_arg_name to dst_arg_name.
|
|
Args:
|
|
name_dict(dict):
|
|
key (str): Deprecate argument names.
|
|
val (str): Expected argument names.
|
|
Returns:
|
|
func: New function.
|
|
"""
|
|
|
|
def api_warning_wrapper(old_func):
|
|
|
|
@functools.wraps(old_func)
|
|
def new_func(*args, **kwargs):
|
|
# get the arg spec of the decorated method
|
|
args_info = getfullargspec(old_func)
|
|
# get name of the function
|
|
func_name = old_func.__name__
|
|
if cls_name is not None:
|
|
func_name = f'{cls_name}.{func_name}'
|
|
if args:
|
|
arg_names = args_info.args[:len(args)]
|
|
for src_arg_name, dst_arg_name in name_dict.items():
|
|
if src_arg_name in arg_names:
|
|
warnings.warn(
|
|
f'"{src_arg_name}" is deprecated in '
|
|
f'`{func_name}`, please use "{dst_arg_name}" '
|
|
'instead', DeprecationWarning)
|
|
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
|
if kwargs:
|
|
for src_arg_name, dst_arg_name in name_dict.items():
|
|
if src_arg_name in kwargs:
|
|
|
|
assert dst_arg_name not in kwargs, (
|
|
f'The expected behavior is to replace '
|
|
f'the deprecated key `{src_arg_name}` to '
|
|
f'new key `{dst_arg_name}`, but got them '
|
|
f'in the arguments at the same time, which '
|
|
f'is confusing. `{src_arg_name} will be '
|
|
f'deprecated in the future, please '
|
|
f'use `{dst_arg_name}` instead.')
|
|
|
|
warnings.warn(
|
|
f'"{src_arg_name}" is deprecated in '
|
|
f'`{func_name}`, please use "{dst_arg_name}" '
|
|
'instead', DeprecationWarning)
|
|
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
|
|
|
# apply converted arguments to the decorated method
|
|
output = old_func(*args, **kwargs)
|
|
return output
|
|
|
|
return new_func
|
|
|
|
return api_warning_wrapper
|
|
|
|
|
|
|
|
def build_from_cfg(cfg: Dict,
|
|
registry: 'Registry',
|
|
default_args: Optional[Dict] = None) -> Any:
|
|
"""Build a module from config dict when it is a class configuration, or
|
|
call a function from config dict when it is a function configuration.
|
|
|
|
Example:
|
|
>>> MODELS = Registry('models')
|
|
>>> @MODELS.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
|
|
>>> # Returns an instantiated object
|
|
>>> @MODELS.register_module()
|
|
>>> def resnet50():
|
|
>>> pass
|
|
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
|
>>> # Return a result of the calling function
|
|
|
|
Args:
|
|
cfg (dict): Config dict. It should at least contain the key "type".
|
|
registry (:obj:`Registry`): The registry to search the type from.
|
|
default_args (dict, optional): Default initialization arguments.
|
|
|
|
Returns:
|
|
object: The constructed object.
|
|
"""
|
|
if not isinstance(cfg, dict):
|
|
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
|
if 'type' not in cfg:
|
|
if default_args is None or 'type' not in default_args:
|
|
raise KeyError(
|
|
'`cfg` or `default_args` must contain the key "type", '
|
|
f'but got {cfg}\n{default_args}')
|
|
if not isinstance(registry, Registry):
|
|
raise TypeError('registry must be an mmcv.Registry object, '
|
|
f'but got {type(registry)}')
|
|
if not (isinstance(default_args, dict) or default_args is None):
|
|
raise TypeError('default_args must be a dict or None, '
|
|
f'but got {type(default_args)}')
|
|
|
|
args = cfg.copy()
|
|
|
|
if default_args is not None:
|
|
for name, value in default_args.items():
|
|
args.setdefault(name, value)
|
|
|
|
obj_type = args.pop('type')
|
|
if isinstance(obj_type, str):
|
|
obj_cls = registry.get(obj_type)
|
|
if obj_cls is None:
|
|
raise KeyError(
|
|
f'{obj_type} is not in the {registry.name} registry')
|
|
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
|
obj_cls = obj_type
|
|
else:
|
|
raise TypeError(
|
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
|
try:
|
|
return obj_cls(**args)
|
|
except Exception as e:
|
|
# Normal TypeError does not print class name.
|
|
raise type(e)(f'{obj_cls.__name__}: {e}')
|
|
|
|
|
|
class Registry:
|
|
"""A registry to map strings to classes or functions.
|
|
|
|
Registered object could be built from registry. Meanwhile, registered
|
|
functions could be called from registry.
|
|
|
|
Example:
|
|
>>> MODELS = Registry('models')
|
|
>>> @MODELS.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
|
>>> @MODELS.register_module()
|
|
>>> def resnet50():
|
|
>>> pass
|
|
>>> resnet = MODELS.build(dict(type='resnet50'))
|
|
|
|
Please refer to
|
|
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
|
advanced usage.
|
|
|
|
Args:
|
|
name (str): Registry name.
|
|
build_func(func, optional): Build function to construct instance from
|
|
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
|
``build_func`` is specified. If ``parent`` is specified and
|
|
``build_func`` is not given, ``build_func`` will be inherited
|
|
from ``parent``. Default: None.
|
|
parent (Registry, optional): Parent registry. The class registered in
|
|
children registry could be built from parent. Default: None.
|
|
scope (str, optional): The scope of registry. It is the key to search
|
|
for children registry. If not specified, scope will be the name of
|
|
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self, name, build_func=None, parent=None, scope=None):
|
|
self._name = name
|
|
self._module_dict = dict()
|
|
self._children = dict()
|
|
self._scope = self.infer_scope() if scope is None else scope
|
|
|
|
# self.build_func will be set with the following priority:
|
|
# 1. build_func
|
|
# 2. parent.build_func
|
|
# 3. build_from_cfg
|
|
if build_func is None:
|
|
if parent is not None:
|
|
self.build_func = parent.build_func
|
|
else:
|
|
self.build_func = build_from_cfg
|
|
else:
|
|
self.build_func = build_func
|
|
if parent is not None:
|
|
assert isinstance(parent, Registry)
|
|
parent._add_children(self)
|
|
self.parent = parent
|
|
else:
|
|
self.parent = None
|
|
|
|
def __len__(self):
|
|
return len(self._module_dict)
|
|
|
|
def __contains__(self, key):
|
|
return self.get(key) is not None
|
|
|
|
def __repr__(self):
|
|
format_str = self.__class__.__name__ + \
|
|
f'(name={self._name}, ' \
|
|
f'items={self._module_dict})'
|
|
return format_str
|
|
|
|
@staticmethod
|
|
def infer_scope():
|
|
"""Infer the scope of registry.
|
|
|
|
The name of the package where registry is defined will be returned.
|
|
|
|
Example:
|
|
>>> # in mmdet/models/backbone/resnet.py
|
|
>>> MODELS = Registry('models')
|
|
>>> @MODELS.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
The scope of ``ResNet`` will be ``mmdet``.
|
|
|
|
Returns:
|
|
str: The inferred scope name.
|
|
"""
|
|
# We access the caller using inspect.currentframe() instead of
|
|
# inspect.stack() for performance reasons. See details in PR #1844
|
|
frame = inspect.currentframe()
|
|
# get the frame where `infer_scope()` is called
|
|
infer_scope_caller = frame.f_back.f_back
|
|
filename = inspect.getmodule(infer_scope_caller).__name__
|
|
split_filename = filename.split('.')
|
|
return split_filename[0]
|
|
|
|
@staticmethod
|
|
def split_scope_key(key):
|
|
"""Split scope and key.
|
|
|
|
The first scope will be split from key.
|
|
|
|
Examples:
|
|
>>> Registry.split_scope_key('mmdet.ResNet')
|
|
'mmdet', 'ResNet'
|
|
>>> Registry.split_scope_key('ResNet')
|
|
None, 'ResNet'
|
|
|
|
Return:
|
|
tuple[str | None, str]: The former element is the first scope of
|
|
the key, which can be ``None``. The latter is the remaining key.
|
|
"""
|
|
split_index = key.find('.')
|
|
if split_index != -1:
|
|
return key[:split_index], key[split_index + 1:]
|
|
else:
|
|
return None, key
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
@property
|
|
def scope(self):
|
|
return self._scope
|
|
|
|
@property
|
|
def module_dict(self):
|
|
return self._module_dict
|
|
|
|
@property
|
|
def children(self):
|
|
return self._children
|
|
|
|
def get(self, key):
|
|
"""Get the registry record.
|
|
|
|
Args:
|
|
key (str): The class name in string format.
|
|
|
|
Returns:
|
|
class: The corresponding class.
|
|
"""
|
|
scope, real_key = self.split_scope_key(key)
|
|
if scope is None or scope == self._scope:
|
|
# get from self
|
|
if real_key in self._module_dict:
|
|
return self._module_dict[real_key]
|
|
else:
|
|
# get from self._children
|
|
if scope in self._children:
|
|
return self._children[scope].get(real_key)
|
|
else:
|
|
# goto root
|
|
parent = self.parent
|
|
while parent.parent is not None:
|
|
parent = parent.parent
|
|
return parent.get(key)
|
|
|
|
def build(self, *args, **kwargs):
|
|
return self.build_func(*args, **kwargs, registry=self)
|
|
|
|
def _add_children(self, registry):
|
|
"""Add children for a registry.
|
|
|
|
The ``registry`` will be added as children based on its scope.
|
|
The parent registry could build objects from children registry.
|
|
|
|
Example:
|
|
>>> models = Registry('models')
|
|
>>> mmdet_models = Registry('models', parent=models)
|
|
>>> @mmdet_models.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
|
"""
|
|
|
|
assert isinstance(registry, Registry)
|
|
assert registry.scope is not None
|
|
assert registry.scope not in self.children, \
|
|
f'scope {registry.scope} exists in {self.name} registry'
|
|
self.children[registry.scope] = registry
|
|
|
|
@deprecated_api_warning(name_dict=dict(module_class='module'))
|
|
def _register_module(self, module, module_name=None, force=False):
|
|
if not inspect.isclass(module) and not inspect.isfunction(module):
|
|
raise TypeError('module must be a class or a function, '
|
|
f'but got {type(module)}')
|
|
|
|
if module_name is None:
|
|
module_name = module.__name__
|
|
if isinstance(module_name, str):
|
|
module_name = [module_name]
|
|
for name in module_name:
|
|
if not force and name in self._module_dict:
|
|
raise KeyError(f'{name} is already registered '
|
|
f'in {self.name}')
|
|
self._module_dict[name] = module
|
|
|
|
def deprecated_register_module(self, cls=None, force=False):
|
|
warnings.warn(
|
|
'The old API of register_module(module, force=False) '
|
|
'is deprecated and will be removed, please use the new API '
|
|
'register_module(name=None, force=False, module=None) instead.',
|
|
DeprecationWarning)
|
|
if cls is None:
|
|
return partial(self.deprecated_register_module, force=force)
|
|
self._register_module(cls, force=force)
|
|
return cls
|
|
|
|
def register_module(self, name=None, force=False, module=None):
|
|
"""Register a module.
|
|
|
|
A record will be added to `self._module_dict`, whose key is the class
|
|
name or the specified name, and value is the class itself.
|
|
It can be used as a decorator or a normal function.
|
|
|
|
Example:
|
|
>>> backbones = Registry('backbone')
|
|
>>> @backbones.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
|
|
>>> backbones = Registry('backbone')
|
|
>>> @backbones.register_module(name='mnet')
|
|
>>> class MobileNet:
|
|
>>> pass
|
|
|
|
>>> backbones = Registry('backbone')
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> backbones.register_module(ResNet)
|
|
|
|
Args:
|
|
name (str | None): The module name to be registered. If not
|
|
specified, the class name will be used.
|
|
force (bool, optional): Whether to override an existing class with
|
|
the same name. Default: False.
|
|
module (type): Module class or function to be registered.
|
|
"""
|
|
if not isinstance(force, bool):
|
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
|
# NOTE: This is a walkaround to be compatible with the old api,
|
|
# while it may introduce unexpected bugs.
|
|
if isinstance(name, type):
|
|
return self.deprecated_register_module(name, force=force)
|
|
|
|
# raise the error ahead of time
|
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
|
raise TypeError(
|
|
'name must be either of None, an instance of str or a sequence'
|
|
f' of str, but got {type(name)}')
|
|
|
|
# use it as a normal method: x.register_module(module=SomeClass)
|
|
if module is not None:
|
|
self._register_module(module=module, module_name=name, force=force)
|
|
return module
|
|
|
|
# use it as a decorator: @x.register_module()
|
|
def _register(module):
|
|
self._register_module(module=module, module_name=name, force=force)
|
|
return module
|
|
|
|
return _register |