Skip to content

Commit

Permalink
Merge pull request #65 from RapidAI/fix_table_cls_preprocess
Browse files Browse the repository at this point in the history
fix: fix table cls preprocess
  • Loading branch information
Joker1212 authored Oct 30, 2024
2 parents 881a164 + 2c3939b commit eaaf4d3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lineless_table_rec/utils_table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[An
cur[0], next[0], axis="y", threhold=thehold
)
if c_idx:
dis = max(next_box[0] - cur_box[1], 0)
dis = max(next_box[0] - cur_box[0], 0)
blank_str = int(dis / threshold) * " "
cur[1] = cur[1] + blank_str + next[1]
xmin = min(cur_box[0], next_box[0])
Expand Down
4 changes: 2 additions & 2 deletions table_cls/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from PIL import Image

from .utils import InputType, LoadImage, OrtInferSession
from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop

cur_dir = Path(__file__).resolve().parent
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, model_path):

def preprocess(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = resize_and_center_crop(img, 640)
img = np.array(img, dtype=np.float32) / 255
img = img.transpose(2, 0, 1) # HWC to CHW
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
Expand Down
34 changes: 34 additions & 0 deletions table_cls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,37 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")


def resize_and_center_crop(image, output_size=640):
"""
将图片的最小边缩放到指定大小,并进行中心裁剪。
:param image: 输入的图片数组 (H, W, C)
:param output_size: 缩放和裁剪后的图片大小,默认为 640
:return: 处理后的图片数组 (output_size, output_size, C)
"""
# 获取图片的高度和宽度
height, width = image.shape[:2]
# 计算缩放比例
if width < height:
new_width = output_size
new_height = int(output_size * height / width)
else:
new_width = int(output_size * width / height)
new_height = output_size

# 缩放图片
image_resize = cv2.resize(
image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
)

# 计算中心裁剪的坐标
left = (new_width - output_size) // 2
top = (new_height - output_size) // 2
right = left + output_size
bottom = top + output_size

# # 中心裁剪
image_cropped = image_resize[top:bottom, left:right]
return image_cropped
2 changes: 1 addition & 1 deletion wired_table_rec/utils_table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[A
cur[0], next[0], axis="y", threhold=threhold
)
if c_idx:
dis = max(next_box[0] - cur_box[1], 0)
dis = max(next_box[0] - cur_box[0], 0)
blank_str = int(dis / threshold) * " "
cur[1] = cur[1] + blank_str + next[1]
xmin = min(cur_box[0], next_box[0])
Expand Down

0 comments on commit eaaf4d3

Please sign in to comment.