First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo
This commit is contained in:
157
OCR_earsing/latent_diffusion/taming/util.py
Normal file
157
OCR_earsing/latent_diffusion/taming/util.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os, hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
URL_MAP = {
|
||||
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
||||
}
|
||||
|
||||
CKPT_MAP = {
|
||||
"vgg_lpips": "vgg.pth"
|
||||
}
|
||||
|
||||
MD5_MAP = {
|
||||
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
||||
}
|
||||
|
||||
|
||||
def download(url, local_path, chunk_size=1024):
|
||||
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||||
with open(local_path, "wb") as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
|
||||
def md5_hash(path):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content).hexdigest()
|
||||
|
||||
|
||||
def get_ckpt_path(name, root, check=False):
|
||||
assert name in URL_MAP
|
||||
path = os.path.join(root, CKPT_MAP[name])
|
||||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||||
download(URL_MAP[name], path)
|
||||
md5 = md5_hash(path)
|
||||
assert md5 == MD5_MAP[name], md5
|
||||
return path
|
||||
|
||||
|
||||
class KeyNotFoundError(Exception):
|
||||
def __init__(self, cause, keys=None, visited=None):
|
||||
self.cause = cause
|
||||
self.keys = keys
|
||||
self.visited = visited
|
||||
messages = list()
|
||||
if keys is not None:
|
||||
messages.append("Key not found: {}".format(keys))
|
||||
if visited is not None:
|
||||
messages.append("Visited: {}".format(visited))
|
||||
messages.append("Cause:\n{}".format(cause))
|
||||
message = "\n".join(messages)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def retrieve(
|
||||
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
||||
):
|
||||
"""Given a nested list or dict return the desired value at key expanding
|
||||
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
||||
is done in-place.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
list_or_dict : list or dict
|
||||
Possibly nested list or dictionary.
|
||||
key : str
|
||||
key/to/value, path like string describing all keys necessary to
|
||||
consider to get to the desired value. List indices can also be
|
||||
passed here.
|
||||
splitval : str
|
||||
String that defines the delimiter between keys of the
|
||||
different depth levels in `key`.
|
||||
default : obj
|
||||
Value returned if :attr:`key` is not found.
|
||||
expand : bool
|
||||
Whether to expand callable nodes on the path or not.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The desired value or if :attr:`default` is not ``None`` and the
|
||||
:attr:`key` is not found returns ``default``.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
keys = key.split(splitval)
|
||||
|
||||
success = True
|
||||
try:
|
||||
visited = []
|
||||
parent = None
|
||||
last_key = None
|
||||
for key in keys:
|
||||
if callable(list_or_dict):
|
||||
if not expand:
|
||||
raise KeyNotFoundError(
|
||||
ValueError(
|
||||
"Trying to get past callable node with expand=False."
|
||||
),
|
||||
keys=keys,
|
||||
visited=visited,
|
||||
)
|
||||
list_or_dict = list_or_dict()
|
||||
parent[last_key] = list_or_dict
|
||||
|
||||
last_key = key
|
||||
parent = list_or_dict
|
||||
|
||||
try:
|
||||
if isinstance(list_or_dict, dict):
|
||||
list_or_dict = list_or_dict[key]
|
||||
else:
|
||||
list_or_dict = list_or_dict[int(key)]
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
||||
|
||||
visited += [key]
|
||||
# final expansion of retrieved value
|
||||
if expand and callable(list_or_dict):
|
||||
list_or_dict = list_or_dict()
|
||||
parent[last_key] = list_or_dict
|
||||
except KeyNotFoundError as e:
|
||||
if default is None:
|
||||
raise e
|
||||
else:
|
||||
list_or_dict = default
|
||||
success = False
|
||||
|
||||
if not pass_success:
|
||||
return list_or_dict
|
||||
else:
|
||||
return list_or_dict, success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {"keya": "a",
|
||||
"keyb": "b",
|
||||
"keyc":
|
||||
{"cc1": 1,
|
||||
"cc2": 2,
|
||||
}
|
||||
}
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.create(config)
|
||||
print(config)
|
||||
retrieve(config, "keya")
|
||||
|
Reference in New Issue
Block a user