126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
|
import logging
|
||
|
import os
|
||
|
import sys
|
||
|
import torch
|
||
|
|
||
|
# Download checkpoints
|
||
|
os.system("pip install git+https://github.com/facebookresearch/segment-anything.git")
|
||
|
os.system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
|
||
|
os.system('wget -O last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1')
|
||
|
|
||
|
from PIL import Image
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
|
||
|
# OCR
|
||
|
from utils.ocr_utils import ocr_extraction
|
||
|
from utils.easy_ocr_utils import easy_ocr_extraction
|
||
|
|
||
|
# SAM
|
||
|
from segment_anything import SamPredictor, sam_model_registry
|
||
|
|
||
|
# Diffusion model
|
||
|
sys.path.append('latent_diffusion')
|
||
|
from latent_diffusion.ldm_erase_text import erase_text_from_image, instantiate_from_config, OmegaConf
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def multi_mask2one_mask(masks):
|
||
|
_, _, h, w = masks.shape
|
||
|
for i, mask in enumerate(masks):
|
||
|
mask_image = mask.reshape(h, w, 1)
|
||
|
whole_mask = mask_image if i == 0 else whole_mask + mask_image
|
||
|
whole_mask = np.where(whole_mask == False, 0, 255)
|
||
|
return whole_mask
|
||
|
|
||
|
|
||
|
def numpy2PIL(numpy_image):
|
||
|
out = Image.fromarray(numpy_image.astype(np.uint8))
|
||
|
return out
|
||
|
|
||
|
|
||
|
def show_mask(mask, ax, random_color=False):
|
||
|
if random_color:
|
||
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||
|
else:
|
||
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
||
|
h, w = mask.shape[-2:]
|
||
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||
|
ax.imshow(mask_image)
|
||
|
|
||
|
|
||
|
def run_earse(img_path, sam_type, sam_checkpoint, config_path, model_checkpoint, device="cpu", img_size=(512, 512),
|
||
|
steps=50, use_easy_ocr=False):
|
||
|
img = cv2.imread(img_path)
|
||
|
# h, w, c = img.shape
|
||
|
|
||
|
# SAM
|
||
|
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
|
||
|
sam = sam.to(device)
|
||
|
sam_predictor = SamPredictor(sam)
|
||
|
|
||
|
# Diffusion model
|
||
|
config = OmegaConf.load(config_path)
|
||
|
model = instantiate_from_config(config.model)
|
||
|
model.load_state_dict(
|
||
|
torch.load(model_checkpoint)["state_dict"],
|
||
|
strict=False
|
||
|
)
|
||
|
model = model.to(device)
|
||
|
|
||
|
if use_easy_ocr:
|
||
|
word_info = easy_ocr_extraction(img_path)
|
||
|
else:
|
||
|
word_info = ocr_extraction(img_path)
|
||
|
det_bboxes = [bbox[:4] for bbox in word_info]
|
||
|
# convert to torch tensor
|
||
|
det_bboxes = torch.tensor(det_bboxes)
|
||
|
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
||
|
det_bboxes, img.shape[:2]
|
||
|
)
|
||
|
sam_predictor.set_image(img, image_format='BGR')
|
||
|
masks, _, _ = sam_predictor.predict_torch(
|
||
|
point_coords=None,
|
||
|
point_labels=None,
|
||
|
boxes=transformed_boxes,
|
||
|
multimask_output=False,
|
||
|
)
|
||
|
ori_mask = multi_mask2one_mask(masks=masks)
|
||
|
mask_img = ori_mask[:, :, 0].astype('uint8')
|
||
|
kernel = np.ones((5, 5), np.int8)
|
||
|
whole_mask = cv2.dilate(
|
||
|
mask_img, kernel, iterations=2
|
||
|
)
|
||
|
|
||
|
mask_pil_image = numpy2PIL(numpy_image=whole_mask)
|
||
|
result_img = erase_text_from_image(
|
||
|
img_path=img_path,
|
||
|
mask_pil_img=mask_pil_image,
|
||
|
model=model,
|
||
|
device=device,
|
||
|
opt=None,
|
||
|
img_size=img_size,
|
||
|
steps=steps
|
||
|
)
|
||
|
|
||
|
result_img = cv2.cvtColor(np.array(result_img), cv2.COLOR_RGB2BGR)
|
||
|
cv2.namedWindow("Result Image", cv2.WINDOW_NORMAL)
|
||
|
cv2.imshow("Result Image", result_img)
|
||
|
cv2.waitKey(0) # Wait until a key is pressed
|
||
|
cv2.destroyAllWindows() # Close the image window
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
run_earse(
|
||
|
img_path="Facture médecine douce-27746732_0.jpg",
|
||
|
sam_type="vit_h",
|
||
|
sam_checkpoint="sam_vit_h_4b8939.pth",
|
||
|
config_path="latent_diffusion/inpainting_big/config.yaml",
|
||
|
model_checkpoint="last.ckpt",
|
||
|
device="cpu",
|
||
|
img_size=(512, 512),
|
||
|
steps=50,
|
||
|
use_easy_ocr=True
|
||
|
)
|