First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo
This commit is contained in:
130
OCR_earsing/latent_diffusion/taming/modules/util.py
Normal file
130
OCR_earsing/latent_diffusion/taming/modules/util.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_params(model):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
return total_params
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
def __init__(self, num_features, logdet=False, affine=True,
|
||||
allow_reverse_init=False):
|
||||
assert affine
|
||||
super().__init__()
|
||||
self.logdet = logdet
|
||||
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
||||
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
||||
self.allow_reverse_init = allow_reverse_init
|
||||
|
||||
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
||||
|
||||
def initialize(self, input):
|
||||
with torch.no_grad():
|
||||
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
||||
mean = (
|
||||
flatten.mean(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
std = (
|
||||
flatten.std(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
|
||||
self.loc.data.copy_(-mean)
|
||||
self.scale.data.copy_(1 / (std + 1e-6))
|
||||
|
||||
def forward(self, input, reverse=False):
|
||||
if reverse:
|
||||
return self.reverse(input)
|
||||
if len(input.shape) == 2:
|
||||
input = input[:,:,None,None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
_, _, height, width = input.shape
|
||||
|
||||
if self.training and self.initialized.item() == 0:
|
||||
self.initialize(input)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
h = self.scale * (input + self.loc)
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
|
||||
if self.logdet:
|
||||
log_abs = torch.log(torch.abs(self.scale))
|
||||
logdet = height*width*torch.sum(log_abs)
|
||||
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
||||
return h, logdet
|
||||
|
||||
return h
|
||||
|
||||
def reverse(self, output):
|
||||
if self.training and self.initialized.item() == 0:
|
||||
if not self.allow_reverse_init:
|
||||
raise RuntimeError(
|
||||
"Initializing ActNorm in reverse direction is "
|
||||
"disabled by default. Use allow_reverse_init=True to enable."
|
||||
)
|
||||
else:
|
||||
self.initialize(output)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
if len(output.shape) == 2:
|
||||
output = output[:,:,None,None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
h = output / self.scale - self.loc
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
return h
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Labelator(AbstractEncoder):
|
||||
"""Net2Net Interface for Class-Conditional Model"""
|
||||
def __init__(self, n_classes, quantize_interface=True):
|
||||
super().__init__()
|
||||
self.n_classes = n_classes
|
||||
self.quantize_interface = quantize_interface
|
||||
|
||||
def encode(self, c):
|
||||
c = c[:,None]
|
||||
if self.quantize_interface:
|
||||
return c, None, [None, None, c.long()]
|
||||
return c
|
||||
|
||||
|
||||
class SOSProvider(AbstractEncoder):
|
||||
# for unconditional training
|
||||
def __init__(self, sos_token, quantize_interface=True):
|
||||
super().__init__()
|
||||
self.sos_token = sos_token
|
||||
self.quantize_interface = quantize_interface
|
||||
|
||||
def encode(self, x):
|
||||
# get batch size from data and replicate sos_token
|
||||
c = torch.ones(x.shape[0], 1)*self.sos_token
|
||||
c = c.long().to(x.device)
|
||||
if self.quantize_interface:
|
||||
return c, None, [None, None, c]
|
||||
return c
|
Reference in New Issue
Block a user