First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo

This commit is contained in:
2024-09-16 19:11:05 +07:00
parent d27ba1890b
commit 17aa33a17a
72 changed files with 12419 additions and 1 deletions

View File

@@ -0,0 +1,352 @@
import os, math
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from main import instantiate_from_config
from taming.modules.util import SOSProvider
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class Net2NetTransformer(pl.LightningModule):
def __init__(self,
transformer_config,
first_stage_config,
cond_stage_config,
permuter_config=None,
ckpt_path=None,
ignore_keys=[],
first_stage_key="image",
cond_stage_key="depth",
downsample_cond_size=-1,
pkeep=1.0,
sos_token=0,
unconditional=False,
):
super().__init__()
self.be_unconditional = unconditional
self.sos_token = sos_token
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.init_first_stage_from_ckpt(first_stage_config)
self.init_cond_stage_from_ckpt(cond_stage_config)
if permuter_config is None:
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
self.permuter = instantiate_from_config(config=permuter_config)
self.transformer = instantiate_from_config(config=transformer_config)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.downsample_cond_size = downsample_cond_size
self.pkeep = pkeep
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
for k in sd.keys():
for ik in ignore_keys:
if k.startswith(ik):
self.print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def init_first_stage_from_ckpt(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.first_stage_model = model
def init_cond_stage_from_ckpt(self, config):
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__" or self.be_unconditional:
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
f"Prepending {self.sos_token} as a sos token.")
self.be_unconditional = True
self.cond_stage_key = self.first_stage_key
self.cond_stage_model = SOSProvider(self.sos_token)
else:
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.cond_stage_model = model
def forward(self, x, c):
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)
if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices
cz_indices = torch.cat((c_indices, a_indices), dim=1)
# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]
return logits, target
def top_k_logits(self, logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[..., [-1]]] = -float('Inf')
return out
@torch.no_grad()
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
callback=lambda k: None):
x = torch.cat((c,x),dim=1)
block_size = self.transformer.get_block_size()
assert not self.transformer.training
if self.pkeep <= 0.0:
# one pass suffices since input is pure noise anyway
assert len(x.shape)==2
noise_shape = (x.shape[0], steps-1)
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
x = torch.cat((x,noise),dim=1)
logits, _ = self.transformer(x)
# take all logits for now and scale by temp
logits = logits / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = self.top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
shape = probs.shape
probs = probs.reshape(shape[0]*shape[1],shape[2])
ix = torch.multinomial(probs, num_samples=1)
probs = probs.reshape(shape[0],shape[1],shape[2])
ix = ix.reshape(shape[0],shape[1])
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# cut off conditioning
x = ix[:, c.shape[1]-1:]
else:
for k in range(steps):
callback(k)
assert x.size(1) <= block_size # make sure model can see conditioning
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = self.transformer(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = self.top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
# cut off conditioning
x = x[:, c.shape[1]:]
return x
@torch.no_grad()
def encode_to_z(self, x):
quant_z, _, info = self.first_stage_model.encode(x)
indices = info[2].view(quant_z.shape[0], -1)
indices = self.permuter(indices)
return quant_z, indices
@torch.no_grad()
def encode_to_c(self, c):
if self.downsample_cond_size > -1:
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
if len(indices.shape) > 2:
indices = indices.view(c.shape[0], -1)
return quant_c, indices
@torch.no_grad()
def decode_to_img(self, index, zshape):
index = self.permuter(index, reverse=True)
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
quant_z = self.first_stage_model.quantize.get_codebook_entry(
index.reshape(-1), shape=bhwc)
x = self.first_stage_model.decode(quant_z)
return x
@torch.no_grad()
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
log = dict()
N = 4
if lr_interface:
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
else:
x, c = self.get_xc(batch, N)
x = x.to(device=self.device)
c = c.to(device=self.device)
quant_z, z_indices = self.encode_to_z(x)
quant_c, c_indices = self.encode_to_c(c)
# create a "half"" sample
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
index_sample = self.sample(z_start_indices, c_indices,
steps=z_indices.shape[1]-z_start_indices.shape[1],
temperature=temperature if temperature is not None else 1.0,
sample=True,
top_k=top_k if top_k is not None else 100,
callback=callback if callback is not None else lambda k: None)
x_sample = self.decode_to_img(index_sample, quant_z.shape)
# sample
z_start_indices = z_indices[:, :0]
index_sample = self.sample(z_start_indices, c_indices,
steps=z_indices.shape[1],
temperature=temperature if temperature is not None else 1.0,
sample=True,
top_k=top_k if top_k is not None else 100,
callback=callback if callback is not None else lambda k: None)
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
# det sample
z_start_indices = z_indices[:, :0]
index_sample = self.sample(z_start_indices, c_indices,
steps=z_indices.shape[1],
sample=False,
callback=callback if callback is not None else lambda k: None)
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
# reconstruction
x_rec = self.decode_to_img(z_indices, quant_z.shape)
log["inputs"] = x
log["reconstructions"] = x_rec
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
figure_size = (x_rec.shape[2], x_rec.shape[3])
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
label_for_category_no = dataset.get_textual_label_for_category_no
plotter = dataset.conditional_builders[self.cond_stage_key].plot
log["conditioning"] = torch.zeros_like(log["reconstructions"])
for i in range(quant_c.shape[0]):
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
log["conditioning_rec"] = log["conditioning"]
elif self.cond_stage_key != "image":
cond_rec = self.cond_stage_model.decode(quant_c)
if self.cond_stage_key == "segmentation":
# get image from segmentation mask
num_classes = cond_rec.shape[1]
c = torch.argmax(c, dim=1, keepdim=True)
c = F.one_hot(c, num_classes=num_classes)
c = c.squeeze(1).permute(0, 3, 1, 2).float()
c = self.cond_stage_model.to_rgb(c)
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
log["conditioning_rec"] = cond_rec
log["conditioning"] = c
log["samples_half"] = x_sample
log["samples_nopix"] = x_sample_nopix
log["samples_det"] = x_sample_det
return log
def get_input(self, key, batch):
x = batch[key]
if len(x.shape) == 3:
x = x[..., None]
if len(x.shape) == 4:
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
if x.dtype == torch.double:
x = x.float()
return x
def get_xc(self, batch, N=None):
x = self.get_input(self.first_stage_key, batch)
c = self.get_input(self.cond_stage_key, batch)
if N is not None:
x = x[:N]
c = c[:N]
return x, c
def shared_step(self, batch, batch_idx):
x, c = self.get_xc(batch)
logits, target = self(x, c)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
return loss
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
"""
Following minGPT:
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.transformer.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
return optimizer

View File

@@ -0,0 +1,22 @@
from torch import Tensor
class DummyCondStage:
def __init__(self, conditional_key):
self.conditional_key = conditional_key
self.train = None
def eval(self):
return self
@staticmethod
def encode(c: Tensor):
return c, None, (None, None, c)
@staticmethod
def decode(c: Tensor):
return c
@staticmethod
def to_rgb(c: Tensor):
return c

View File

@@ -0,0 +1,404 @@
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from main import instantiate_from_config
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
return x.float()
def training_step(self, batch, batch_idx, optimizer_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
rec_loss = log_dict_ae["val/rec_loss"]
self.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQSegmentationModel(VQModel):
def __init__(self, n_labels, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
return opt_ae
def training_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
total_loss = log_dict_ae["val/total_loss"]
self.log("val/total_loss", total_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return aeloss
@torch.no_grad()
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
# convert logits to indices
xrec = torch.argmax(xrec, dim=1, keepdim=True)
xrec = F.one_hot(xrec, num_classes=x.shape[1])
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
return log
class VQNoDiscModel(VQModel):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None
):
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
colorize_nlabels=colorize_nlabels)
def training_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
output = pl.TrainResult(minimize=aeloss)
output.log("train/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return output
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
rec_loss = log_dict_ae["val/rec_loss"]
output = pl.EvalResult(checkpoint_on=rec_loss)
output.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=True, on_epoch=True)
output.log_dict(log_dict_ae)
return output
def configure_optimizers(self):
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=self.learning_rate, betas=(0.5, 0.9))
return optimizer
class GumbelVQ(VQModel):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
temperature_scheduler_config,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
kl_weight=1e-8,
remap=None,
):
z_channels = ddconfig["z_channels"]
super().__init__(ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=ignore_keys,
image_key=image_key,
colorize_nlabels=colorize_nlabels,
monitor=monitor,
)
self.loss.n_classes = n_embed
self.vocab_size = n_embed
self.quantize = GumbelQuantize(z_channels, embed_dim,
n_embed=n_embed,
kl_weight=kl_weight, temp_init=1.0,
remap=remap)
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def temperature_scheduling(self):
self.quantize.temperature = self.temperature_scheduler(self.global_step)
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode_code(self, code_b):
raise NotImplementedError
def training_step(self, batch, batch_idx, optimizer_idx):
self.temperature_scheduling()
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
x = self.get_input(batch, self.image_key)
xrec, qloss = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
rec_loss = log_dict_ae["val/rec_loss"]
self.log("val/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log("val/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
# encode
h = self.encoder(x)
h = self.quant_conv(h)
quant, _, _ = self.quantize(h)
# decode
x_rec = self.decode(quant)
log["inputs"] = x
log["reconstructions"] = x_rec
return log
class EMAVQ(VQModel):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__(ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=ignore_keys,
image_key=image_key,
colorize_nlabels=colorize_nlabels,
monitor=monitor,
)
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
embedding_dim=embed_dim,
beta=0.25,
remap=remap)
def configure_optimizers(self):
lr = self.learning_rate
#Remove self.quantize from parameter list since it is updated via EMA
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []

View File

@@ -0,0 +1,776 @@
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True):
super().__init__()
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, t=None):
#assert x.shape[2] == x.shape[3] == self.resolution
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, **ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class VUNet(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
in_channels, c_channels,
resolution, z_channels, use_timestep=False, **ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(c_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
self.z_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=1,
stride=1,
padding=0)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, z):
#assert x.shape[2] == x.shape[3] == self.resolution
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
z = self.z_in(z)
h = torch.cat((h,z),dim=1)
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h

View File

@@ -0,0 +1,67 @@
import functools
import torch.nn as nn
from taming.modules.util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)

View File

@@ -0,0 +1,2 @@
from taming.modules.losses.vqperceptual import DummyLoss

View File

@@ -0,0 +1,123 @@
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from taming.util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
""" A single linear layer which does a 1x1 conv """
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
return x/(norm_factor+eps)
def spatial_average(x, keepdim=True):
return x.mean([2,3],keepdim=keepdim)

View File

@@ -0,0 +1,22 @@
import torch.nn as nn
import torch.nn.functional as F
class BCELoss(nn.Module):
def forward(self, prediction, target):
loss = F.binary_cross_entropy_with_logits(prediction,target)
return loss, {}
class BCELossWithQuant(nn.Module):
def __init__(self, codebook_weight=1.):
super().__init__()
self.codebook_weight = codebook_weight
def forward(self, qloss, target, prediction, split):
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
loss = bce_loss + self.codebook_weight*qloss
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/bce_loss".format(split): bce_loss.detach().mean(),
"{}/quant_loss".format(split): qloss.detach().mean()
}

View File

@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
class DummyLoss(nn.Module):
def __init__(self):
super().__init__()
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge"):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train"):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log

View File

@@ -0,0 +1,31 @@
import torch
class CoordStage(object):
def __init__(self, n_embed, down_factor):
self.n_embed = n_embed
self.down_factor = down_factor
def eval(self):
return self
def encode(self, c):
"""fake vqmodel interface"""
assert 0.0 <= c.min() and c.max() <= 1.0
b,ch,h,w = c.shape
assert ch == 1
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
mode="area")
c = c.clamp(0.0, 1.0)
c = self.n_embed*c
c_quant = c.round()
c_ind = c_quant.to(dtype=torch.long)
info = None, None, c_ind
return c_quant, None, info
def decode(self, c):
c = c/self.n_embed
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
mode="nearest")
return c

View File

@@ -0,0 +1,415 @@
"""
taken from: https://github.com/karpathy/minGPT/
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import top_k_top_p_filtering
logger = logging.getLogger(__name__)
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k,v in kwargs.items():
setattr(self, k, v)
class GPT1Config(GPTConfig):
""" GPT-1 like network roughly 125M params """
n_layer = 12
n_head = 12
n_embd = 768
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
mask = torch.tril(torch.ones(config.block_size,
config.block_size))
if hasattr(config, "n_unmasked"):
mask[:config.n_unmasked, :config.n_unmasked] = 1
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
present = torch.stack((k, v))
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
if layer_past is None:
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y, present # TODO: check that this does not break anything
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present: assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
super().__init__()
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
n_unmasked=n_unmasked)
# input embedding stem
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
self.apply(self._init_weights)
self.config = config
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
# inference only
assert not self.training
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
if past is not None:
assert past_length is not None
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
past_shape = list(past.shape)
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
else:
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
x = self.drop(token_embeddings + position_embeddings)
presents = [] # accumulate over layers
for i, block in enumerate(self.blocks):
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
presents.append(present)
x = self.ln_f(x)
logits = self.head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
class DummyGPT(nn.Module):
# for debugging
def __init__(self, add_value=1):
super().__init__()
self.add_value = add_value
def forward(self, idx):
return idx + self.add_value, None
class CodeGPT(nn.Module):
"""Takes in semi-embeddings"""
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
super().__init__()
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
n_unmasked=n_unmasked)
# input embedding stem
self.tok_emb = nn.Linear(in_channels, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
self.apply(self._init_weights)
self.config = config
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.taming_cinln_f(x)
logits = self.head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
#### sampling utils
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
return x
@torch.no_grad()
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
top_k=None, top_p=None, callback=None):
# x is conditioning
sample = x
cond_len = x.shape[1]
past = None
for n in range(steps):
if callback is not None:
callback(n)
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
if past is None:
past = [present]
else:
past.append(present)
logits = logits[:, -1, :] / temperature
if top_k is not None:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if not sample_logits:
_, x = torch.topk(probs, k=1, dim=-1)
else:
x = torch.multinomial(probs, num_samples=1)
# append to the sequence and continue
sample = torch.cat((sample, x), dim=1)
del past
sample = sample[:, cond_len:] # cut conditioning off
return sample
#### clustering utils
class KMeans(nn.Module):
def __init__(self, ncluster=512, nc=3, niter=10):
super().__init__()
self.ncluster = ncluster
self.nc = nc
self.niter = niter
self.shape = (3,32,32)
self.register_buffer("C", torch.zeros(self.ncluster,nc))
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
def is_initialized(self):
return self.initialized.item() == 1
@torch.no_grad()
def initialize(self, x):
N, D = x.shape
assert D == self.nc, D
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
for i in range(self.niter):
# assign all pixels to the closest codebook element
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
# move each codebook element to be the mean of the pixels that assigned to it
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
# re-assign any poorly positioned codebook elements
nanix = torch.any(torch.isnan(c), dim=1)
ndead = nanix.sum().item()
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
self.C.copy_(c)
self.initialized.fill_(1)
def forward(self, x, reverse=False, shape=None):
if not reverse:
# flatten
bs,c,h,w = x.shape
assert c == self.nc
x = x.reshape(bs,c,h*w,1)
C = self.C.permute(1,0)
C = C.reshape(1,c,1,self.ncluster)
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
return a
else:
# flatten
bs, HW = x.shape
"""
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
c = c[bs*[0],:,:,:]
c = c[:,:,HW*[0],:]
x = x.reshape(bs, 1, HW, 1)
x = x[:,3*[0],:,:]
x = torch.gather(c, dim=3, index=x)
"""
x = self.C[x]
x = x.permute(0,2,1)
shape = shape if shape is not None else self.shape
x = x.reshape(bs, *shape)
return x

