synthetics_handwritten_OCR/OCR_earsing/latent_diffusion/ldm_erase_text.py

97 lines
3.4 KiB
Python
Raw Normal View History

import argparse, os, sys, glob
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def make_batch(image, mask_pil_image, img_size, device):
if isinstance(image, str):
if img_size is not None:
image = np.array(Image.open(image).convert("RGB").resize(img_size))
else:
image = np.array(Image.open(image).convert("RGB")) # need to resize to a image_size
else:
if img_size is not None:
image = np.array(image.convert("RGB").resize(img_size))
else:
image = np.array(image.convert("RGB")) # need to resize to a image_size
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if img_size is not None:
mask = np.array(mask_pil_image.convert("L").resize(img_size))
else:
mask = np.array(mask_pil_image.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1 - mask) * image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k] * 2.0 - 1.0
return batch
def erase_text_from_image(img_path,
mask_pil_img,
model,
device,
opt,
img_size=None,
steps=None):
sampler = DDIMSampler(model)
with torch.no_grad():
with model.ema_scope():
if img_size is None:
batch = make_batch(
img_path,
mask_pil_img,
img_size=opt.img_size,
device=device)
else:
batch = make_batch(
img_path, mask_pil_img, img_size=img_size, device=device)
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(
batch["mask"], size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1] - 1, ) + c.shape[2:]
if steps is None:
samples_ddim, _ = sampler.sample(
S=opt.steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
else:
samples_ddim, _ = sampler.sample(
S=steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
return Image.fromarray(inpainted.astype(np.uint8))