First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo
This commit is contained in:
0
OCR_earsing/latent_diffusion/taming/__init__.py
Normal file
0
OCR_earsing/latent_diffusion/taming/__init__.py
Normal file
352
OCR_earsing/latent_diffusion/taming/models/cond_transformer.py
Normal file
352
OCR_earsing/latent_diffusion/taming/models/cond_transformer.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import os, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from main import instantiate_from_config
|
||||
from taming.modules.util import SOSProvider
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class Net2NetTransformer(pl.LightningModule):
|
||||
def __init__(self,
|
||||
transformer_config,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
permuter_config=None,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
first_stage_key="image",
|
||||
cond_stage_key="depth",
|
||||
downsample_cond_size=-1,
|
||||
pkeep=1.0,
|
||||
sos_token=0,
|
||||
unconditional=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.be_unconditional = unconditional
|
||||
self.sos_token = sos_token
|
||||
self.first_stage_key = first_stage_key
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.init_first_stage_from_ckpt(first_stage_config)
|
||||
self.init_cond_stage_from_ckpt(cond_stage_config)
|
||||
if permuter_config is None:
|
||||
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
||||
self.permuter = instantiate_from_config(config=permuter_config)
|
||||
self.transformer = instantiate_from_config(config=transformer_config)
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.downsample_cond_size = downsample_cond_size
|
||||
self.pkeep = pkeep
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
for k in sd.keys():
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
self.print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def init_first_stage_from_ckpt(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
self.first_stage_model = model
|
||||
|
||||
def init_cond_stage_from_ckpt(self, config):
|
||||
if config == "__is_first_stage__":
|
||||
print("Using first stage also as cond stage.")
|
||||
self.cond_stage_model = self.first_stage_model
|
||||
elif config == "__is_unconditional__" or self.be_unconditional:
|
||||
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
||||
f"Prepending {self.sos_token} as a sos token.")
|
||||
self.be_unconditional = True
|
||||
self.cond_stage_key = self.first_stage_key
|
||||
self.cond_stage_model = SOSProvider(self.sos_token)
|
||||
else:
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
self.cond_stage_model = model
|
||||
|
||||
def forward(self, x, c):
|
||||
# one step to produce the logits
|
||||
_, z_indices = self.encode_to_z(x)
|
||||
_, c_indices = self.encode_to_c(c)
|
||||
|
||||
if self.training and self.pkeep < 1.0:
|
||||
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
||||
device=z_indices.device))
|
||||
mask = mask.round().to(dtype=torch.int64)
|
||||
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
||||
a_indices = mask*z_indices+(1-mask)*r_indices
|
||||
else:
|
||||
a_indices = z_indices
|
||||
|
||||
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
||||
|
||||
# target includes all sequence elements (no need to handle first one
|
||||
# differently because we are conditioning)
|
||||
target = z_indices
|
||||
# make the prediction
|
||||
logits, _ = self.transformer(cz_indices[:, :-1])
|
||||
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
||||
logits = logits[:, c_indices.shape[1]-1:]
|
||||
|
||||
return logits, target
|
||||
|
||||
def top_k_logits(self, logits, k):
|
||||
v, ix = torch.topk(logits, k)
|
||||
out = logits.clone()
|
||||
out[out < v[..., [-1]]] = -float('Inf')
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
||||
callback=lambda k: None):
|
||||
x = torch.cat((c,x),dim=1)
|
||||
block_size = self.transformer.get_block_size()
|
||||
assert not self.transformer.training
|
||||
if self.pkeep <= 0.0:
|
||||
# one pass suffices since input is pure noise anyway
|
||||
assert len(x.shape)==2
|
||||
noise_shape = (x.shape[0], steps-1)
|
||||
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
||||
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
||||
x = torch.cat((x,noise),dim=1)
|
||||
logits, _ = self.transformer(x)
|
||||
# take all logits for now and scale by temp
|
||||
logits = logits / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = self.top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
shape = probs.shape
|
||||
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.reshape(shape[0],shape[1],shape[2])
|
||||
ix = ix.reshape(shape[0],shape[1])
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# cut off conditioning
|
||||
x = ix[:, c.shape[1]-1:]
|
||||
else:
|
||||
for k in range(steps):
|
||||
callback(k)
|
||||
assert x.size(1) <= block_size # make sure model can see conditioning
|
||||
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
||||
logits, _ = self.transformer(x_cond)
|
||||
# pluck the logits at the final step and scale by temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = self.top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# append to the sequence and continue
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
# cut off conditioning
|
||||
x = x[:, c.shape[1]:]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_to_z(self, x):
|
||||
quant_z, _, info = self.first_stage_model.encode(x)
|
||||
indices = info[2].view(quant_z.shape[0], -1)
|
||||
indices = self.permuter(indices)
|
||||
return quant_z, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_to_c(self, c):
|
||||
if self.downsample_cond_size > -1:
|
||||
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
||||
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
||||
if len(indices.shape) > 2:
|
||||
indices = indices.view(c.shape[0], -1)
|
||||
return quant_c, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_to_img(self, index, zshape):
|
||||
index = self.permuter(index, reverse=True)
|
||||
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
||||
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
||||
index.reshape(-1), shape=bhwc)
|
||||
x = self.first_stage_model.decode(quant_z)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
||||
log = dict()
|
||||
|
||||
N = 4
|
||||
if lr_interface:
|
||||
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
||||
else:
|
||||
x, c = self.get_xc(batch, N)
|
||||
x = x.to(device=self.device)
|
||||
c = c.to(device=self.device)
|
||||
|
||||
quant_z, z_indices = self.encode_to_z(x)
|
||||
quant_c, c_indices = self.encode_to_c(c)
|
||||
|
||||
# create a "half"" sample
|
||||
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
sample=True,
|
||||
top_k=top_k if top_k is not None else 100,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# sample
|
||||
z_start_indices = z_indices[:, :0]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1],
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
sample=True,
|
||||
top_k=top_k if top_k is not None else 100,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# det sample
|
||||
z_start_indices = z_indices[:, :0]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1],
|
||||
sample=False,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# reconstruction
|
||||
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
||||
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = x_rec
|
||||
|
||||
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
||||
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
||||
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
||||
label_for_category_no = dataset.get_textual_label_for_category_no
|
||||
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
||||
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
||||
for i in range(quant_c.shape[0]):
|
||||
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
||||
log["conditioning_rec"] = log["conditioning"]
|
||||
elif self.cond_stage_key != "image":
|
||||
cond_rec = self.cond_stage_model.decode(quant_c)
|
||||
if self.cond_stage_key == "segmentation":
|
||||
# get image from segmentation mask
|
||||
num_classes = cond_rec.shape[1]
|
||||
|
||||
c = torch.argmax(c, dim=1, keepdim=True)
|
||||
c = F.one_hot(c, num_classes=num_classes)
|
||||
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
c = self.cond_stage_model.to_rgb(c)
|
||||
|
||||
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
||||
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
||||
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
||||
log["conditioning_rec"] = cond_rec
|
||||
log["conditioning"] = c
|
||||
|
||||
log["samples_half"] = x_sample
|
||||
log["samples_nopix"] = x_sample_nopix
|
||||
log["samples_det"] = x_sample_det
|
||||
return log
|
||||
|
||||
def get_input(self, key, batch):
|
||||
x = batch[key]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
if len(x.shape) == 4:
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
||||
if x.dtype == torch.double:
|
||||
x = x.float()
|
||||
return x
|
||||
|
||||
def get_xc(self, batch, N=None):
|
||||
x = self.get_input(self.first_stage_key, batch)
|
||||
c = self.get_input(self.cond_stage_key, batch)
|
||||
if N is not None:
|
||||
x = x[:N]
|
||||
c = c[:N]
|
||||
return x, c
|
||||
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, c = self.get_xc(batch)
|
||||
logits, target = self(x, c)
|
||||
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Following minGPT:
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, )
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.transformer.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith('bias'):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# special case the position embedding parameter in the root GPT module as not decayed
|
||||
no_decay.add('pos_emb')
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||
% (str(param_dict.keys() - union_params), )
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
||||
return optimizer
|
@@ -0,0 +1,22 @@
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DummyCondStage:
|
||||
def __init__(self, conditional_key):
|
||||
self.conditional_key = conditional_key
|
||||
self.train = None
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def encode(c: Tensor):
|
||||
return c, None, (None, None, c)
|
||||
|
||||
@staticmethod
|
||||
def decode(c: Tensor):
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def to_rgb(c: Tensor):
|
||||
return c
|
404
OCR_earsing/latent_diffusion/taming/models/vqgan.py
Normal file
404
OCR_earsing/latent_diffusion/taming/models/vqgan.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from main import instantiate_from_config
|
||||
|
||||
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from taming.modules.vqvae.quantize import GumbelQuantize
|
||||
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap, sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.image_key = image_key
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input):
|
||||
quant, diff, _ = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
||||
return x.float()
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
self.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
self.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQSegmentationModel(VQModel):
|
||||
def __init__(self, n_labels, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return opt_ae
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
total_loss = log_dict_ae["val/total_loss"]
|
||||
self.log("val/total_loss", total_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
return aeloss
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
# convert logits to indices
|
||||
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
||||
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
||||
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
return log
|
||||
|
||||
|
||||
class VQNoDiscModel(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None
|
||||
):
|
||||
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
||||
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
||||
output = pl.TrainResult(minimize=aeloss)
|
||||
output.log("train/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return output
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
output = pl.EvalResult(checkpoint_on=rec_loss)
|
||||
output.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log_dict(log_dict_ae)
|
||||
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=self.learning_rate, betas=(0.5, 0.9))
|
||||
return optimizer
|
||||
|
||||
|
||||
class GumbelVQ(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
temperature_scheduler_config,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
kl_weight=1e-8,
|
||||
remap=None,
|
||||
):
|
||||
|
||||
z_channels = ddconfig["z_channels"]
|
||||
super().__init__(ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=ignore_keys,
|
||||
image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels,
|
||||
monitor=monitor,
|
||||
)
|
||||
|
||||
self.loss.n_classes = n_embed
|
||||
self.vocab_size = n_embed
|
||||
|
||||
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
||||
n_embed=n_embed,
|
||||
kl_weight=kl_weight, temp_init=1.0,
|
||||
remap=remap)
|
||||
|
||||
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def temperature_scheduling(self):
|
||||
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode_code(self, code_b):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
self.temperature_scheduling()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
self.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
# encode
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, _, _ = self.quantize(h)
|
||||
# decode
|
||||
x_rec = self.decode(quant)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = x_rec
|
||||
return log
|
||||
|
||||
|
||||
class EMAVQ(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
):
|
||||
super().__init__(ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=ignore_keys,
|
||||
image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels,
|
||||
monitor=monitor,
|
||||
)
|
||||
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
||||
embedding_dim=embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap)
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
#Remove self.quantize from parameter list since it is updated via EMA
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
@@ -0,0 +1,776 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
|
||||
def forward(self, x, t=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, **ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VUNet(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
in_channels, c_channels,
|
||||
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(c_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
self.z_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
|
||||
def forward(self, x, z):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
z = self.z_in(z)
|
||||
h = torch.cat((h,z),dim=1)
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
@@ -0,0 +1,67 @@
|
||||
import functools
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from taming.modules.util import ActNorm
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||||
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||||
"""
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
||||
"""Construct a PatchGAN discriminator
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
if not use_actnorm:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
else:
|
||||
norm_layer = ActNorm
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
use_bias = norm_layer != nn.BatchNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.main(input)
|
@@ -0,0 +1,2 @@
|
||||
from taming.modules.losses.vqperceptual import DummyLoss
|
||||
|
123
OCR_earsing/latent_diffusion/taming/modules/losses/lpips.py
Normal file
123
OCR_earsing/latent_diffusion/taming/modules/losses/lpips.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from collections import namedtuple
|
||||
|
||||
from taming.util import get_ckpt_path
|
||||
|
||||
|
||||
class LPIPS(nn.Module):
|
||||
# Learned perceptual metric
|
||||
def __init__(self, use_dropout=True):
|
||||
super().__init__()
|
||||
self.scaling_layer = ScalingLayer()
|
||||
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
||||
self.net = vgg16(pretrained=True, requires_grad=False)
|
||||
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
||||
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
||||
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
||||
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
||||
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
||||
self.load_from_pretrained()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_from_pretrained(self, name="vgg_lpips"):
|
||||
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
||||
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
||||
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, name="vgg_lpips"):
|
||||
if name != "vgg_lpips":
|
||||
raise NotImplementedError
|
||||
model = cls()
|
||||
ckpt = get_ckpt_path(name)
|
||||
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
||||
return model
|
||||
|
||||
def forward(self, input, target):
|
||||
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
||||
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
||||
for kk in range(len(self.chns)):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||
|
||||
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
||||
val = res[0]
|
||||
for l in range(1, len(self.chns)):
|
||||
val += res[l]
|
||||
return val
|
||||
|
||||
|
||||
class ScalingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(ScalingLayer, self).__init__()
|
||||
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
||||
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
||||
|
||||
def forward(self, inp):
|
||||
return (inp - self.shift) / self.scale
|
||||
|
||||
|
||||
class NetLinLayer(nn.Module):
|
||||
""" A single linear layer which does a 1x1 conv """
|
||||
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
||||
super(NetLinLayer, self).__init__()
|
||||
layers = [nn.Dropout(), ] if (use_dropout) else []
|
||||
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_tensor(x,eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
||||
return x/(norm_factor+eps)
|
||||
|
||||
|
||||
def spatial_average(x, keepdim=True):
|
||||
return x.mean([2,3],keepdim=keepdim)
|
||||
|
@@ -0,0 +1,22 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BCELoss(nn.Module):
|
||||
def forward(self, prediction, target):
|
||||
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
||||
return loss, {}
|
||||
|
||||
|
||||
class BCELossWithQuant(nn.Module):
|
||||
def __init__(self, codebook_weight=1.):
|
||||
super().__init__()
|
||||
self.codebook_weight = codebook_weight
|
||||
|
||||
def forward(self, qloss, target, prediction, split):
|
||||
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
||||
loss = bce_loss + self.codebook_weight*qloss
|
||||
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
||||
"{}/quant_loss".format(split): qloss.detach().mean()
|
||||
}
|
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
|
||||
|
||||
class DummyLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1. - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
||||
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
||||
return d_loss
|
||||
|
||||
|
||||
class VQLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||
disc_ndf=64, disc_loss="hinge"):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.codebook_weight = codebook_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
|
||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm,
|
||||
ndf=disc_ndf
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
if disc_loss == "hinge":
|
||||
self.disc_loss = hinge_d_loss
|
||||
elif disc_loss == "vanilla":
|
||||
self.disc_loss = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
||||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
||||
global_step, last_layer=None, cond=None, split="train"):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
else:
|
||||
p_loss = torch.tensor([0.0])
|
||||
|
||||
nll_loss = rec_loss
|
||||
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
nll_loss = torch.mean(nll_loss)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
||||
|
||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/p_loss".format(split): p_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||
}
|
||||
return d_loss, log
|
31
OCR_earsing/latent_diffusion/taming/modules/misc/coord.py
Normal file
31
OCR_earsing/latent_diffusion/taming/modules/misc/coord.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
class CoordStage(object):
|
||||
def __init__(self, n_embed, down_factor):
|
||||
self.n_embed = n_embed
|
||||
self.down_factor = down_factor
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, c):
|
||||
"""fake vqmodel interface"""
|
||||
assert 0.0 <= c.min() and c.max() <= 1.0
|
||||
b,ch,h,w = c.shape
|
||||
assert ch == 1
|
||||
|
||||
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
||||
mode="area")
|
||||
c = c.clamp(0.0, 1.0)
|
||||
c = self.n_embed*c
|
||||
c_quant = c.round()
|
||||
c_ind = c_quant.to(dtype=torch.long)
|
||||
|
||||
info = None, None, c_ind
|
||||
return c_quant, None, info
|
||||
|
||||
def decode(self, c):
|
||||
c = c/self.n_embed
|
||||
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
||||
mode="nearest")
|
||||
return c
|
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
taken from: https://github.com/karpathy/minGPT/
|
||||
GPT model:
|
||||
- the initial stem consists of a combination of token encoding and a positional encoding
|
||||
- the meat of it is a uniform sequence of Transformer blocks
|
||||
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
||||
- all blocks feed into a central residual pathway similar to resnets
|
||||
- the final decoder is a linear projection into a vanilla Softmax classifier
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import top_k_top_p_filtering
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPTConfig:
|
||||
""" base GPT config, params common to all GPT versions """
|
||||
embd_pdrop = 0.1
|
||||
resid_pdrop = 0.1
|
||||
attn_pdrop = 0.1
|
||||
|
||||
def __init__(self, vocab_size, block_size, **kwargs):
|
||||
self.vocab_size = vocab_size
|
||||
self.block_size = block_size
|
||||
for k,v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class GPT1Config(GPTConfig):
|
||||
""" GPT-1 like network roughly 125M params """
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
n_embd = 768
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
||||
explicit implementation here to show that there is nothing too scary here.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
# key, query, value projections for all heads
|
||||
self.key = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.query = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.value = nn.Linear(config.n_embd, config.n_embd)
|
||||
# regularization
|
||||
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
||||
# output projection
|
||||
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||
mask = torch.tril(torch.ones(config.block_size,
|
||||
config.block_size))
|
||||
if hasattr(config, "n_unmasked"):
|
||||
mask[:config.n_unmasked, :config.n_unmasked] = 1
|
||||
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
||||
self.n_head = config.n_head
|
||||
|
||||
def forward(self, x, layer_past=None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
present = torch.stack((k, v))
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
if layer_past is None:
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
||||
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_drop(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_drop(self.proj(y))
|
||||
return y, present # TODO: check that this does not break anything
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
""" an unassuming Transformer block """
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(config.n_embd, 4 * config.n_embd),
|
||||
nn.GELU(), # nice
|
||||
nn.Linear(4 * config.n_embd, config.n_embd),
|
||||
nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
|
||||
def forward(self, x, layer_past=None, return_present=False):
|
||||
# TODO: check that training still works
|
||||
if return_present: assert not self.training
|
||||
# layer past: tuple of length two with B, nh, T, hs
|
||||
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
|
||||
|
||||
x = x + attn
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
if layer_past is not None or return_present:
|
||||
return x, present
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
""" the full GPT language model, with a context size of block_size """
|
||||
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
|
||||
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
||||
super().__init__()
|
||||
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
||||
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
||||
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
||||
n_unmasked=n_unmasked)
|
||||
# input embedding stem
|
||||
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
# transformer
|
||||
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||
# decoder head
|
||||
self.ln_f = nn.LayerNorm(config.n_embd)
|
||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.block_size = config.block_size
|
||||
self.apply(self._init_weights)
|
||||
self.config = config
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, idx, embeddings=None, targets=None):
|
||||
# forward the GPT model
|
||||
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
||||
|
||||
if embeddings is not None: # prepend explicit embeddings
|
||||
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
||||
|
||||
t = token_embeddings.shape[1]
|
||||
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.blocks(x)
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
|
||||
return logits, loss
|
||||
|
||||
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
|
||||
# inference only
|
||||
assert not self.training
|
||||
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
||||
if embeddings is not None: # prepend explicit embeddings
|
||||
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
||||
|
||||
if past is not None:
|
||||
assert past_length is not None
|
||||
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
|
||||
past_shape = list(past.shape)
|
||||
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
|
||||
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
|
||||
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
|
||||
else:
|
||||
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
|
||||
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
presents = [] # accumulate over layers
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
|
||||
presents.append(present)
|
||||
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
|
||||
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
|
||||
|
||||
|
||||
class DummyGPT(nn.Module):
|
||||
# for debugging
|
||||
def __init__(self, add_value=1):
|
||||
super().__init__()
|
||||
self.add_value = add_value
|
||||
|
||||
def forward(self, idx):
|
||||
return idx + self.add_value, None
|
||||
|
||||
|
||||
class CodeGPT(nn.Module):
|
||||
"""Takes in semi-embeddings"""
|
||||
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
|
||||
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
||||
super().__init__()
|
||||
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
||||
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
||||
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
||||
n_unmasked=n_unmasked)
|
||||
# input embedding stem
|
||||
self.tok_emb = nn.Linear(in_channels, config.n_embd)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
# transformer
|
||||
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||
# decoder head
|
||||
self.ln_f = nn.LayerNorm(config.n_embd)
|
||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.block_size = config.block_size
|
||||
self.apply(self._init_weights)
|
||||
self.config = config
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, idx, embeddings=None, targets=None):
|
||||
# forward the GPT model
|
||||
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
||||
|
||||
if embeddings is not None: # prepend explicit embeddings
|
||||
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
||||
|
||||
t = token_embeddings.shape[1]
|
||||
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.blocks(x)
|
||||
x = self.taming_cinln_f(x)
|
||||
logits = self.head(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
|
||||
return logits, loss
|
||||
|
||||
|
||||
|
||||
#### sampling utils
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
v, ix = torch.topk(logits, k)
|
||||
out = logits.clone()
|
||||
out[out < v[:, [-1]]] = -float('Inf')
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
||||
"""
|
||||
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
||||
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
||||
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
||||
of block_size, unlike an RNN that has an infinite context window.
|
||||
"""
|
||||
block_size = model.get_block_size()
|
||||
model.eval()
|
||||
for k in range(steps):
|
||||
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
||||
logits, _ = model(x_cond)
|
||||
# pluck the logits at the final step and scale by temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# append to the sequence and continue
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
|
||||
top_k=None, top_p=None, callback=None):
|
||||
# x is conditioning
|
||||
sample = x
|
||||
cond_len = x.shape[1]
|
||||
past = None
|
||||
for n in range(steps):
|
||||
if callback is not None:
|
||||
callback(n)
|
||||
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
|
||||
if past is None:
|
||||
past = [present]
|
||||
else:
|
||||
past.append(present)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
if top_k is not None:
|
||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
if not sample_logits:
|
||||
_, x = torch.topk(probs, k=1, dim=-1)
|
||||
else:
|
||||
x = torch.multinomial(probs, num_samples=1)
|
||||
# append to the sequence and continue
|
||||
sample = torch.cat((sample, x), dim=1)
|
||||
del past
|
||||
sample = sample[:, cond_len:] # cut conditioning off
|
||||
return sample
|
||||
|
||||
|
||||
#### clustering utils
|
||||
|
||||
class KMeans(nn.Module):
|
||||
def __init__(self, ncluster=512, nc=3, niter=10):
|
||||
super().__init__()
|
||||
self.ncluster = ncluster
|
||||
self.nc = nc
|
||||
self.niter = niter
|
||||
self.shape = (3,32,32)
|
||||
self.register_buffer("C", torch.zeros(self.ncluster,nc))
|
||||
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
||||
|
||||
def is_initialized(self):
|
||||
return self.initialized.item() == 1
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize(self, x):
|
||||
N, D = x.shape
|
||||
assert D == self.nc, D
|
||||
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
|
||||
for i in range(self.niter):
|
||||
# assign all pixels to the closest codebook element
|
||||
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
|
||||
# move each codebook element to be the mean of the pixels that assigned to it
|
||||
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
|
||||
# re-assign any poorly positioned codebook elements
|
||||
nanix = torch.any(torch.isnan(c), dim=1)
|
||||
ndead = nanix.sum().item()
|
||||
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
|
||||
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
|
||||
|
||||
self.C.copy_(c)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
|
||||
def forward(self, x, reverse=False, shape=None):
|
||||
if not reverse:
|
||||
# flatten
|
||||
bs,c,h,w = x.shape
|
||||
assert c == self.nc
|
||||
x = x.reshape(bs,c,h*w,1)
|
||||
C = self.C.permute(1,0)
|
||||
C = C.reshape(1,c,1,self.ncluster)
|
||||
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
|
||||
return a
|
||||
else:
|
||||
# flatten
|
||||
bs, HW = x.shape
|
||||
"""
|
||||
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
|
||||
c = c[bs*[0],:,:,:]
|
||||
c = c[:,:,HW*[0],:]
|
||||
x = x.reshape(bs, 1, HW, 1)
|
||||
x = x[:,3*[0],:,:]
|
||||
x = torch.gather(c, dim=3, index=x)
|
||||
"""
|
||||
x = self.C[x]
|
||||
x = x.permute(0,2,1)
|
||||
shape = shape if shape is not None else self.shape
|
||||
x = x.reshape(bs, *shape)
|
||||
|
||||
return x
|
@@ -0,0 +1,248 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractPermuter(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
def forward(self, x, reverse=False):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Identity(AbstractPermuter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
return x
|
||||
|
||||
|
||||
class Subsample(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
C = 1
|
||||
indices = np.arange(H*W).reshape(C,H,W)
|
||||
while min(H, W) > 1:
|
||||
indices = indices.reshape(C,H//2,2,W//2,2)
|
||||
indices = indices.transpose(0,2,4,1,3)
|
||||
indices = indices.reshape(C*4,H//2, W//2)
|
||||
H = H//2
|
||||
W = W//2
|
||||
C = C*4
|
||||
assert H == W == 1
|
||||
idx = torch.tensor(indices.ravel())
|
||||
self.register_buffer('forward_shuffle_idx',
|
||||
nn.Parameter(idx, requires_grad=False))
|
||||
self.register_buffer('backward_shuffle_idx',
|
||||
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
def mortonify(i, j):
|
||||
"""(i,j) index to linear morton code"""
|
||||
i = np.uint64(i)
|
||||
j = np.uint64(j)
|
||||
|
||||
z = np.uint(0)
|
||||
|
||||
for pos in range(32):
|
||||
z = (z |
|
||||
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
||||
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
class ZCurve(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
||||
idx = np.argsort(reverseidx)
|
||||
idx = torch.tensor(idx)
|
||||
reverseidx = torch.tensor(reverseidx)
|
||||
self.register_buffer('forward_shuffle_idx',
|
||||
idx)
|
||||
self.register_buffer('backward_shuffle_idx',
|
||||
reverseidx)
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class SpiralOut(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
assert H == W
|
||||
size = W
|
||||
indices = np.arange(size*size).reshape(size,size)
|
||||
|
||||
i0 = size//2
|
||||
j0 = size//2-1
|
||||
|
||||
i = i0
|
||||
j = j0
|
||||
|
||||
idx = [indices[i0, j0]]
|
||||
step_mult = 0
|
||||
for c in range(1, size//2+1):
|
||||
step_mult += 1
|
||||
# steps left
|
||||
for k in range(step_mult):
|
||||
i = i - 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step down
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
step_mult += 1
|
||||
if c < size//2:
|
||||
# step right
|
||||
for k in range(step_mult):
|
||||
i = i + 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step up
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j - 1
|
||||
idx.append(indices[i, j])
|
||||
else:
|
||||
# end reached
|
||||
for k in range(step_mult-1):
|
||||
i = i + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
assert len(idx) == size*size
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class SpiralIn(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
assert H == W
|
||||
size = W
|
||||
indices = np.arange(size*size).reshape(size,size)
|
||||
|
||||
i0 = size//2
|
||||
j0 = size//2-1
|
||||
|
||||
i = i0
|
||||
j = j0
|
||||
|
||||
idx = [indices[i0, j0]]
|
||||
step_mult = 0
|
||||
for c in range(1, size//2+1):
|
||||
step_mult += 1
|
||||
# steps left
|
||||
for k in range(step_mult):
|
||||
i = i - 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step down
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
step_mult += 1
|
||||
if c < size//2:
|
||||
# step right
|
||||
for k in range(step_mult):
|
||||
i = i + 1
|
||||
j = j
|
||||
idx.append(indices[i, j])
|
||||
|
||||
# step up
|
||||
for k in range(step_mult):
|
||||
i = i
|
||||
j = j - 1
|
||||
idx.append(indices[i, j])
|
||||
else:
|
||||
# end reached
|
||||
for k in range(step_mult-1):
|
||||
i = i + 1
|
||||
idx.append(indices[i, j])
|
||||
|
||||
assert len(idx) == size*size
|
||||
idx = idx[::-1]
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class Random(nn.Module):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
indices = np.random.RandomState(1).permutation(H*W)
|
||||
idx = torch.tensor(indices.ravel())
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
class AlternateParsing(AbstractPermuter):
|
||||
def __init__(self, H, W):
|
||||
super().__init__()
|
||||
indices = np.arange(W*H).reshape(H,W)
|
||||
for i in range(1, H, 2):
|
||||
indices[i, :] = indices[i, ::-1]
|
||||
idx = indices.flatten()
|
||||
assert len(idx) == H*W
|
||||
idx = torch.tensor(idx)
|
||||
self.register_buffer('forward_shuffle_idx', idx)
|
||||
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
||||
|
||||
def forward(self, x, reverse=False):
|
||||
if not reverse:
|
||||
return x[:, self.forward_shuffle_idx]
|
||||
else:
|
||||
return x[:, self.backward_shuffle_idx]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p0 = AlternateParsing(16, 16)
|
||||
print(p0.forward_shuffle_idx)
|
||||
print(p0.backward_shuffle_idx)
|
||||
|
||||
x = torch.randint(0, 768, size=(11, 256))
|
||||
y = p0(x)
|
||||
xre = p0(y, reverse=True)
|
||||
assert torch.equal(x, xre)
|
||||
|
||||
p1 = SpiralOut(2, 2)
|
||||
print(p1.forward_shuffle_idx)
|
||||
print(p1.backward_shuffle_idx)
|
130
OCR_earsing/latent_diffusion/taming/modules/util.py
Normal file
130
OCR_earsing/latent_diffusion/taming/modules/util.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_params(model):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
return total_params
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
def __init__(self, num_features, logdet=False, affine=True,
|
||||
allow_reverse_init=False):
|
||||
assert affine
|
||||
super().__init__()
|
||||
self.logdet = logdet
|
||||
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
||||
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
||||
self.allow_reverse_init = allow_reverse_init
|
||||
|
||||
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
||||
|
||||
def initialize(self, input):
|
||||
with torch.no_grad():
|
||||
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
||||
mean = (
|
||||
flatten.mean(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
std = (
|
||||
flatten.std(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
|
||||
self.loc.data.copy_(-mean)
|
||||
self.scale.data.copy_(1 / (std + 1e-6))
|
||||
|
||||
def forward(self, input, reverse=False):
|
||||
if reverse:
|
||||
return self.reverse(input)
|
||||
if len(input.shape) == 2:
|
||||
input = input[:,:,None,None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
_, _, height, width = input.shape
|
||||
|
||||
if self.training and self.initialized.item() == 0:
|
||||
self.initialize(input)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
h = self.scale * (input + self.loc)
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
|
||||
if self.logdet:
|
||||
log_abs = torch.log(torch.abs(self.scale))
|
||||
logdet = height*width*torch.sum(log_abs)
|
||||
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
||||
return h, logdet
|
||||
|
||||
return h
|
||||
|
||||
def reverse(self, output):
|
||||
if self.training and self.initialized.item() == 0:
|
||||
if not self.allow_reverse_init:
|
||||
raise RuntimeError(
|
||||
"Initializing ActNorm in reverse direction is "
|
||||
"disabled by default. Use allow_reverse_init=True to enable."
|
||||
)
|
||||
else:
|
||||
self.initialize(output)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
if len(output.shape) == 2:
|
||||
output = output[:,:,None,None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
h = output / self.scale - self.loc
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
return h
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Labelator(AbstractEncoder):
|
||||
"""Net2Net Interface for Class-Conditional Model"""
|
||||
def __init__(self, n_classes, quantize_interface=True):
|
||||
super().__init__()
|
||||
self.n_classes = n_classes
|
||||
self.quantize_interface = quantize_interface
|
||||
|
||||
def encode(self, c):
|
||||
c = c[:,None]
|
||||
if self.quantize_interface:
|
||||
return c, None, [None, None, c.long()]
|
||||
return c
|
||||
|
||||
|
||||
class SOSProvider(AbstractEncoder):
|
||||
# for unconditional training
|
||||
def __init__(self, sos_token, quantize_interface=True):
|
||||
super().__init__()
|
||||
self.sos_token = sos_token
|
||||
self.quantize_interface = quantize_interface
|
||||
|
||||
def encode(self, x):
|
||||
# get batch size from data and replicate sos_token
|
||||
c = torch.ones(x.shape[0], 1)*self.sos_token
|
||||
c = c.long().to(x.device)
|
||||
if self.quantize_interface:
|
||||
return c, None, [None, None, c]
|
||||
return c
|
445
OCR_earsing/latent_diffusion/taming/modules/vqvae/quantize.py
Normal file
445
OCR_earsing/latent_diffusion/taming/modules/vqvae/quantize.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch import einsum
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
||||
____________________________________________
|
||||
Discretization bottleneck part of the VQ-VAE.
|
||||
Inputs:
|
||||
- n_e : number of embeddings
|
||||
- e_dim : dimension of embedding
|
||||
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
_____________________________________________
|
||||
"""
|
||||
|
||||
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
||||
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
||||
# used wherever VectorQuantizer has been used before and is additionally
|
||||
# more efficient.
|
||||
def __init__(self, n_e, e_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Inputs the output of the encoder network z and maps it to a discrete
|
||||
one-hot vector that is the index of the closest embedding vector e_j
|
||||
z (continuous) -> z_q (discrete)
|
||||
z.shape = (batch, channel, height, width)
|
||||
quantization pipeline:
|
||||
1. get encoder input (B,C,H,W)
|
||||
2. flatten input to (B*H*W,C)
|
||||
"""
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
||||
torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
## could possible replace this here
|
||||
# #\start...
|
||||
# find closest encodings
|
||||
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
|
||||
min_encodings = torch.zeros(
|
||||
min_encoding_indices.shape[0], self.n_e).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# dtype min encodings: torch.float32
|
||||
# min_encodings shape: torch.Size([2048, 512])
|
||||
# min_encoding_indices.shape: torch.Size([2048, 1])
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
#.........\end
|
||||
|
||||
# with:
|
||||
# .........\start
|
||||
#min_encoding_indices = torch.argmin(d, dim=1)
|
||||
#z_q = self.embedding(min_encoding_indices)
|
||||
# ......\end......... (TODO)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
||||
torch.mean((z_q - z.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
# TODO: check for more easy handling with nn.Embedding
|
||||
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
||||
min_encodings.scatter_(1, indices[:,None], 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantize(nn.Module):
|
||||
"""
|
||||
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
||||
Gumbel Softmax trick quantizer
|
||||
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
||||
https://arxiv.org/abs/1611.01144
|
||||
"""
|
||||
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
||||
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
||||
remap=None, unknown_index="random"):
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.n_embed = n_embed
|
||||
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
|
||||
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
||||
self.embed = nn.Embedding(n_embed, embedding_dim)
|
||||
|
||||
self.use_vqinterface = use_vqinterface
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed+1
|
||||
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices.")
|
||||
else:
|
||||
self.re_embed = n_embed
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:,:,None]==used[None,None,...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2)<1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
||||
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, return_logits=False):
|
||||
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
||||
hard = self.straight_through if self.training else True
|
||||
temp = self.temperature if temp is None else temp
|
||||
|
||||
logits = self.proj(z)
|
||||
if self.remap is not None:
|
||||
# continue only with used logits
|
||||
full_zeros = torch.zeros_like(logits)
|
||||
logits = logits[:,self.used,...]
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
||||
if self.remap is not None:
|
||||
# go back to all entries but unused set to zero
|
||||
full_zeros[:,self.used,...] = soft_one_hot
|
||||
soft_one_hot = full_zeros
|
||||
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
||||
|
||||
ind = soft_one_hot.argmax(dim=1)
|
||||
if self.remap is not None:
|
||||
ind = self.remap_to_used(ind)
|
||||
if self.use_vqinterface:
|
||||
if return_logits:
|
||||
return z_q, diff, (None, None, ind), logits
|
||||
return z_q, diff, (None, None, ind)
|
||||
return z_q, diff, ind
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
b, h, w, c = shape
|
||||
assert b*h*w == indices.shape[0]
|
||||
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
||||
if self.remap is not None:
|
||||
indices = self.unmap_to_all(indices)
|
||||
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
||||
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
||||
return z_q
|
||||
|
||||
|
||||
class VectorQuantizer2(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
"""
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
||||
sane_index_shape=False, legacy=True):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed+1
|
||||
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices.")
|
||||
else:
|
||||
self.re_embed = n_e
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:,:,None]==used[None,None,...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2)<1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
||||
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
||||
assert return_logits==False, "Only for interface compatible with Gumbel"
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
||||
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
min_encodings = None
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
||||
torch.mean((z_q - z.detach()) ** 2)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
||||
torch.mean((z_q - z.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0],-1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
class EmbeddingEMA(nn.Module):
|
||||
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
weight = torch.randn(num_tokens, codebook_dim)
|
||||
self.weight = nn.Parameter(weight, requires_grad = False)
|
||||
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
||||
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
||||
self.update = True
|
||||
|
||||
def forward(self, embed_id):
|
||||
return F.embedding(embed_id, self.weight)
|
||||
|
||||
def cluster_size_ema_update(self, new_cluster_size):
|
||||
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
||||
|
||||
def embed_avg_ema_update(self, new_embed_avg):
|
||||
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
||||
|
||||
def weight_update(self, num_tokens):
|
||||
n = self.cluster_size.sum()
|
||||
smoothed_cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
||||
)
|
||||
#normalize embedding average with smoothed cluster size
|
||||
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
||||
self.weight.data.copy_(embed_normalized)
|
||||
|
||||
|
||||
class EMAVectorQuantizer(nn.Module):
|
||||
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
||||
remap=None, unknown_index="random"):
|
||||
super().__init__()
|
||||
self.codebook_dim = codebook_dim
|
||||
self.num_tokens = num_tokens
|
||||
self.beta = beta
|
||||
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed+1
|
||||
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices.")
|
||||
else:
|
||||
self.re_embed = n_embed
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:,:,None]==used[None,None,...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2)<1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape)>1
|
||||
inds = inds.reshape(ishape[0],-1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
||||
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
#z, 'b c h w -> b h w c'
|
||||
z = rearrange(z, 'b c h w -> b h w c')
|
||||
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
||||
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
||||
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
||||
|
||||
|
||||
encoding_indices = torch.argmin(d, dim=1)
|
||||
|
||||
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||
avg_probs = torch.mean(encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
||||
|
||||
if self.training and self.embedding.update:
|
||||
#EMA cluster size
|
||||
encodings_sum = encodings.sum(0)
|
||||
self.embedding.cluster_size_ema_update(encodings_sum)
|
||||
#EMA embedding average
|
||||
embed_sum = encodings.transpose(0,1) @ z_flattened
|
||||
self.embedding.embed_avg_ema_update(embed_sum)
|
||||
#normalize embed_avg and update weight
|
||||
self.embedding.weight_update(self.num_tokens)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
#z_q, 'b h w c -> b c h w'
|
||||
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
||||
return z_q, loss, (perplexity, encodings, encoding_indices)
|
157
OCR_earsing/latent_diffusion/taming/util.py
Normal file
157
OCR_earsing/latent_diffusion/taming/util.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os, hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
URL_MAP = {
|
||||
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
||||
}
|
||||
|
||||
CKPT_MAP = {
|
||||
"vgg_lpips": "vgg.pth"
|
||||
}
|
||||
|
||||
MD5_MAP = {
|
||||
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
||||
}
|
||||
|
||||
|
||||
def download(url, local_path, chunk_size=1024):
|
||||
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||||
with open(local_path, "wb") as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
|
||||
def md5_hash(path):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content).hexdigest()
|
||||
|
||||
|
||||
def get_ckpt_path(name, root, check=False):
|
||||
assert name in URL_MAP
|
||||
path = os.path.join(root, CKPT_MAP[name])
|
||||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||||
download(URL_MAP[name], path)
|
||||
md5 = md5_hash(path)
|
||||
assert md5 == MD5_MAP[name], md5
|
||||
return path
|
||||
|
||||
|
||||
class KeyNotFoundError(Exception):
|
||||
def __init__(self, cause, keys=None, visited=None):
|
||||
self.cause = cause
|
||||
self.keys = keys
|
||||
self.visited = visited
|
||||
messages = list()
|
||||
if keys is not None:
|
||||
messages.append("Key not found: {}".format(keys))
|
||||
if visited is not None:
|
||||
messages.append("Visited: {}".format(visited))
|
||||
messages.append("Cause:\n{}".format(cause))
|
||||
message = "\n".join(messages)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def retrieve(
|
||||
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
||||
):
|
||||
"""Given a nested list or dict return the desired value at key expanding
|
||||
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
||||
is done in-place.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
list_or_dict : list or dict
|
||||
Possibly nested list or dictionary.
|
||||
key : str
|
||||
key/to/value, path like string describing all keys necessary to
|
||||
consider to get to the desired value. List indices can also be
|
||||
passed here.
|
||||
splitval : str
|
||||
String that defines the delimiter between keys of the
|
||||
different depth levels in `key`.
|
||||
default : obj
|
||||
Value returned if :attr:`key` is not found.
|
||||
expand : bool
|
||||
Whether to expand callable nodes on the path or not.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The desired value or if :attr:`default` is not ``None`` and the
|
||||
:attr:`key` is not found returns ``default``.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
||||
``None``.
|
||||
"""
|
||||
|
||||
keys = key.split(splitval)
|
||||
|
||||
success = True
|
||||
try:
|
||||
visited = []
|
||||
parent = None
|
||||
last_key = None
|
||||
for key in keys:
|
||||
if callable(list_or_dict):
|
||||
if not expand:
|
||||
raise KeyNotFoundError(
|
||||
ValueError(
|
||||
"Trying to get past callable node with expand=False."
|
||||
),
|
||||
keys=keys,
|
||||
visited=visited,
|
||||
)
|
||||
list_or_dict = list_or_dict()
|
||||
parent[last_key] = list_or_dict
|
||||
|
||||
last_key = key
|
||||
parent = list_or_dict
|
||||
|
||||
try:
|
||||
if isinstance(list_or_dict, dict):
|
||||
list_or_dict = list_or_dict[key]
|
||||
else:
|
||||
list_or_dict = list_or_dict[int(key)]
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
||||
|
||||
visited += [key]
|
||||
# final expansion of retrieved value
|
||||
if expand and callable(list_or_dict):
|
||||
list_or_dict = list_or_dict()
|
||||
parent[last_key] = list_or_dict
|
||||
except KeyNotFoundError as e:
|
||||
if default is None:
|
||||
raise e
|
||||
else:
|
||||
list_or_dict = default
|
||||
success = False
|
||||
|
||||
if not pass_success:
|
||||
return list_or_dict
|
||||
else:
|
||||
return list_or_dict, success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {"keya": "a",
|
||||
"keyb": "b",
|
||||
"keyc":
|
||||
{"cc1": 1,
|
||||
"cc2": 2,
|
||||
}
|
||||
}
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.create(config)
|
||||
print(config)
|
||||
retrieve(config, "keya")
|
||||
|
Reference in New Issue
Block a user