synthetics_handwritten_OCR/OCR_earsing/ocr_eraser.py

126 lines
3.8 KiB
Python
Raw Permalink Normal View History

import logging
import os
import sys
import torch
# Download checkpoints
os.system("pip install git+https://github.com/facebookresearch/segment-anything.git")
os.system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
os.system('wget -O last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1')
from PIL import Image
import numpy as np
import cv2
# OCR
from utils.ocr_utils import ocr_extraction
from utils.easy_ocr_utils import easy_ocr_extraction
# SAM
from segment_anything import SamPredictor, sam_model_registry
# Diffusion model
sys.path.append('latent_diffusion')
from latent_diffusion.ldm_erase_text import erase_text_from_image, instantiate_from_config, OmegaConf
logger = logging.getLogger(__name__)
def multi_mask2one_mask(masks):
_, _, h, w = masks.shape
for i, mask in enumerate(masks):
mask_image = mask.reshape(h, w, 1)
whole_mask = mask_image if i == 0 else whole_mask + mask_image
whole_mask = np.where(whole_mask == False, 0, 255)
return whole_mask
def numpy2PIL(numpy_image):
out = Image.fromarray(numpy_image.astype(np.uint8))
return out
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def run_earse(img_path, sam_type, sam_checkpoint, config_path, model_checkpoint, device="cpu", img_size=(512, 512),
steps=50, use_easy_ocr=False):
img = cv2.imread(img_path)
# h, w, c = img.shape
# SAM
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
sam = sam.to(device)
sam_predictor = SamPredictor(sam)
# Diffusion model
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model)
model.load_state_dict(
torch.load(model_checkpoint)["state_dict"],
strict=False
)
model = model.to(device)
if use_easy_ocr:
word_info = easy_ocr_extraction(img_path)
else:
word_info = ocr_extraction(img_path)
det_bboxes = [bbox[:4] for bbox in word_info]
# convert to torch tensor
det_bboxes = torch.tensor(det_bboxes)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
det_bboxes, img.shape[:2]
)
sam_predictor.set_image(img, image_format='BGR')
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
ori_mask = multi_mask2one_mask(masks=masks)
mask_img = ori_mask[:, :, 0].astype('uint8')
kernel = np.ones((5, 5), np.int8)
whole_mask = cv2.dilate(
mask_img, kernel, iterations=2
)
mask_pil_image = numpy2PIL(numpy_image=whole_mask)
result_img = erase_text_from_image(
img_path=img_path,
mask_pil_img=mask_pil_image,
model=model,
device=device,
opt=None,
img_size=img_size,
steps=steps
)
result_img = cv2.cvtColor(np.array(result_img), cv2.COLOR_RGB2BGR)
cv2.namedWindow("Result Image", cv2.WINDOW_NORMAL)
cv2.imshow("Result Image", result_img)
cv2.waitKey(0) # Wait until a key is pressed
cv2.destroyAllWindows() # Close the image window
if __name__ == "__main__":
run_earse(
img_path="Facture médecine douce-27746732_0.jpg",
sam_type="vit_h",
sam_checkpoint="sam_vit_h_4b8939.pth",
config_path="latent_diffusion/inpainting_big/config.yaml",
model_checkpoint="last.ckpt",
device="cpu",
img_size=(512, 512),
steps=50,
use_easy_ocr=True
)