import logging import os import sys import torch # Download checkpoints os.system("pip install git+https://github.com/facebookresearch/segment-anything.git") os.system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') os.system('wget -O last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1') from PIL import Image import numpy as np import cv2 # OCR from utils.ocr_utils import ocr_extraction from utils.easy_ocr_utils import easy_ocr_extraction # SAM from segment_anything import SamPredictor, sam_model_registry # Diffusion model sys.path.append('latent_diffusion') from latent_diffusion.ldm_erase_text import erase_text_from_image, instantiate_from_config, OmegaConf logger = logging.getLogger(__name__) def multi_mask2one_mask(masks): _, _, h, w = masks.shape for i, mask in enumerate(masks): mask_image = mask.reshape(h, w, 1) whole_mask = mask_image if i == 0 else whole_mask + mask_image whole_mask = np.where(whole_mask == False, 0, 255) return whole_mask def numpy2PIL(numpy_image): out = Image.fromarray(numpy_image.astype(np.uint8)) return out def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def run_earse(img_path, sam_type, sam_checkpoint, config_path, model_checkpoint, device="cpu", img_size=(512, 512), steps=50, use_easy_ocr=False): img = cv2.imread(img_path) # h, w, c = img.shape # SAM sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint) sam = sam.to(device) sam_predictor = SamPredictor(sam) # Diffusion model config = OmegaConf.load(config_path) model = instantiate_from_config(config.model) model.load_state_dict( torch.load(model_checkpoint)["state_dict"], strict=False ) model = model.to(device) if use_easy_ocr: word_info = easy_ocr_extraction(img_path) else: word_info = ocr_extraction(img_path) det_bboxes = [bbox[:4] for bbox in word_info] # convert to torch tensor det_bboxes = torch.tensor(det_bboxes) transformed_boxes = sam_predictor.transform.apply_boxes_torch( det_bboxes, img.shape[:2] ) sam_predictor.set_image(img, image_format='BGR') masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) ori_mask = multi_mask2one_mask(masks=masks) mask_img = ori_mask[:, :, 0].astype('uint8') kernel = np.ones((5, 5), np.int8) whole_mask = cv2.dilate( mask_img, kernel, iterations=2 ) mask_pil_image = numpy2PIL(numpy_image=whole_mask) result_img = erase_text_from_image( img_path=img_path, mask_pil_img=mask_pil_image, model=model, device=device, opt=None, img_size=img_size, steps=steps ) result_img = cv2.cvtColor(np.array(result_img), cv2.COLOR_RGB2BGR) cv2.namedWindow("Result Image", cv2.WINDOW_NORMAL) cv2.imshow("Result Image", result_img) cv2.waitKey(0) # Wait until a key is pressed cv2.destroyAllWindows() # Close the image window if __name__ == "__main__": run_earse( img_path="Facture médecine douce-27746732_0.jpg", sam_type="vit_h", sam_checkpoint="sam_vit_h_4b8939.pth", config_path="latent_diffusion/inpainting_big/config.yaml", model_checkpoint="last.ckpt", device="cpu", img_size=(512, 512), steps=50, use_easy_ocr=True )