synthetics_handwritten_OCR/OCR_earsing/utils/ocr_utils.py

212 lines
6.4 KiB
Python

from typing import List, Optional, Tuple, Union
import logging
import sys
import math
import cv2
from PIL import Image
import torch
import numpy as np
import numpy.typing as npt
from cv_base.postprocess.db_postprocess import DBPostProcess
from cv_base.postprocess.pp_rec_postprocess import CTCLabelDecode
from cv_base.preprocess.db_preprocess import DBPreprocess
from cv_base.utils.perspective import four_point_transform
from infer_client.adapters import InferenceAdapter
from infer_client.adapters.onnx import OnnxInferenceAdapter
from infer_client.inference import Inference
logger = logging.getLogger(__name__)
DET_MODEL = "model_checkpoint/det_latin_words"
REC_MODEL = "model_checkpoint/rec_general_words_fr"
VOCAB_DICT = "model_checkpoint/latin_dict.txt"
def get_vocab_str(vocab_dict_path: str):
with open(vocab_dict_path, "r") as rf:
vocab = rf.read()
return vocab.replace("\n", "")
def box4point_to_box2point(box4point):
# bounding box = [x0, y0, x1, y1, x2, y2, x3, y3]
all_x = [box4point[2 * i] for i in range(4)]
all_y = [box4point[2 * i + 1] for i in range(4)]
box2point = [min(all_x), min(all_y), max(all_x), max(all_y)]
return box2point
def minimum_bounding_rectangle(points):
from scipy.spatial import ConvexHull
hull = ConvexHull(points)
hull_points = points[hull.vertices]
# Find the minimum area rectangle
rect = cv2.minAreaRect(hull_points.astype(np.float32))
box = cv2.boxPoints(rect)
box = np.int0(box)
return box
class DetTextDocument(object):
def __init__(
self,
adapter: InferenceAdapter,
img_size: Union[int, Tuple[int, int]] = 736,
thresh: float = 0.2,
bbox_thresh: float = 0.5,
unclip_ratio: float = 2.0,
):
self.infer_obj = Inference(adapter)
if self.infer_obj.health():
logger.info(f"Init {self.__class__.__name__} done")
else:
logger.error(f"Init {self.__class__.__name__} fail")
sys.exit(1)
self.pre_process = DBPreprocess(img_size=img_size)
self.post_process = DBPostProcess(
thresh=thresh,
bbox_thresh=bbox_thresh,
unclip_ratio=unclip_ratio,
score_mode="fast",
)
def inference(
self,
img: npt.NDArray[np.uint8],
) -> Tuple[List[Optional[npt.NDArray[np.int32]]], List[Optional[float]]]:
pre_processed_img, _ = self.pre_process(img)
outs_dbnet = self.infer_obj.inference({"input": pre_processed_img}, ["output"])
if not outs_dbnet:
return [], []
_, bbox_lst, score_lst = self.post_process(
outs_dbnet[0][0], img_shape=(img.shape[0], img.shape[1])
)
return [bbox.astype(np.int32) for bbox in bbox_lst], score_lst
class RecPreprocessImage(object):
def __init__(self, rec_image_shape):
self.rec_image_shape = rec_image_shape
def __call__(self, img):
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
h, w = img.shape[:2]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int((imgH * max_wh_ratio))
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im[np.newaxis, :]
class RecTextDocument(object):
def __init__(
self,
adapter: InferenceAdapter,
vocab_dict: str,
):
self.pre_process = RecPreprocessImage(rec_image_shape=[3, 48, 320])
self.infer_obj = Inference(adapter)
if self.infer_obj.health():
logger.info(f"Init {self.__class__.__name__} done")
else:
logger.error(f"Init {self.__class__.__name__} fail")
sys.exit(1)
vocab_dict += " "
self.post_process = CTCLabelDecode(vocab_dict, use_space_char=True)
def inference(self, img: npt.NDArray[np.uint8]) -> Tuple[str, float]:
preprocessed_img = self.pre_process(img)
if res := self.infer_obj.inference(
{"x": preprocessed_img}, ["softmax_2.tmp_0"]
):
sent, probs = self.post_process(res)[0]
probs = [probs]
return sent, round(sum(probs) / len(probs), 5)
return "", 0.0
det_model = DetTextDocument(
OnnxInferenceAdapter(
model_name=DET_MODEL,
version="1"
),
# TritonInferenceAdapter(
# triton_server=TRITON_SERVER_URL,
# model_name=os.path.basename(DET_MODEL),
# ssl=False,
# ),
img_size=736,
unclip_ratio=1.7,
)
rec_model = RecTextDocument(
OnnxInferenceAdapter(
model_name=REC_MODEL,
version="1"
),
# TritonInferenceAdapter(
# triton_server=TRITON_SERVER_URL,
# model_name=os.path.basename(REC_MODEL),
# ssl=False,
# ),
vocab_dict=get_vocab_str(VOCAB_DICT),
)
def ocr_extraction(
img_path: str
) -> dict:
org_img_bgr = cv2.imread(img_path)
# text detection
bbox_lst, _ = det_model.inference(org_img_bgr)
word_info = []
for bbox in bbox_lst:
# text recognition
text_img = four_point_transform(org_img_bgr, bbox)
word, conf_text = rec_model.inference(text_img)
word_info.append(
[min(bbox[:, 0]), min(bbox[:, 1]), max(bbox[:, 0]), max(bbox[:, 1]), word]
)
return word_info
def visualize_ocr(img_path: str, word_info):
org_img_bgr = cv2.imread(img_path)
for i, (x1, y1, x2, y2, word) in enumerate(word_info):
cv2.rectangle(org_img_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(
org_img_bgr,
word,
(x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 200, 0),
2,
)
cv2.namedWindow("Image with Annotations", cv2.WINDOW_NORMAL)
cv2.imshow("Image with Annotations", org_img_bgr)
cv2.waitKey(0) # Wait until a key is pressed
cv2.destroyAllWindows() # Close the image window