First commit OCR_earsing and Synthetics Handwritten Recognition awesome repo
This commit is contained in:
31
OCR_earsing/latent_diffusion/taming/modules/misc/coord.py
Normal file
31
OCR_earsing/latent_diffusion/taming/modules/misc/coord.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
class CoordStage(object):
|
||||
def __init__(self, n_embed, down_factor):
|
||||
self.n_embed = n_embed
|
||||
self.down_factor = down_factor
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, c):
|
||||
"""fake vqmodel interface"""
|
||||
assert 0.0 <= c.min() and c.max() <= 1.0
|
||||
b,ch,h,w = c.shape
|
||||
assert ch == 1
|
||||
|
||||
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
||||
mode="area")
|
||||
c = c.clamp(0.0, 1.0)
|
||||
c = self.n_embed*c
|
||||
c_quant = c.round()
|
||||
c_ind = c_quant.to(dtype=torch.long)
|
||||
|
||||
info = None, None, c_ind
|
||||
return c_quant, None, info
|
||||
|
||||
def decode(self, c):
|
||||
c = c/self.n_embed
|
||||
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
||||
mode="nearest")
|
||||
return c
|
Reference in New Issue
Block a user