212 lines
6.4 KiB
Python
212 lines
6.4 KiB
Python
from typing import List, Optional, Tuple, Union
|
|
import logging
|
|
import sys
|
|
import math
|
|
|
|
import cv2
|
|
from PIL import Image
|
|
import torch
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
from cv_base.postprocess.db_postprocess import DBPostProcess
|
|
from cv_base.postprocess.pp_rec_postprocess import CTCLabelDecode
|
|
from cv_base.preprocess.db_preprocess import DBPreprocess
|
|
from cv_base.utils.perspective import four_point_transform
|
|
from infer_client.adapters import InferenceAdapter
|
|
from infer_client.adapters.onnx import OnnxInferenceAdapter
|
|
from infer_client.inference import Inference
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DET_MODEL = "model_checkpoint/det_latin_words"
|
|
REC_MODEL = "model_checkpoint/rec_general_words_fr"
|
|
VOCAB_DICT = "model_checkpoint/latin_dict.txt"
|
|
|
|
|
|
def get_vocab_str(vocab_dict_path: str):
|
|
with open(vocab_dict_path, "r") as rf:
|
|
vocab = rf.read()
|
|
return vocab.replace("\n", "")
|
|
|
|
|
|
def box4point_to_box2point(box4point):
|
|
# bounding box = [x0, y0, x1, y1, x2, y2, x3, y3]
|
|
all_x = [box4point[2 * i] for i in range(4)]
|
|
all_y = [box4point[2 * i + 1] for i in range(4)]
|
|
box2point = [min(all_x), min(all_y), max(all_x), max(all_y)]
|
|
return box2point
|
|
|
|
|
|
def minimum_bounding_rectangle(points):
|
|
from scipy.spatial import ConvexHull
|
|
|
|
hull = ConvexHull(points)
|
|
hull_points = points[hull.vertices]
|
|
|
|
# Find the minimum area rectangle
|
|
rect = cv2.minAreaRect(hull_points.astype(np.float32))
|
|
box = cv2.boxPoints(rect)
|
|
box = np.int0(box)
|
|
|
|
return box
|
|
|
|
|
|
class DetTextDocument(object):
|
|
def __init__(
|
|
self,
|
|
adapter: InferenceAdapter,
|
|
img_size: Union[int, Tuple[int, int]] = 736,
|
|
thresh: float = 0.2,
|
|
bbox_thresh: float = 0.5,
|
|
unclip_ratio: float = 2.0,
|
|
):
|
|
self.infer_obj = Inference(adapter)
|
|
|
|
if self.infer_obj.health():
|
|
logger.info(f"Init {self.__class__.__name__} done")
|
|
else:
|
|
logger.error(f"Init {self.__class__.__name__} fail")
|
|
sys.exit(1)
|
|
|
|
self.pre_process = DBPreprocess(img_size=img_size)
|
|
self.post_process = DBPostProcess(
|
|
thresh=thresh,
|
|
bbox_thresh=bbox_thresh,
|
|
unclip_ratio=unclip_ratio,
|
|
score_mode="fast",
|
|
)
|
|
|
|
def inference(
|
|
self,
|
|
img: npt.NDArray[np.uint8],
|
|
) -> Tuple[List[Optional[npt.NDArray[np.int32]]], List[Optional[float]]]:
|
|
pre_processed_img, _ = self.pre_process(img)
|
|
outs_dbnet = self.infer_obj.inference({"input": pre_processed_img}, ["output"])
|
|
|
|
if not outs_dbnet:
|
|
return [], []
|
|
|
|
_, bbox_lst, score_lst = self.post_process(
|
|
outs_dbnet[0][0], img_shape=(img.shape[0], img.shape[1])
|
|
)
|
|
return [bbox.astype(np.int32) for bbox in bbox_lst], score_lst
|
|
|
|
|
|
class RecPreprocessImage(object):
|
|
def __init__(self, rec_image_shape):
|
|
self.rec_image_shape = rec_image_shape
|
|
|
|
def __call__(self, img):
|
|
imgC, imgH, imgW = self.rec_image_shape[:3]
|
|
max_wh_ratio = imgW / imgH
|
|
h, w = img.shape[:2]
|
|
ratio = w * 1.0 / h
|
|
max_wh_ratio = max(max_wh_ratio, ratio)
|
|
imgW = int((imgH * max_wh_ratio))
|
|
if math.ceil(imgH * ratio) > imgW:
|
|
resized_w = imgW
|
|
else:
|
|
resized_w = int(math.ceil(imgH * ratio))
|
|
resized_image = cv2.resize(img, (resized_w, imgH))
|
|
resized_image = resized_image.astype("float32")
|
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
resized_image -= 0.5
|
|
resized_image /= 0.5
|
|
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
|
padding_im[:, :, 0:resized_w] = resized_image
|
|
return padding_im[np.newaxis, :]
|
|
|
|
|
|
class RecTextDocument(object):
|
|
def __init__(
|
|
self,
|
|
adapter: InferenceAdapter,
|
|
vocab_dict: str,
|
|
):
|
|
self.pre_process = RecPreprocessImage(rec_image_shape=[3, 48, 320])
|
|
self.infer_obj = Inference(adapter)
|
|
|
|
if self.infer_obj.health():
|
|
logger.info(f"Init {self.__class__.__name__} done")
|
|
else:
|
|
logger.error(f"Init {self.__class__.__name__} fail")
|
|
sys.exit(1)
|
|
vocab_dict += " "
|
|
self.post_process = CTCLabelDecode(vocab_dict, use_space_char=True)
|
|
|
|
def inference(self, img: npt.NDArray[np.uint8]) -> Tuple[str, float]:
|
|
preprocessed_img = self.pre_process(img)
|
|
if res := self.infer_obj.inference(
|
|
{"x": preprocessed_img}, ["softmax_2.tmp_0"]
|
|
):
|
|
sent, probs = self.post_process(res)[0]
|
|
probs = [probs]
|
|
return sent, round(sum(probs) / len(probs), 5)
|
|
return "", 0.0
|
|
|
|
|
|
det_model = DetTextDocument(
|
|
OnnxInferenceAdapter(
|
|
model_name=DET_MODEL,
|
|
version="1"
|
|
),
|
|
# TritonInferenceAdapter(
|
|
# triton_server=TRITON_SERVER_URL,
|
|
# model_name=os.path.basename(DET_MODEL),
|
|
# ssl=False,
|
|
# ),
|
|
img_size=736,
|
|
unclip_ratio=1.7,
|
|
)
|
|
rec_model = RecTextDocument(
|
|
OnnxInferenceAdapter(
|
|
model_name=REC_MODEL,
|
|
version="1"
|
|
),
|
|
# TritonInferenceAdapter(
|
|
# triton_server=TRITON_SERVER_URL,
|
|
# model_name=os.path.basename(REC_MODEL),
|
|
# ssl=False,
|
|
# ),
|
|
vocab_dict=get_vocab_str(VOCAB_DICT),
|
|
)
|
|
|
|
|
|
def ocr_extraction(
|
|
img_path: str
|
|
) -> dict:
|
|
org_img_bgr = cv2.imread(img_path)
|
|
|
|
# text detection
|
|
bbox_lst, _ = det_model.inference(org_img_bgr)
|
|
|
|
word_info = []
|
|
for bbox in bbox_lst:
|
|
# text recognition
|
|
text_img = four_point_transform(org_img_bgr, bbox)
|
|
word, conf_text = rec_model.inference(text_img)
|
|
word_info.append(
|
|
[min(bbox[:, 0]), min(bbox[:, 1]), max(bbox[:, 0]), max(bbox[:, 1]), word]
|
|
)
|
|
|
|
return word_info
|
|
|
|
|
|
def visualize_ocr(img_path: str, word_info):
|
|
org_img_bgr = cv2.imread(img_path)
|
|
for i, (x1, y1, x2, y2, word) in enumerate(word_info):
|
|
cv2.rectangle(org_img_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
cv2.putText(
|
|
org_img_bgr,
|
|
word,
|
|
(x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(0, 200, 0),
|
|
2,
|
|
)
|
|
|
|
cv2.namedWindow("Image with Annotations", cv2.WINDOW_NORMAL)
|
|
cv2.imshow("Image with Annotations", org_img_bgr)
|
|
cv2.waitKey(0) # Wait until a key is pressed
|
|
cv2.destroyAllWindows() # Close the image window |