First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo
This commit is contained in:
352
OCR_earsing/latent_diffusion/taming/models/cond_transformer.py
Normal file
352
OCR_earsing/latent_diffusion/taming/models/cond_transformer.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import os, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from main import instantiate_from_config
|
||||
from taming.modules.util import SOSProvider
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class Net2NetTransformer(pl.LightningModule):
|
||||
def __init__(self,
|
||||
transformer_config,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
permuter_config=None,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
first_stage_key="image",
|
||||
cond_stage_key="depth",
|
||||
downsample_cond_size=-1,
|
||||
pkeep=1.0,
|
||||
sos_token=0,
|
||||
unconditional=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.be_unconditional = unconditional
|
||||
self.sos_token = sos_token
|
||||
self.first_stage_key = first_stage_key
|
||||
self.cond_stage_key = cond_stage_key
|
||||
self.init_first_stage_from_ckpt(first_stage_config)
|
||||
self.init_cond_stage_from_ckpt(cond_stage_config)
|
||||
if permuter_config is None:
|
||||
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
||||
self.permuter = instantiate_from_config(config=permuter_config)
|
||||
self.transformer = instantiate_from_config(config=transformer_config)
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.downsample_cond_size = downsample_cond_size
|
||||
self.pkeep = pkeep
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
for k in sd.keys():
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
self.print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def init_first_stage_from_ckpt(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
self.first_stage_model = model
|
||||
|
||||
def init_cond_stage_from_ckpt(self, config):
|
||||
if config == "__is_first_stage__":
|
||||
print("Using first stage also as cond stage.")
|
||||
self.cond_stage_model = self.first_stage_model
|
||||
elif config == "__is_unconditional__" or self.be_unconditional:
|
||||
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
||||
f"Prepending {self.sos_token} as a sos token.")
|
||||
self.be_unconditional = True
|
||||
self.cond_stage_key = self.first_stage_key
|
||||
self.cond_stage_model = SOSProvider(self.sos_token)
|
||||
else:
|
||||
model = instantiate_from_config(config)
|
||||
model = model.eval()
|
||||
model.train = disabled_train
|
||||
self.cond_stage_model = model
|
||||
|
||||
def forward(self, x, c):
|
||||
# one step to produce the logits
|
||||
_, z_indices = self.encode_to_z(x)
|
||||
_, c_indices = self.encode_to_c(c)
|
||||
|
||||
if self.training and self.pkeep < 1.0:
|
||||
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
||||
device=z_indices.device))
|
||||
mask = mask.round().to(dtype=torch.int64)
|
||||
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
||||
a_indices = mask*z_indices+(1-mask)*r_indices
|
||||
else:
|
||||
a_indices = z_indices
|
||||
|
||||
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
||||
|
||||
# target includes all sequence elements (no need to handle first one
|
||||
# differently because we are conditioning)
|
||||
target = z_indices
|
||||
# make the prediction
|
||||
logits, _ = self.transformer(cz_indices[:, :-1])
|
||||
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
||||
logits = logits[:, c_indices.shape[1]-1:]
|
||||
|
||||
return logits, target
|
||||
|
||||
def top_k_logits(self, logits, k):
|
||||
v, ix = torch.topk(logits, k)
|
||||
out = logits.clone()
|
||||
out[out < v[..., [-1]]] = -float('Inf')
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
||||
callback=lambda k: None):
|
||||
x = torch.cat((c,x),dim=1)
|
||||
block_size = self.transformer.get_block_size()
|
||||
assert not self.transformer.training
|
||||
if self.pkeep <= 0.0:
|
||||
# one pass suffices since input is pure noise anyway
|
||||
assert len(x.shape)==2
|
||||
noise_shape = (x.shape[0], steps-1)
|
||||
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
||||
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
||||
x = torch.cat((x,noise),dim=1)
|
||||
logits, _ = self.transformer(x)
|
||||
# take all logits for now and scale by temp
|
||||
logits = logits / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = self.top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
shape = probs.shape
|
||||
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.reshape(shape[0],shape[1],shape[2])
|
||||
ix = ix.reshape(shape[0],shape[1])
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# cut off conditioning
|
||||
x = ix[:, c.shape[1]-1:]
|
||||
else:
|
||||
for k in range(steps):
|
||||
callback(k)
|
||||
assert x.size(1) <= block_size # make sure model can see conditioning
|
||||
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
||||
logits, _ = self.transformer(x_cond)
|
||||
# pluck the logits at the final step and scale by temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = self.top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# append to the sequence and continue
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
# cut off conditioning
|
||||
x = x[:, c.shape[1]:]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_to_z(self, x):
|
||||
quant_z, _, info = self.first_stage_model.encode(x)
|
||||
indices = info[2].view(quant_z.shape[0], -1)
|
||||
indices = self.permuter(indices)
|
||||
return quant_z, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_to_c(self, c):
|
||||
if self.downsample_cond_size > -1:
|
||||
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
||||
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
||||
if len(indices.shape) > 2:
|
||||
indices = indices.view(c.shape[0], -1)
|
||||
return quant_c, indices
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_to_img(self, index, zshape):
|
||||
index = self.permuter(index, reverse=True)
|
||||
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
||||
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
||||
index.reshape(-1), shape=bhwc)
|
||||
x = self.first_stage_model.decode(quant_z)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
||||
log = dict()
|
||||
|
||||
N = 4
|
||||
if lr_interface:
|
||||
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
||||
else:
|
||||
x, c = self.get_xc(batch, N)
|
||||
x = x.to(device=self.device)
|
||||
c = c.to(device=self.device)
|
||||
|
||||
quant_z, z_indices = self.encode_to_z(x)
|
||||
quant_c, c_indices = self.encode_to_c(c)
|
||||
|
||||
# create a "half"" sample
|
||||
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
sample=True,
|
||||
top_k=top_k if top_k is not None else 100,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# sample
|
||||
z_start_indices = z_indices[:, :0]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1],
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
sample=True,
|
||||
top_k=top_k if top_k is not None else 100,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# det sample
|
||||
z_start_indices = z_indices[:, :0]
|
||||
index_sample = self.sample(z_start_indices, c_indices,
|
||||
steps=z_indices.shape[1],
|
||||
sample=False,
|
||||
callback=callback if callback is not None else lambda k: None)
|
||||
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
||||
|
||||
# reconstruction
|
||||
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
||||
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = x_rec
|
||||
|
||||
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
||||
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
||||
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
||||
label_for_category_no = dataset.get_textual_label_for_category_no
|
||||
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
||||
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
||||
for i in range(quant_c.shape[0]):
|
||||
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
||||
log["conditioning_rec"] = log["conditioning"]
|
||||
elif self.cond_stage_key != "image":
|
||||
cond_rec = self.cond_stage_model.decode(quant_c)
|
||||
if self.cond_stage_key == "segmentation":
|
||||
# get image from segmentation mask
|
||||
num_classes = cond_rec.shape[1]
|
||||
|
||||
c = torch.argmax(c, dim=1, keepdim=True)
|
||||
c = F.one_hot(c, num_classes=num_classes)
|
||||
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
c = self.cond_stage_model.to_rgb(c)
|
||||
|
||||
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
||||
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
||||
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
||||
log["conditioning_rec"] = cond_rec
|
||||
log["conditioning"] = c
|
||||
|
||||
log["samples_half"] = x_sample
|
||||
log["samples_nopix"] = x_sample_nopix
|
||||
log["samples_det"] = x_sample_det
|
||||
return log
|
||||
|
||||
def get_input(self, key, batch):
|
||||
x = batch[key]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
if len(x.shape) == 4:
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
||||
if x.dtype == torch.double:
|
||||
x = x.float()
|
||||
return x
|
||||
|
||||
def get_xc(self, batch, N=None):
|
||||
x = self.get_input(self.first_stage_key, batch)
|
||||
c = self.get_input(self.cond_stage_key, batch)
|
||||
if N is not None:
|
||||
x = x[:N]
|
||||
c = c[:N]
|
||||
return x, c
|
||||
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, c = self.get_xc(batch)
|
||||
logits, target = self(x, c)
|
||||
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Following minGPT:
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, )
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.transformer.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith('bias'):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# special case the position embedding parameter in the root GPT module as not decayed
|
||||
no_decay.add('pos_emb')
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||
% (str(param_dict.keys() - union_params), )
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
||||
return optimizer
|
@@ -0,0 +1,22 @@
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DummyCondStage:
|
||||
def __init__(self, conditional_key):
|
||||
self.conditional_key = conditional_key
|
||||
self.train = None
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def encode(c: Tensor):
|
||||
return c, None, (None, None, c)
|
||||
|
||||
@staticmethod
|
||||
def decode(c: Tensor):
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def to_rgb(c: Tensor):
|
||||
return c
|
404
OCR_earsing/latent_diffusion/taming/models/vqgan.py
Normal file
404
OCR_earsing/latent_diffusion/taming/models/vqgan.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from main import instantiate_from_config
|
||||
|
||||
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from taming.modules.vqvae.quantize import GumbelQuantize
|
||||
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap, sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.image_key = image_key
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input):
|
||||
quant, diff, _ = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
||||
return x.float()
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
self.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
self.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQSegmentationModel(VQModel):
|
||||
def __init__(self, n_labels, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return opt_ae
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
total_loss = log_dict_ae["val/total_loss"]
|
||||
self.log("val/total_loss", total_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
||||
return aeloss
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
# convert logits to indices
|
||||
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
||||
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
||||
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
return log
|
||||
|
||||
|
||||
class VQNoDiscModel(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None
|
||||
):
|
||||
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
||||
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
||||
output = pl.TrainResult(minimize=aeloss)
|
||||
output.log("train/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return output
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
output = pl.EvalResult(checkpoint_on=rec_loss)
|
||||
output.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
output.log_dict(log_dict_ae)
|
||||
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=self.learning_rate, betas=(0.5, 0.9))
|
||||
return optimizer
|
||||
|
||||
|
||||
class GumbelVQ(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
temperature_scheduler_config,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
kl_weight=1e-8,
|
||||
remap=None,
|
||||
):
|
||||
|
||||
z_channels = ddconfig["z_channels"]
|
||||
super().__init__(ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=ignore_keys,
|
||||
image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels,
|
||||
monitor=monitor,
|
||||
)
|
||||
|
||||
self.loss.n_classes = n_embed
|
||||
self.vocab_size = n_embed
|
||||
|
||||
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
||||
n_embed=n_embed,
|
||||
kl_weight=kl_weight, temp_init=1.0,
|
||||
remap=remap)
|
||||
|
||||
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def temperature_scheduling(self):
|
||||
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode_code(self, code_b):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
self.temperature_scheduling()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
rec_loss = log_dict_ae["val/rec_loss"]
|
||||
self.log("val/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log("val/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def log_images(self, batch, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
# encode
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, _, _ = self.quantize(h)
|
||||
# decode
|
||||
x_rec = self.decode(quant)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = x_rec
|
||||
return log
|
||||
|
||||
|
||||
class EMAVQ(VQModel):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
):
|
||||
super().__init__(ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=ignore_keys,
|
||||
image_key=image_key,
|
||||
colorize_nlabels=colorize_nlabels,
|
||||
monitor=monitor,
|
||||
)
|
||||
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
||||
embedding_dim=embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap)
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
#Remove self.quantize from parameter list since it is updated via EMA
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
Reference in New Issue
Block a user