from typing import List, Optional, Tuple, Union import logging import sys import math import cv2 from PIL import Image import torch import numpy as np import numpy.typing as npt from cv_base.postprocess.db_postprocess import DBPostProcess from cv_base.postprocess.pp_rec_postprocess import CTCLabelDecode from cv_base.preprocess.db_preprocess import DBPreprocess from cv_base.utils.perspective import four_point_transform from infer_client.adapters import InferenceAdapter from infer_client.adapters.onnx import OnnxInferenceAdapter from infer_client.inference import Inference logger = logging.getLogger(__name__) DET_MODEL = "model_checkpoint/det_latin_words" REC_MODEL = "model_checkpoint/rec_general_words_fr" VOCAB_DICT = "model_checkpoint/latin_dict.txt" def get_vocab_str(vocab_dict_path: str): with open(vocab_dict_path, "r") as rf: vocab = rf.read() return vocab.replace("\n", "") def box4point_to_box2point(box4point): # bounding box = [x0, y0, x1, y1, x2, y2, x3, y3] all_x = [box4point[2 * i] for i in range(4)] all_y = [box4point[2 * i + 1] for i in range(4)] box2point = [min(all_x), min(all_y), max(all_x), max(all_y)] return box2point def minimum_bounding_rectangle(points): from scipy.spatial import ConvexHull hull = ConvexHull(points) hull_points = points[hull.vertices] # Find the minimum area rectangle rect = cv2.minAreaRect(hull_points.astype(np.float32)) box = cv2.boxPoints(rect) box = np.int0(box) return box class DetTextDocument(object): def __init__( self, adapter: InferenceAdapter, img_size: Union[int, Tuple[int, int]] = 736, thresh: float = 0.2, bbox_thresh: float = 0.5, unclip_ratio: float = 2.0, ): self.infer_obj = Inference(adapter) if self.infer_obj.health(): logger.info(f"Init {self.__class__.__name__} done") else: logger.error(f"Init {self.__class__.__name__} fail") sys.exit(1) self.pre_process = DBPreprocess(img_size=img_size) self.post_process = DBPostProcess( thresh=thresh, bbox_thresh=bbox_thresh, unclip_ratio=unclip_ratio, score_mode="fast", ) def inference( self, img: npt.NDArray[np.uint8], ) -> Tuple[List[Optional[npt.NDArray[np.int32]]], List[Optional[float]]]: pre_processed_img, _ = self.pre_process(img) outs_dbnet = self.infer_obj.inference({"input": pre_processed_img}, ["output"]) if not outs_dbnet: return [], [] _, bbox_lst, score_lst = self.post_process( outs_dbnet[0][0], img_shape=(img.shape[0], img.shape[1]) ) return [bbox.astype(np.int32) for bbox in bbox_lst], score_lst class RecPreprocessImage(object): def __init__(self, rec_image_shape): self.rec_image_shape = rec_image_shape def __call__(self, img): imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH h, w = img.shape[:2] ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, ratio) imgW = int((imgH * max_wh_ratio)) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im[np.newaxis, :] class RecTextDocument(object): def __init__( self, adapter: InferenceAdapter, vocab_dict: str, ): self.pre_process = RecPreprocessImage(rec_image_shape=[3, 48, 320]) self.infer_obj = Inference(adapter) if self.infer_obj.health(): logger.info(f"Init {self.__class__.__name__} done") else: logger.error(f"Init {self.__class__.__name__} fail") sys.exit(1) vocab_dict += " " self.post_process = CTCLabelDecode(vocab_dict, use_space_char=True) def inference(self, img: npt.NDArray[np.uint8]) -> Tuple[str, float]: preprocessed_img = self.pre_process(img) if res := self.infer_obj.inference( {"x": preprocessed_img}, ["softmax_2.tmp_0"] ): sent, probs = self.post_process(res)[0] probs = [probs] return sent, round(sum(probs) / len(probs), 5) return "", 0.0 det_model = DetTextDocument( OnnxInferenceAdapter( model_name=DET_MODEL, version="1" ), # TritonInferenceAdapter( # triton_server=TRITON_SERVER_URL, # model_name=os.path.basename(DET_MODEL), # ssl=False, # ), img_size=736, unclip_ratio=1.7, ) rec_model = RecTextDocument( OnnxInferenceAdapter( model_name=REC_MODEL, version="1" ), # TritonInferenceAdapter( # triton_server=TRITON_SERVER_URL, # model_name=os.path.basename(REC_MODEL), # ssl=False, # ), vocab_dict=get_vocab_str(VOCAB_DICT), ) def ocr_extraction( img_path: str ) -> dict: org_img_bgr = cv2.imread(img_path) # text detection bbox_lst, _ = det_model.inference(org_img_bgr) word_info = [] for bbox in bbox_lst: # text recognition text_img = four_point_transform(org_img_bgr, bbox) word, conf_text = rec_model.inference(text_img) word_info.append( [min(bbox[:, 0]), min(bbox[:, 1]), max(bbox[:, 0]), max(bbox[:, 1]), word] ) return word_info def visualize_ocr(img_path: str, word_info): org_img_bgr = cv2.imread(img_path) for i, (x1, y1, x2, y2, word) in enumerate(word_info): cv2.rectangle(org_img_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText( org_img_bgr, word, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 200, 0), 2, ) cv2.namedWindow("Image with Annotations", cv2.WINDOW_NORMAL) cv2.imshow("Image with Annotations", org_img_bgr) cv2.waitKey(0) # Wait until a key is pressed cv2.destroyAllWindows() # Close the image window