From 2c3939b8c5184691d65cd933a1cc3d1b12378a3d Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Wed, 30 Oct 2024 14:43:12 +0800 Subject: [PATCH] fix: fix table cls preprocess --- lineless_table_rec/utils_table_recover.py | 2 +- table_cls/main.py | 4 +-- table_cls/utils.py | 34 +++++++++++++++++++++++ wired_table_rec/utils_table_recover.py | 2 +- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index 31ccfc5..67ea181 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -289,7 +289,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[An cur[0], next[0], axis="y", threhold=thehold ) if c_idx: - dis = max(next_box[0] - cur_box[1], 0) + dis = max(next_box[0] - cur_box[0], 0) blank_str = int(dis / threshold) * " " cur[1] = cur[1] + blank_str + next[1] xmin = min(cur_box[0], next_box[0]) diff --git a/table_cls/main.py b/table_cls/main.py index 554a973..179c0cc 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image -from .utils import InputType, LoadImage, OrtInferSession +from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop cur_dir = Path(__file__).resolve().parent q_cls_model_path = cur_dir / "models" / "table_cls.onnx" @@ -70,7 +70,7 @@ def __init__(self, model_path): def preprocess(self, img): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, (640, 640)) + img = resize_and_center_crop(img, 640) img = np.array(img, dtype=np.float32) / 255 img = img.transpose(2, 0, 1) # HWC to CHW img = np.expand_dims(img, axis=0) # Add batch dimension, only one image diff --git a/table_cls/utils.py b/table_cls/utils.py index db64b3a..ce404a6 100644 --- a/table_cls/utils.py +++ b/table_cls/utils.py @@ -178,3 +178,37 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): raise LoadImageError(f"{file_path} does not exist.") + + +def resize_and_center_crop(image, output_size=640): + """ + 将图片的最小边缩放到指定大小,并进行中心裁剪。 + + :param image: 输入的图片数组 (H, W, C) + :param output_size: 缩放和裁剪后的图片大小,默认为 640 + :return: 处理后的图片数组 (output_size, output_size, C) + """ + # 获取图片的高度和宽度 + height, width = image.shape[:2] + # 计算缩放比例 + if width < height: + new_width = output_size + new_height = int(output_size * height / width) + else: + new_width = int(output_size * width / height) + new_height = output_size + + # 缩放图片 + image_resize = cv2.resize( + image, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # 计算中心裁剪的坐标 + left = (new_width - output_size) // 2 + top = (new_height - output_size) // 2 + right = left + output_size + bottom = top + output_size + + # # 中心裁剪 + image_cropped = image_resize[top:bottom, left:right] + return image_cropped diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils_table_recover.py index 3462d86..01f451e 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils_table_recover.py @@ -383,7 +383,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[A cur[0], next[0], axis="y", threhold=threhold ) if c_idx: - dis = max(next_box[0] - cur_box[1], 0) + dis = max(next_box[0] - cur_box[0], 0) blank_str = int(dis / threshold) * " " cur[1] = cur[1] + blank_str + next[1] xmin = min(cur_box[0], next_box[0])