View File

@@ -0,0 +1,248 @@
import torch
import torch.nn as nn
import numpy as np
class AbstractPermuter(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, reverse=False):
raise NotImplementedError
class Identity(AbstractPermuter):
def __init__(self):
super().__init__()
def forward(self, x, reverse=False):
return x
class Subsample(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
C = 1
indices = np.arange(H*W).reshape(C,H,W)
while min(H, W) > 1:
indices = indices.reshape(C,H//2,2,W//2,2)
indices = indices.transpose(0,2,4,1,3)
indices = indices.reshape(C*4,H//2, W//2)
H = H//2
W = W//2
C = C*4
assert H == W == 1
idx = torch.tensor(indices.ravel())
self.register_buffer('forward_shuffle_idx',
nn.Parameter(idx, requires_grad=False))
self.register_buffer('backward_shuffle_idx',
nn.Parameter(torch.argsort(idx), requires_grad=False))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
def mortonify(i, j):
"""(i,j) index to linear morton code"""
i = np.uint64(i)
j = np.uint64(j)
z = np.uint(0)
for pos in range(32):
z = (z |
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
)
return z
class ZCurve(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
idx = np.argsort(reverseidx)
idx = torch.tensor(idx)
reverseidx = torch.tensor(reverseidx)
self.register_buffer('forward_shuffle_idx',
idx)
self.register_buffer('backward_shuffle_idx',
reverseidx)
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class SpiralOut(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
assert H == W
size = W
indices = np.arange(size*size).reshape(size,size)
i0 = size//2
j0 = size//2-1
i = i0
j = j0
idx = [indices[i0, j0]]
step_mult = 0
for c in range(1, size//2+1):
step_mult += 1
# steps left
for k in range(step_mult):
i = i - 1
j = j
idx.append(indices[i, j])
# step down
for k in range(step_mult):
i = i
j = j + 1
idx.append(indices[i, j])
step_mult += 1
if c < size//2:
# step right
for k in range(step_mult):
i = i + 1
j = j
idx.append(indices[i, j])
# step up
for k in range(step_mult):
i = i
j = j - 1
idx.append(indices[i, j])
else:
# end reached
for k in range(step_mult-1):
i = i + 1
idx.append(indices[i, j])
assert len(idx) == size*size
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class SpiralIn(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
assert H == W
size = W
indices = np.arange(size*size).reshape(size,size)
i0 = size//2
j0 = size//2-1
i = i0
j = j0
idx = [indices[i0, j0]]
step_mult = 0
for c in range(1, size//2+1):
step_mult += 1
# steps left
for k in range(step_mult):
i = i - 1
j = j
idx.append(indices[i, j])
# step down
for k in range(step_mult):
i = i
j = j + 1
idx.append(indices[i, j])
step_mult += 1
if c < size//2:
# step right
for k in range(step_mult):
i = i + 1
j = j
idx.append(indices[i, j])
# step up
for k in range(step_mult):
i = i
j = j - 1
idx.append(indices[i, j])
else:
# end reached
for k in range(step_mult-1):
i = i + 1
idx.append(indices[i, j])
assert len(idx) == size*size
idx = idx[::-1]
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class Random(nn.Module):
def __init__(self, H, W):
super().__init__()
indices = np.random.RandomState(1).permutation(H*W)
idx = torch.tensor(indices.ravel())
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class AlternateParsing(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
indices = np.arange(W*H).reshape(H,W)
for i in range(1, H, 2):
indices[i, :] = indices[i, ::-1]
idx = indices.flatten()
assert len(idx) == H*W
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
if __name__ == "__main__":
p0 = AlternateParsing(16, 16)
print(p0.forward_shuffle_idx)
print(p0.backward_shuffle_idx)
x = torch.randint(0, 768, size=(11, 256))
y = p0(x)
xre = p0(y, reverse=True)
assert torch.equal(x, xre)
p1 = SpiralOut(2, 2)
print(p1.forward_shuffle_idx)
print(p1.backward_shuffle_idx)

View File

@@ -0,0 +1,130 @@
import torch
import torch.nn as nn
def count_params(model):
total_params = sum(p.numel() for p in model.parameters())
return total_params
class ActNorm(nn.Module):
def __init__(self, num_features, logdet=False, affine=True,
allow_reverse_init=False):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:,:,None,None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height*width*torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:,:,None,None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class Labelator(AbstractEncoder):
"""Net2Net Interface for Class-Conditional Model"""
def __init__(self, n_classes, quantize_interface=True):
super().__init__()
self.n_classes = n_classes
self.quantize_interface = quantize_interface
def encode(self, c):
c = c[:,None]
if self.quantize_interface:
return c, None, [None, None, c.long()]
return c
class SOSProvider(AbstractEncoder):
# for unconditional training
def __init__(self, sos_token, quantize_interface=True):
super().__init__()
self.sos_token = sos_token
self.quantize_interface = quantize_interface
def encode(self, x):
# get batch size from data and replicate sos_token
c = torch.ones(x.shape[0], 1)*self.sos_token
c = c.long().to(x.device)
if self.quantize_interface:
return c, None, [None, None, c]
return c

View File

@@ -0,0 +1,445 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
# used wherever VectorQuantizer has been used before and is additionally
# more efficient.
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
quantization pipeline:
1. get encoder input (B,C,H,W)
2. flatten input to (B*H*W,C)
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end
# with:
# .........\start
#min_encoding_indices = torch.argmin(d, dim=1)
#z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:,None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantize(nn.Module):
"""
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
remap=None, unknown_index="random"):
super().__init__()
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.use_vqinterface = use_vqinterface
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed+1
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_embed
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
match = (inds[:,:,None]==used[None,None,...]).long()
new = match.argmax(-1)
unknown = match.sum(2)<1
if self.unknown_index == "random":
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, return_logits=False):
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:,self.used,...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:,self.used,...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
ind = soft_one_hot.argmax(dim=1)
if self.remap is not None:
ind = self.remap_to_used(ind)
if self.use_vqinterface:
if return_logits:
return z_q, diff, (None, None, ind), logits
return z_q, diff, (None, None, ind)
return z_q, diff, ind
def get_codebook_entry(self, indices, shape):
b, h, w, c = shape
assert b*h*w == indices.shape[0]
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
return z_q
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed+1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
match = (inds[:,:,None]==used[None,None,...]).long()
new = match.argmax(-1)
unknown = match.sum(2)<1
if self.unknown_index == "random":
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
assert rescale_logits==False, "Only for interface compatible with Gumbel"
assert return_logits==False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0],-1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
super().__init__()
self.decay = decay
self.eps = eps
weight = torch.randn(num_tokens, codebook_dim)
self.weight = nn.Parameter(weight, requires_grad = False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
self.update = True
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
)
#normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
self.weight.data.copy_(embed_normalized)
class EMAVectorQuantizer(nn.Module):
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
remap=None, unknown_index="random"):
super().__init__()
self.codebook_dim = codebook_dim
self.num_tokens = num_tokens
self.beta = beta
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed+1
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_embed
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
match = (inds[:,:,None]==used[None,None,...]).long()
new = match.argmax(-1)
unknown = match.sum(2)<1
if self.unknown_index == "random":
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape)>1
inds = inds.reshape(ishape[0],-1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
return back.reshape(ishape)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
#z, 'b c h w -> b h w c'
z = rearrange(z, 'b c h w -> b h w c')
z_flattened = z.reshape(-1, self.codebook_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
if self.training and self.embedding.update:
#EMA cluster size
encodings_sum = encodings.sum(0)
self.embedding.cluster_size_ema_update(encodings_sum)
#EMA embedding average
embed_sum = encodings.transpose(0,1) @ z_flattened
self.embedding.embed_avg_ema_update(embed_sum)
#normalize embed_avg and update weight
self.embedding.weight_update(self.num_tokens)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
#z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, (perplexity, encodings, encoding_indices)

View File

@@ -0,0 +1,157 @@
import os, hashlib
import requests
from tqdm import tqdm
URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}
CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}
MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success
if __name__ == "__main__":
config = {"keya": "a",
"keyb": "b",
"keyc":
{"cc1": 1,
"cc2": 2,
}
}
from omegaconf import OmegaConf
config = OmegaConf.create(config)
print(config)
retrieve(config, "keya")