add readme (#10)

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* remove submodule

* add mPLUG MiniGPT4

* Update Readme.md

* Update Readme.md

* Update Readme.md

---------

Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
This commit is contained in:
lz
2023-06-01 09:57:03 +08:00
committed by GitHub
parent 64f7eb334d
commit 3213a65d96
275 changed files with 16059 additions and 6 deletions

View File

@@ -0,0 +1,24 @@
from .processors.builder import build_processors
from .xgpt3_dataset import MultiModalDataset
def train_valid_test_datasets_provider(data_path, config, tokenizer, seq_length=1024):
"""Build train and valid datasets."""
print('> building train and validation datasets for mPLUG-Owl ...')
train_ds, valid_ds = build_train_valid_test_datasets(
input_file=data_path,
tokenizer=tokenizer,
max_length=seq_length,
config=config)
print("> finished creating mPLUG-Owl datasets ...")
return train_ds, valid_ds
def build_train_valid_test_datasets(input_file, tokenizer, max_length=80, config=None):
train_processors = build_processors(config['train_processors'])
valid_processors = build_processors(config['valid_processors'])
assert len(input_file) == 2 # If you have files more than 2, modify code at here or merger them into train and dev
train_ds = MultiModalDataset(input_file[0], tokenizer, train_processors, max_length)
valid_ds = MultiModalDataset(input_file[1], tokenizer, valid_processors, max_length)
test_ds = None
return (train_ds, valid_ds)

View File

@@ -0,0 +1,9 @@
# Copyright (c) Alibaba. All rights reserved.
from .builder import PROCESSORS, build_processors
from .default_processor import DefaultProcessor
from .caption_processor import CaptionProcessor
__all__ = [
'PROCESSORS', 'build_processors',
'DefaultProcessor', 'CaptionProcessor'
]

View File

@@ -0,0 +1,12 @@
import os
import numpy as np
from data_utils.registry import Registry, build_from_cfg
PROCESSORS = Registry('processors')
def build_processors(processors_cfg):
processors = dict()
for task, processor in processors_cfg.items():
processors[task] = build_from_cfg(processor, PROCESSORS)
return processors

View File

@@ -0,0 +1,53 @@
import torch
from torchvision import transforms
from PIL import Image
import random
from data_utils.randaugment import RandomAugment
from .builder import PROCESSORS
@PROCESSORS.register_module()
class CaptionProcessor:
def __init__(self, image_size=224, min_scale = 0.5, randaug=False):
self.image_size = image_size
self.min_scale = min_scale
if randaug:
self.image_transform = transforms.Compose([
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
else:
self.image_transform = transforms.Compose([
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.text_transform = None
def __call__(self, image, text):
assert image or text
if image:
image_input = self.image_transform(image)
else:
image_input = None
if text:
if isinstance(text["prompt"], list):
prompt = random.choice(text["prompt"])
else:
prompt = text["prompt"]
text_input = dict(
prompt=prompt,
completion=text["text"],
)
else:
text_input = None
return image_input, text_input

View File

@@ -0,0 +1,42 @@
import torch
from torchvision import transforms
from PIL import Image
import random
from data_utils.randaugment import RandomAugment
from .builder import PROCESSORS
@PROCESSORS.register_module()
class DefaultProcessor:
def __init__(self, image_size=224):
self.image_size = image_size
self.image_transform = transforms.Compose([
transforms.Resize((image_size, image_size),interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.text_transform = None
def __call__(self, image, text):
assert image or text
if image:
image_input = self.image_transform(image)
else:
image_input = None
if text:
if isinstance(text["prompt"], list):
prompt = random.choice(text["prompt"])
else:
prompt = text["prompt"]
text_input = dict(
prompt=prompt,
completion=text["text"],
)
else:
text_input = None
return image_input, text_input

View File

@@ -0,0 +1,345 @@
import cv2
import numpy as np
from PIL import Image
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0: return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, returnPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
self.returnPIL = returnPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
if self.returnPIL:
img = img.astype('uint8')
img = Image.fromarray(img)
return img
if __name__ == '__main__':
a = RandomAugment()
img = np.random.randn(32, 32, 3)
a(img)

View File

@@ -0,0 +1,422 @@
# Copyright (c) Alibaba. All rights reserved.
import inspect
import warnings
import functools
from functools import partial
from typing import Any, Dict, Optional
from collections import abc
from inspect import getfullargspec
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""
def api_warning_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get name of the function
func_name = old_func.__name__
if cls_name is not None:
func_name = f'{cls_name}.{func_name}'
if args:
arg_names = args_info.args[:len(args)]
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in arg_names:
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs:
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in kwargs:
assert dst_arg_name not in kwargs, (
f'The expected behavior is to replace '
f'the deprecated key `{src_arg_name}` to '
f'new key `{dst_arg_name}`, but got them '
f'in the arguments at the same time, which '
f'is confusing. `{src_arg_name} will be '
f'deprecated in the future, please '
f'use `{dst_arg_name}` instead.')
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method
output = old_func(*args, **kwargs)
return output
return new_func
return api_warning_wrapper
def build_from_cfg(cfg: Dict,
registry: 'Registry',
default_args: Optional[Dict] = None) -> Any:
"""Build a module from config dict when it is a class configuration, or
call a function from config dict when it is a function configuration.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
>>> # Returns an instantiated object
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
>>> # Return a result of the calling function
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
class Registry:
"""A registry to map strings to classes or functions.
Registered object could be built from registry. Meanwhile, registered
functions could be called from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = MODELS.build(dict(type='resnet50'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
>>> # in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
str: The inferred scope name.
"""
# We access the caller using inspect.currentframe() instead of
# inspect.stack() for performance reasons. See details in PR #1844
frame = inspect.currentframe()
# get the frame where `infer_scope()` is called
infer_scope_caller = frame.f_back.f_back
filename = inspect.getmodule(infer_scope_caller).__name__
split_filename = filename.split('.')
return split_filename[0]
@staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
tuple[str | None, str]: The former element is the first scope of
the key, which can be ``None``. The latter is the remaining key.
"""
split_index = key.find('.')
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert registry.scope not in self.children, \
f'scope {registry.scope} exists in {self.name} registry'
self.children[registry.scope] = registry
@deprecated_api_warning(name_dict=dict(module_class='module'))
def _register_module(self, module, module_name=None, force=False):
if not inspect.isclass(module) and not inspect.isfunction(module):
raise TypeError('module must be a class or a function, '
f'but got {type(module)}')
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class or function to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
return module
return _register

View File

@@ -0,0 +1,256 @@
import json
import logging
import os
import random
import re
import time
import traceback
import warnings
from io import BytesIO
import h5py
import numpy as np
import torch
from icecream import ic
from PIL import Image, ImageFile
from torch.utils.data import Dataset, Subset
from utils import get_args
from .processors import build_processors
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
def load_jsonl(filename):
with open(filename, "r", encoding="utf-8") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
class MultiModalDataset(Dataset):
"""MultiModal dataset"""
def __init__(self, input_files, tokenizer, processors,
max_length=2048,
media_tokens=['<image>']):
args = get_args()
self.dataset = []
if isinstance(input_files, str):
input_files = [input_files]
for input_file in input_files:
self.dataset += load_jsonl(input_file)
self.tokenizer = tokenizer
self.max_length = max_length
self.processors = processors
self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)}
self.media_lengths = {'<image>': 1+64}
print("num_media_token: ", self.media_lengths)
self.bucket = {}
print(len(self.dataset))
def __len__(self):
return len(self.dataset)
def _load_img(self, images):
if isinstance(images, str):
images = [images]
image_pils = []
for image_url in images:
image = Image.open(image_url).convert('RGB')
image_pils.append(image)
return image_pils
def process_data(self, data, processor=None):
# Process Image if exists
if 'image' in data and len(data['image']) > 0:
if 'image_data' in data:
images = data['image_data']
else:
image_urls = data['image']
images = self._load_img(image_urls)
if processor:
images = [processor(image=image, text=None)[0]
for image in images]
images = torch.stack(images, dim=0)
else:
images = None
# Process Text
text = {
"prompt": data.get('prompt', ""),
"text": data["text"]
}
if processor:
text = processor(image=None, text=text)[1]
return images, text
def __getitem__(self, index):
data = self.dataset[index]
task_type = data.get('task_type', 'dummy_default').split(
'_')[-1] # Get processor type
while True:
try:
# use for processing image-text pairs
image, text = self.process_data(
data, self.processors[task_type])
text_input = self._extract_text_token_from_conversation(
text, self.max_length, index)
except Exception as e:
traceback.print_exc()
# print(e)
#logging.info("Get image:{} from oss failed, retry.".format(ann))
time.sleep(0.1)
index = 0 if index == (len(self) - 1) else index + 1
data = self.dataset[index]
task_type = data.get(
'task_type', 'dummy_default').split('_')[-1]
continue
break
batch_data = {
"image": image,
"text": text_input
}
return batch_data
def _extract_text_token_from_conversation(self, data, max_length, index):
# output enc_chunk
enc_chunk = []
if self.tokenizer.bos_token_id > 0:
prompt_chunk = [self.tokenizer.bos_token_id]
else:
prompt_chunk = []
conversation = data["completion"]
# For Text only data
if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: ']))
chunk_strs = re.split(f'({pattern})', conversation)
prompt_length = -1
stop_flag = False
for idx, chunk_str in enumerate(chunk_strs):
if idx == 0:
enc_chunk = prompt_chunk + \
self.tokenizer(chunk_str, add_special_tokens=False)[
'input_ids']
enc_length = len(enc_chunk)
label_chunk = [0] * enc_length
else:
if chunk_strs[idx-1] == 'AI: ':
curr_chunk = self.tokenizer(
chunk_str, add_special_tokens=False)['input_ids']
if enc_length + len(curr_chunk) >= max_length:
curr_chunk = curr_chunk[:max_length-enc_length]
stop_flag = True
curr_chunk += [self.tokenizer.eos_token_id]
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [1] * len(curr_chunk)
else:
curr_chunk = self.tokenizer(
chunk_str, add_special_tokens=False)['input_ids']
if enc_length + len(curr_chunk) >= max_length + 1:
curr_chunk = curr_chunk[:max_length+1-enc_length]
stop_flag = True
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [0] * len(curr_chunk)
if stop_flag:
break
# For Image-Text Data
else:
enc_length = 0
prompt_length = -2
pattern = '|'.join(
map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: ']))
chunk_strs = re.split(f'({pattern})', conversation)
chunk_strs = [x for x in chunk_strs if len(x) > 0]
for idx, chunk_str in enumerate(chunk_strs):
if enc_length >= max_length + 1:
break
if idx == 0:
enc_chunk = prompt_chunk + \
self.tokenizer(chunk_str, add_special_tokens=False)[
'input_ids']
enc_length = len(enc_chunk)
label_chunk = [0] * enc_length
else:
if chunk_str in self.media_tokens:
# [CLS] + 256 + [EOS]
if enc_length + self.media_lengths[chunk_str] > max_length + 1:
break
else:
enc_chunk += [self.media_tokens[chunk_str]
] * self.media_lengths[chunk_str]
enc_length += self.media_lengths[chunk_str]
label_chunk += [0] * self.media_lengths[chunk_str]
else:
if chunk_strs[idx-1] == 'AI: ':
curr_chunk = self.tokenizer(
chunk_str, add_special_tokens=False)['input_ids']
if enc_length + len(curr_chunk) >= max_length:
curr_chunk = curr_chunk[:max_length-enc_length]
curr_chunk += [self.tokenizer.eos_token_id]
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [1] * len(curr_chunk)
else:
curr_chunk = self.tokenizer(
chunk_str, add_special_tokens=False)['input_ids']
if enc_length + len(curr_chunk) >= max_length + 1:
curr_chunk = curr_chunk[:max_length +
1-enc_length]
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [0] * len(curr_chunk)
if enc_length < max_length + 1:
padding_chunk = [self.tokenizer.pad_token_id] * \
(max_length + 1 - enc_length)
padding_length = len(padding_chunk)
label_chunk += [0] * (max_length + 1 - enc_length)
enc_chunk = enc_chunk + padding_chunk
else:
padding_length = 0
assert enc_length + padding_length == max_length + \
1, (index, prompt_length, enc_length,
padding_length, max_length + 1)
assert len(label_chunk) == max_length + \
1, (len(label_chunk), max_length + 1)
non_padding_mask = [1 if i < enc_length -
1 else 0 for i in range(max_length)]
enc_chunk = torch.tensor(enc_chunk).long()
non_padding_mask = torch.tensor(non_padding_mask).long()
prompt_mask = torch.tensor(label_chunk)[1:].long()
prompt_length = torch.tensor([prompt_length]).long()
# Create loss mask
if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
non_media_mask = torch.ones_like(non_padding_mask).long()
else:
tmp_enc_chunk = enc_chunk.clone()
tmp_enc_chunk[tmp_enc_chunk >= 0] = 1
tmp_enc_chunk[tmp_enc_chunk < 0] = 0
non_media_mask = torch.tensor(tmp_enc_chunk).long()
non_media_mask = non_media_mask[1:].long()
return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length,
"non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask}