add readme (#10)
* Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * Update Readme.md * remove submodule * add mPLUG MiniGPT4 * Update Readme.md * Update Readme.md * Update Readme.md --------- Co-authored-by: Yuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
This commit is contained in:
24
models/mPLUG_Owl/pipeline/data_utils/__init__.py
Normal file
24
models/mPLUG_Owl/pipeline/data_utils/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .processors.builder import build_processors
|
||||
from .xgpt3_dataset import MultiModalDataset
|
||||
|
||||
def train_valid_test_datasets_provider(data_path, config, tokenizer, seq_length=1024):
|
||||
"""Build train and valid datasets."""
|
||||
print('> building train and validation datasets for mPLUG-Owl ...')
|
||||
train_ds, valid_ds = build_train_valid_test_datasets(
|
||||
input_file=data_path,
|
||||
tokenizer=tokenizer,
|
||||
max_length=seq_length,
|
||||
config=config)
|
||||
print("> finished creating mPLUG-Owl datasets ...")
|
||||
|
||||
return train_ds, valid_ds
|
||||
|
||||
def build_train_valid_test_datasets(input_file, tokenizer, max_length=80, config=None):
|
||||
train_processors = build_processors(config['train_processors'])
|
||||
valid_processors = build_processors(config['valid_processors'])
|
||||
|
||||
assert len(input_file) == 2 # If you have files more than 2, modify code at here or merger them into train and dev
|
||||
train_ds = MultiModalDataset(input_file[0], tokenizer, train_processors, max_length)
|
||||
valid_ds = MultiModalDataset(input_file[1], tokenizer, valid_processors, max_length)
|
||||
test_ds = None
|
||||
return (train_ds, valid_ds)
|
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Alibaba. All rights reserved.
|
||||
from .builder import PROCESSORS, build_processors
|
||||
from .default_processor import DefaultProcessor
|
||||
from .caption_processor import CaptionProcessor
|
||||
|
||||
__all__ = [
|
||||
'PROCESSORS', 'build_processors',
|
||||
'DefaultProcessor', 'CaptionProcessor'
|
||||
]
|
12
models/mPLUG_Owl/pipeline/data_utils/processors/builder.py
Normal file
12
models/mPLUG_Owl/pipeline/data_utils/processors/builder.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from data_utils.registry import Registry, build_from_cfg
|
||||
|
||||
PROCESSORS = Registry('processors')
|
||||
|
||||
def build_processors(processors_cfg):
|
||||
processors = dict()
|
||||
for task, processor in processors_cfg.items():
|
||||
processors[task] = build_from_cfg(processor, PROCESSORS)
|
||||
return processors
|
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
from data_utils.randaugment import RandomAugment
|
||||
from .builder import PROCESSORS
|
||||
|
||||
|
||||
@PROCESSORS.register_module()
|
||||
class CaptionProcessor:
|
||||
def __init__(self, image_size=224, min_scale = 0.5, randaug=False):
|
||||
self.image_size = image_size
|
||||
self.min_scale = min_scale
|
||||
|
||||
if randaug:
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
|
||||
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
else:
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
self.text_transform = None
|
||||
|
||||
def __call__(self, image, text):
|
||||
assert image or text
|
||||
|
||||
if image:
|
||||
image_input = self.image_transform(image)
|
||||
else:
|
||||
image_input = None
|
||||
|
||||
if text:
|
||||
if isinstance(text["prompt"], list):
|
||||
prompt = random.choice(text["prompt"])
|
||||
else:
|
||||
prompt = text["prompt"]
|
||||
text_input = dict(
|
||||
prompt=prompt,
|
||||
completion=text["text"],
|
||||
)
|
||||
else:
|
||||
text_input = None
|
||||
return image_input, text_input
|
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
from data_utils.randaugment import RandomAugment
|
||||
from .builder import PROCESSORS
|
||||
|
||||
|
||||
@PROCESSORS.register_module()
|
||||
class DefaultProcessor:
|
||||
def __init__(self, image_size=224):
|
||||
self.image_size = image_size
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize((image_size, image_size),interpolation=Image.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
self.text_transform = None
|
||||
|
||||
def __call__(self, image, text):
|
||||
assert image or text
|
||||
|
||||
if image:
|
||||
image_input = self.image_transform(image)
|
||||
else:
|
||||
image_input = None
|
||||
|
||||
if text:
|
||||
if isinstance(text["prompt"], list):
|
||||
prompt = random.choice(text["prompt"])
|
||||
else:
|
||||
prompt = text["prompt"]
|
||||
text_input = dict(
|
||||
prompt=prompt,
|
||||
completion=text["text"],
|
||||
)
|
||||
else:
|
||||
text_input = None
|
||||
return image_input, text_input
|
345
models/mPLUG_Owl/pipeline/data_utils/randaugment.py
Normal file
345
models/mPLUG_Owl/pipeline/data_utils/randaugment.py
Normal file
@@ -0,0 +1,345 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
## aug functions
|
||||
def identity_func(img):
|
||||
return img
|
||||
|
||||
|
||||
def autocontrast_func(img, cutoff=0):
|
||||
'''
|
||||
same output as PIL.ImageOps.autocontrast
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
n = ch.size
|
||||
cut = cutoff * n // 100
|
||||
if cut == 0:
|
||||
high, low = ch.max(), ch.min()
|
||||
else:
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
low = np.argwhere(np.cumsum(hist) > cut)
|
||||
low = 0 if low.shape[0] == 0 else low[0]
|
||||
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
||||
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
||||
if high <= low:
|
||||
table = np.arange(n_bins)
|
||||
else:
|
||||
scale = (n_bins - 1) / (high - low)
|
||||
offset = -low * scale
|
||||
table = np.arange(n_bins) * scale + offset
|
||||
table[table < 0] = 0
|
||||
table[table > n_bins - 1] = n_bins - 1
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def equalize_func(img):
|
||||
'''
|
||||
same output as PIL.ImageOps.equalize
|
||||
PIL's implementation is different from cv2.equalize
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
non_zero_hist = hist[hist != 0].reshape(-1)
|
||||
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
||||
if step == 0: return ch
|
||||
n = np.empty_like(hist)
|
||||
n[0] = step // 2
|
||||
n[1:] = hist[:-1]
|
||||
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def rotate_func(img, degree, fill=(0, 0, 0)):
|
||||
'''
|
||||
like PIL, rotate by degree, not radians
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
center = W / 2, H / 2
|
||||
M = cv2.getRotationMatrix2D(center, degree, 1)
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
||||
return out
|
||||
|
||||
|
||||
def solarize_func(img, thresh=128):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def color_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Color
|
||||
'''
|
||||
## implementation according to PIL definition, quite slow
|
||||
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
||||
# out = blend(degenerate, img, factor)
|
||||
# M = (
|
||||
# np.eye(3) * factor
|
||||
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
||||
# )[np.newaxis, np.newaxis, :]
|
||||
M = (
|
||||
np.float32([
|
||||
[0.886, -0.114, -0.114],
|
||||
[-0.587, 0.413, -0.587],
|
||||
[-0.299, -0.299, 0.701]]) * factor
|
||||
+ np.float32([[0.114], [0.587], [0.299]])
|
||||
)
|
||||
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def contrast_func(img, factor):
|
||||
"""
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
"""
|
||||
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
||||
table = np.array([(
|
||||
el - mean) * factor + mean
|
||||
for el in range(256)
|
||||
]).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def brightness_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
'''
|
||||
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def sharpness_func(img, factor):
|
||||
'''
|
||||
The differences the this result and PIL are all on the 4 boundaries, the center
|
||||
areas are same
|
||||
'''
|
||||
kernel = np.ones((3, 3), dtype=np.float32)
|
||||
kernel[1][1] = 5
|
||||
kernel /= 13
|
||||
degenerate = cv2.filter2D(img, -1, kernel)
|
||||
if factor == 0.0:
|
||||
out = degenerate
|
||||
elif factor == 1.0:
|
||||
out = img
|
||||
else:
|
||||
out = img.astype(np.float32)
|
||||
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
||||
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
||||
out = out.astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def posterize_func(img, bits):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
||||
return out
|
||||
|
||||
|
||||
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
||||
replace = np.array(replace, dtype=np.uint8)
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
rh, rw = np.random.random(2)
|
||||
pad_size = pad_size // 2
|
||||
ch, cw = int(rh * H), int(rw * W)
|
||||
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
||||
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
||||
out = img.copy()
|
||||
out[x1:x2, y1:y2, :] = replace
|
||||
return out
|
||||
|
||||
|
||||
### level to args
|
||||
def enhance_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
||||
return level_to_args
|
||||
|
||||
|
||||
def shear_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 0.3
|
||||
if np.random.random() > 0.5: level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * float(translate_const)
|
||||
if np.random.random() > 0.5: level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * cutout_const)
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def solarize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 256)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def none_level_to_args(level):
|
||||
return ()
|
||||
|
||||
|
||||
def posterize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 4)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 30
|
||||
if np.random.random() < 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
func_dict = {
|
||||
'Identity': identity_func,
|
||||
'AutoContrast': autocontrast_func,
|
||||
'Equalize': equalize_func,
|
||||
'Rotate': rotate_func,
|
||||
'Solarize': solarize_func,
|
||||
'Color': color_func,
|
||||
'Contrast': contrast_func,
|
||||
'Brightness': brightness_func,
|
||||
'Sharpness': sharpness_func,
|
||||
'ShearX': shear_x_func,
|
||||
'TranslateX': translate_x_func,
|
||||
'TranslateY': translate_y_func,
|
||||
'Posterize': posterize_func,
|
||||
'ShearY': shear_y_func,
|
||||
}
|
||||
|
||||
translate_const = 10
|
||||
MAX_LEVEL = 10
|
||||
replace_value = (128, 128, 128)
|
||||
arg_dict = {
|
||||
'Identity': none_level_to_args,
|
||||
'AutoContrast': none_level_to_args,
|
||||
'Equalize': none_level_to_args,
|
||||
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
||||
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
||||
'Color': enhance_level_to_args(MAX_LEVEL),
|
||||
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
||||
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
||||
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
||||
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
'TranslateX': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'TranslateY': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
||||
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
}
|
||||
|
||||
|
||||
class RandomAugment(object):
|
||||
|
||||
def __init__(self, N=2, M=10, isPIL=False, returnPIL=False, augs=[]):
|
||||
self.N = N
|
||||
self.M = M
|
||||
self.isPIL = isPIL
|
||||
self.returnPIL = returnPIL
|
||||
if augs:
|
||||
self.augs = augs
|
||||
else:
|
||||
self.augs = list(arg_dict.keys())
|
||||
|
||||
def get_random_ops(self):
|
||||
sampled_ops = np.random.choice(self.augs, self.N)
|
||||
return [(op, 0.5, self.M) for op in sampled_ops]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.isPIL:
|
||||
img = np.array(img)
|
||||
ops = self.get_random_ops()
|
||||
for name, prob, level in ops:
|
||||
if np.random.random() > prob:
|
||||
continue
|
||||
args = arg_dict[name](level)
|
||||
img = func_dict[name](img, *args)
|
||||
if self.returnPIL:
|
||||
img = img.astype('uint8')
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = RandomAugment()
|
||||
img = np.random.randn(32, 32, 3)
|
||||
a(img)
|
422
models/mPLUG_Owl/pipeline/data_utils/registry.py
Normal file
422
models/mPLUG_Owl/pipeline/data_utils/registry.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) Alibaba. All rights reserved.
|
||||
import inspect
|
||||
import warnings
|
||||
import functools
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
from collections import abc
|
||||
from inspect import getfullargspec
|
||||
|
||||
|
||||
def is_seq_of(seq, expected_type, seq_type=None):
|
||||
"""Check whether it is a sequence of some type.
|
||||
Args:
|
||||
seq (Sequence): The sequence to be checked.
|
||||
expected_type (type): Expected type of sequence items.
|
||||
seq_type (type, optional): Expected sequence type.
|
||||
Returns:
|
||||
bool: Whether the sequence is valid.
|
||||
"""
|
||||
if seq_type is None:
|
||||
exp_seq_type = abc.Sequence
|
||||
else:
|
||||
assert isinstance(seq_type, type)
|
||||
exp_seq_type = seq_type
|
||||
if not isinstance(seq, exp_seq_type):
|
||||
return False
|
||||
for item in seq:
|
||||
if not isinstance(item, expected_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def deprecated_api_warning(name_dict, cls_name=None):
|
||||
"""A decorator to check if some arguments are deprecate and try to replace
|
||||
deprecate src_arg_name to dst_arg_name.
|
||||
Args:
|
||||
name_dict(dict):
|
||||
key (str): Deprecate argument names.
|
||||
val (str): Expected argument names.
|
||||
Returns:
|
||||
func: New function.
|
||||
"""
|
||||
|
||||
def api_warning_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get name of the function
|
||||
func_name = old_func.__name__
|
||||
if cls_name is not None:
|
||||
func_name = f'{cls_name}.{func_name}'
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in arg_names:
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
||||
if kwargs:
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in kwargs:
|
||||
|
||||
assert dst_arg_name not in kwargs, (
|
||||
f'The expected behavior is to replace '
|
||||
f'the deprecated key `{src_arg_name}` to '
|
||||
f'new key `{dst_arg_name}`, but got them '
|
||||
f'in the arguments at the same time, which '
|
||||
f'is confusing. `{src_arg_name} will be '
|
||||
f'deprecated in the future, please '
|
||||
f'use `{dst_arg_name}` instead.')
|
||||
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
||||
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*args, **kwargs)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return api_warning_wrapper
|
||||
|
||||
|
||||
|
||||
def build_from_cfg(cfg: Dict,
|
||||
registry: 'Registry',
|
||||
default_args: Optional[Dict] = None) -> Any:
|
||||
"""Build a module from config dict when it is a class configuration, or
|
||||
call a function from config dict when it is a function configuration.
|
||||
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
|
||||
>>> # Returns an instantiated object
|
||||
>>> @MODELS.register_module()
|
||||
>>> def resnet50():
|
||||
>>> pass
|
||||
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
||||
>>> # Return a result of the calling function
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an mmcv.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry')
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}')
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry to map strings to classes or functions.
|
||||
|
||||
Registered object could be built from registry. Meanwhile, registered
|
||||
functions could be called from registry.
|
||||
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||
>>> @MODELS.register_module()
|
||||
>>> def resnet50():
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='resnet50'))
|
||||
|
||||
Please refer to
|
||||
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
||||
advanced usage.
|
||||
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
build_func(func, optional): Build function to construct instance from
|
||||
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
||||
``build_func`` is specified. If ``parent`` is specified and
|
||||
``build_func`` is not given, ``build_func`` will be inherited
|
||||
from ``parent``. Default: None.
|
||||
parent (Registry, optional): Parent registry. The class registered in
|
||||
children registry could be built from parent. Default: None.
|
||||
scope (str, optional): The scope of registry. It is the key to search
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, name, build_func=None, parent=None, scope=None):
|
||||
self._name = name
|
||||
self._module_dict = dict()
|
||||
self._children = dict()
|
||||
self._scope = self.infer_scope() if scope is None else scope
|
||||
|
||||
# self.build_func will be set with the following priority:
|
||||
# 1. build_func
|
||||
# 2. parent.build_func
|
||||
# 3. build_from_cfg
|
||||
if build_func is None:
|
||||
if parent is not None:
|
||||
self.build_func = parent.build_func
|
||||
else:
|
||||
self.build_func = build_from_cfg
|
||||
else:
|
||||
self.build_func = build_func
|
||||
if parent is not None:
|
||||
assert isinstance(parent, Registry)
|
||||
parent._add_children(self)
|
||||
self.parent = parent
|
||||
else:
|
||||
self.parent = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._module_dict)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__ + \
|
||||
f'(name={self._name}, ' \
|
||||
f'items={self._module_dict})'
|
||||
return format_str
|
||||
|
||||
@staticmethod
|
||||
def infer_scope():
|
||||
"""Infer the scope of registry.
|
||||
|
||||
The name of the package where registry is defined will be returned.
|
||||
|
||||
Example:
|
||||
>>> # in mmdet/models/backbone/resnet.py
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
The scope of ``ResNet`` will be ``mmdet``.
|
||||
|
||||
Returns:
|
||||
str: The inferred scope name.
|
||||
"""
|
||||
# We access the caller using inspect.currentframe() instead of
|
||||
# inspect.stack() for performance reasons. See details in PR #1844
|
||||
frame = inspect.currentframe()
|
||||
# get the frame where `infer_scope()` is called
|
||||
infer_scope_caller = frame.f_back.f_back
|
||||
filename = inspect.getmodule(infer_scope_caller).__name__
|
||||
split_filename = filename.split('.')
|
||||
return split_filename[0]
|
||||
|
||||
@staticmethod
|
||||
def split_scope_key(key):
|
||||
"""Split scope and key.
|
||||
|
||||
The first scope will be split from key.
|
||||
|
||||
Examples:
|
||||
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||
'mmdet', 'ResNet'
|
||||
>>> Registry.split_scope_key('ResNet')
|
||||
None, 'ResNet'
|
||||
|
||||
Return:
|
||||
tuple[str | None, str]: The former element is the first scope of
|
||||
the key, which can be ``None``. The latter is the remaining key.
|
||||
"""
|
||||
split_index = key.find('.')
|
||||
if split_index != -1:
|
||||
return key[:split_index], key[split_index + 1:]
|
||||
else:
|
||||
return None, key
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
return self._module_dict[real_key]
|
||||
else:
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
return self._children[scope].get(real_key)
|
||||
else:
|
||||
# goto root
|
||||
parent = self.parent
|
||||
while parent.parent is not None:
|
||||
parent = parent.parent
|
||||
return parent.get(key)
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
return self.build_func(*args, **kwargs, registry=self)
|
||||
|
||||
def _add_children(self, registry):
|
||||
"""Add children for a registry.
|
||||
|
||||
The ``registry`` will be added as children based on its scope.
|
||||
The parent registry could build objects from children registry.
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> mmdet_models = Registry('models', parent=models)
|
||||
>>> @mmdet_models.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
||||
"""
|
||||
|
||||
assert isinstance(registry, Registry)
|
||||
assert registry.scope is not None
|
||||
assert registry.scope not in self.children, \
|
||||
f'scope {registry.scope} exists in {self.name} registry'
|
||||
self.children[registry.scope] = registry
|
||||
|
||||
@deprecated_api_warning(name_dict=dict(module_class='module'))
|
||||
def _register_module(self, module, module_name=None, force=False):
|
||||
if not inspect.isclass(module) and not inspect.isfunction(module):
|
||||
raise TypeError('module must be a class or a function, '
|
||||
f'but got {type(module)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in self._module_dict:
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'in {self.name}')
|
||||
self._module_dict[name] = module
|
||||
|
||||
def deprecated_register_module(self, cls=None, force=False):
|
||||
warnings.warn(
|
||||
'The old API of register_module(module, force=False) '
|
||||
'is deprecated and will be removed, please use the new API '
|
||||
'register_module(name=None, force=False, module=None) instead.',
|
||||
DeprecationWarning)
|
||||
if cls is None:
|
||||
return partial(self.deprecated_register_module, force=force)
|
||||
self._register_module(cls, force=force)
|
||||
return cls
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
"""Register a module.
|
||||
|
||||
A record will be added to `self._module_dict`, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
|
||||
Example:
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module(name='mnet')
|
||||
>>> class MobileNet:
|
||||
>>> pass
|
||||
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> backbones.register_module(ResNet)
|
||||
|
||||
Args:
|
||||
name (str | None): The module name to be registered. If not
|
||||
specified, the class name will be used.
|
||||
force (bool, optional): Whether to override an existing class with
|
||||
the same name. Default: False.
|
||||
module (type): Module class or function to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
# NOTE: This is a walkaround to be compatible with the old api,
|
||||
# while it may introduce unexpected bugs.
|
||||
if isinstance(name, type):
|
||||
return self.deprecated_register_module(name, force=force)
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be either of None, an instance of str or a sequence'
|
||||
f' of str, but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
self._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
256
models/mPLUG_Owl/pipeline/data_utils/xgpt3_dataset.py
Normal file
256
models/mPLUG_Owl/pipeline/data_utils/xgpt3_dataset.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from icecream import ic
|
||||
from PIL import Image, ImageFile
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
from utils import get_args
|
||||
|
||||
from .processors import build_processors
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
ImageFile.MAX_IMAGE_PIXELS = None
|
||||
Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
warnings.filterwarnings("ignore")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(filename):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
return [json.loads(l.strip("\n")) for l in f.readlines()]
|
||||
|
||||
|
||||
class MultiModalDataset(Dataset):
|
||||
"""MultiModal dataset"""
|
||||
|
||||
def __init__(self, input_files, tokenizer, processors,
|
||||
max_length=2048,
|
||||
media_tokens=['<image>']):
|
||||
args = get_args()
|
||||
self.dataset = []
|
||||
if isinstance(input_files, str):
|
||||
input_files = [input_files]
|
||||
for input_file in input_files:
|
||||
self.dataset += load_jsonl(input_file)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.processors = processors
|
||||
self.media_tokens = {k: -int(i+1) for i, k in enumerate(media_tokens)}
|
||||
self.media_lengths = {'<image>': 1+64}
|
||||
print("num_media_token: ", self.media_lengths)
|
||||
self.bucket = {}
|
||||
print(len(self.dataset))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _load_img(self, images):
|
||||
if isinstance(images, str):
|
||||
images = [images]
|
||||
image_pils = []
|
||||
for image_url in images:
|
||||
image = Image.open(image_url).convert('RGB')
|
||||
image_pils.append(image)
|
||||
return image_pils
|
||||
|
||||
def process_data(self, data, processor=None):
|
||||
# Process Image if exists
|
||||
if 'image' in data and len(data['image']) > 0:
|
||||
if 'image_data' in data:
|
||||
images = data['image_data']
|
||||
else:
|
||||
image_urls = data['image']
|
||||
images = self._load_img(image_urls)
|
||||
if processor:
|
||||
images = [processor(image=image, text=None)[0]
|
||||
for image in images]
|
||||
images = torch.stack(images, dim=0)
|
||||
else:
|
||||
images = None
|
||||
|
||||
# Process Text
|
||||
text = {
|
||||
"prompt": data.get('prompt', ""),
|
||||
"text": data["text"]
|
||||
}
|
||||
if processor:
|
||||
text = processor(image=None, text=text)[1]
|
||||
return images, text
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.dataset[index]
|
||||
task_type = data.get('task_type', 'dummy_default').split(
|
||||
'_')[-1] # Get processor type
|
||||
while True:
|
||||
try:
|
||||
# use for processing image-text pairs
|
||||
image, text = self.process_data(
|
||||
data, self.processors[task_type])
|
||||
|
||||
text_input = self._extract_text_token_from_conversation(
|
||||
text, self.max_length, index)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# print(e)
|
||||
#logging.info("Get image:{} from oss failed, retry.".format(ann))
|
||||
time.sleep(0.1)
|
||||
index = 0 if index == (len(self) - 1) else index + 1
|
||||
data = self.dataset[index]
|
||||
task_type = data.get(
|
||||
'task_type', 'dummy_default').split('_')[-1]
|
||||
continue
|
||||
break
|
||||
|
||||
batch_data = {
|
||||
"image": image,
|
||||
"text": text_input
|
||||
}
|
||||
|
||||
return batch_data
|
||||
|
||||
def _extract_text_token_from_conversation(self, data, max_length, index):
|
||||
# output enc_chunk
|
||||
enc_chunk = []
|
||||
|
||||
if self.tokenizer.bos_token_id > 0:
|
||||
prompt_chunk = [self.tokenizer.bos_token_id]
|
||||
else:
|
||||
prompt_chunk = []
|
||||
|
||||
conversation = data["completion"]
|
||||
# For Text only data
|
||||
if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
|
||||
pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: ']))
|
||||
chunk_strs = re.split(f'({pattern})', conversation)
|
||||
prompt_length = -1
|
||||
stop_flag = False
|
||||
for idx, chunk_str in enumerate(chunk_strs):
|
||||
if idx == 0:
|
||||
enc_chunk = prompt_chunk + \
|
||||
self.tokenizer(chunk_str, add_special_tokens=False)[
|
||||
'input_ids']
|
||||
enc_length = len(enc_chunk)
|
||||
label_chunk = [0] * enc_length
|
||||
else:
|
||||
if chunk_strs[idx-1] == 'AI: ':
|
||||
curr_chunk = self.tokenizer(
|
||||
chunk_str, add_special_tokens=False)['input_ids']
|
||||
if enc_length + len(curr_chunk) >= max_length:
|
||||
curr_chunk = curr_chunk[:max_length-enc_length]
|
||||
stop_flag = True
|
||||
curr_chunk += [self.tokenizer.eos_token_id]
|
||||
enc_length += len(curr_chunk)
|
||||
enc_chunk += curr_chunk
|
||||
label_chunk += [1] * len(curr_chunk)
|
||||
else:
|
||||
curr_chunk = self.tokenizer(
|
||||
chunk_str, add_special_tokens=False)['input_ids']
|
||||
if enc_length + len(curr_chunk) >= max_length + 1:
|
||||
curr_chunk = curr_chunk[:max_length+1-enc_length]
|
||||
stop_flag = True
|
||||
enc_length += len(curr_chunk)
|
||||
enc_chunk += curr_chunk
|
||||
label_chunk += [0] * len(curr_chunk)
|
||||
if stop_flag:
|
||||
break
|
||||
|
||||
# For Image-Text Data
|
||||
else:
|
||||
enc_length = 0
|
||||
prompt_length = -2
|
||||
pattern = '|'.join(
|
||||
map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: ']))
|
||||
chunk_strs = re.split(f'({pattern})', conversation)
|
||||
chunk_strs = [x for x in chunk_strs if len(x) > 0]
|
||||
for idx, chunk_str in enumerate(chunk_strs):
|
||||
if enc_length >= max_length + 1:
|
||||
break
|
||||
|
||||
if idx == 0:
|
||||
enc_chunk = prompt_chunk + \
|
||||
self.tokenizer(chunk_str, add_special_tokens=False)[
|
||||
'input_ids']
|
||||
enc_length = len(enc_chunk)
|
||||
label_chunk = [0] * enc_length
|
||||
else:
|
||||
if chunk_str in self.media_tokens:
|
||||
# [CLS] + 256 + [EOS]
|
||||
if enc_length + self.media_lengths[chunk_str] > max_length + 1:
|
||||
break
|
||||
else:
|
||||
enc_chunk += [self.media_tokens[chunk_str]
|
||||
] * self.media_lengths[chunk_str]
|
||||
enc_length += self.media_lengths[chunk_str]
|
||||
label_chunk += [0] * self.media_lengths[chunk_str]
|
||||
else:
|
||||
|
||||
if chunk_strs[idx-1] == 'AI: ':
|
||||
curr_chunk = self.tokenizer(
|
||||
chunk_str, add_special_tokens=False)['input_ids']
|
||||
if enc_length + len(curr_chunk) >= max_length:
|
||||
curr_chunk = curr_chunk[:max_length-enc_length]
|
||||
curr_chunk += [self.tokenizer.eos_token_id]
|
||||
enc_length += len(curr_chunk)
|
||||
enc_chunk += curr_chunk
|
||||
label_chunk += [1] * len(curr_chunk)
|
||||
else:
|
||||
curr_chunk = self.tokenizer(
|
||||
chunk_str, add_special_tokens=False)['input_ids']
|
||||
if enc_length + len(curr_chunk) >= max_length + 1:
|
||||
curr_chunk = curr_chunk[:max_length +
|
||||
1-enc_length]
|
||||
enc_length += len(curr_chunk)
|
||||
enc_chunk += curr_chunk
|
||||
label_chunk += [0] * len(curr_chunk)
|
||||
|
||||
if enc_length < max_length + 1:
|
||||
padding_chunk = [self.tokenizer.pad_token_id] * \
|
||||
(max_length + 1 - enc_length)
|
||||
padding_length = len(padding_chunk)
|
||||
label_chunk += [0] * (max_length + 1 - enc_length)
|
||||
enc_chunk = enc_chunk + padding_chunk
|
||||
else:
|
||||
padding_length = 0
|
||||
|
||||
assert enc_length + padding_length == max_length + \
|
||||
1, (index, prompt_length, enc_length,
|
||||
padding_length, max_length + 1)
|
||||
assert len(label_chunk) == max_length + \
|
||||
1, (len(label_chunk), max_length + 1)
|
||||
non_padding_mask = [1 if i < enc_length -
|
||||
1 else 0 for i in range(max_length)]
|
||||
|
||||
enc_chunk = torch.tensor(enc_chunk).long()
|
||||
non_padding_mask = torch.tensor(non_padding_mask).long()
|
||||
prompt_mask = torch.tensor(label_chunk)[1:].long()
|
||||
prompt_length = torch.tensor([prompt_length]).long()
|
||||
|
||||
# Create loss mask
|
||||
if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
|
||||
non_media_mask = torch.ones_like(non_padding_mask).long()
|
||||
else:
|
||||
tmp_enc_chunk = enc_chunk.clone()
|
||||
tmp_enc_chunk[tmp_enc_chunk >= 0] = 1
|
||||
tmp_enc_chunk[tmp_enc_chunk < 0] = 0
|
||||
non_media_mask = torch.tensor(tmp_enc_chunk).long()
|
||||
non_media_mask = non_media_mask[1:].long()
|
||||
return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length,
|
||||
"non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask}
|
Reference in New Issue
Block a user