From 9da519ca95edcf94cb2be1225cbf47d630b5bbbf Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Tue, 10 Sep 2024 23:38:21 +0800 Subject: [PATCH] feature: optimize lineless table rec --- .gitignore | 3 +- lineless_table_rec/main.py | 196 ++++++--- lineless_table_rec/utils_table_recover.py | 486 ++++++++++++++++++---- 3 files changed, 547 insertions(+), 138 deletions(-) diff --git a/.gitignore b/.gitignore index 4c21ff5..14d5622 100755 --- a/.gitignore +++ b/.gitignore @@ -156,4 +156,5 @@ long1.jpg *.pdmodel .DS_Store -*.npy \ No newline at end of file +*.npy +/lineless_table_rec/output/ diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 3d95a81..e2b8f10 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -3,6 +3,7 @@ # @Contact: liekkaskono@163.com import argparse import logging +import os import time import traceback from pathlib import Path @@ -12,13 +13,13 @@ import numpy as np from rapidocr_onnxruntime import RapidOCR -from .lineless_table_process import DetProcess, get_affine_transform_upper_left -from .utils import InputType, LoadImage, OrtInferSession -from .utils_table_recover import ( +from lineless_table_process import DetProcess, get_affine_transform_upper_left +from utils import InputType, LoadImage, OrtInferSession +from utils_table_recover import ( get_rotate_crop_image, - match_ocr_cell, plot_html_table, - sorted_boxes, + sorted_ocr_boxes, box_4_2_poly_to_box_4_1, match_ocr_cell, + filter_duplicated_box, gather_ocr_list_by_row, plot_rec_box_with_logic_info, plot_rec_box, format_html, ) cur_dir = Path(__file__).resolve().parent @@ -28,9 +29,9 @@ class LinelessTableRecognition: def __init__( - self, - detect_model_path: Union[str, Path] = detect_model_path, - process_model_path: Union[str, Path] = process_model_path, + self, + detect_model_path: Union[str, Path] = detect_model_path, + process_model_path: Union[str, Path] = process_model_path, ): self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3) self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3) @@ -45,32 +46,70 @@ def __init__( self.det_process = DetProcess() self.ocr = RapidOCR() - def __call__(self, content: InputType) -> str: + def __call__(self, content: InputType): ss = time.perf_counter() img = self.load_img(content) - ocr_res, _ = self.ocr(img) - input_info = self.preprocess(img) try: polygons, slct_logi = self.infer(input_info) logi_points = self.filter_logi_points(slct_logi) + # ocr 结果匹配 + cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons) + # 如果有识别框没有ocr结果,直接进行rec补充 + cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map) + # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 + t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points) + # 拆分包含和重叠的识别框 + deleted_idx_set = filter_duplicated_box([table_box_ocr['t_box'] for table_box_ocr in t_rec_ocr_list]) + t_rec_ocr_list = [t_rec_ocr_list[i] for i in range(len(t_rec_ocr_list)) if i not in deleted_idx_set] + # 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框 + t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list) + # todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中 + # 将同一个识别框中的ocr结果排序并同行合并 + t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list) + # 渲染为html + logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list] + cell_box_det_map = { + i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr['t_ocr_res']] + for i, t_box_ocr in enumerate(t_rec_ocr_list) + } + table_str = plot_html_table(logi_points, cell_box_det_map) - sorted_polygons = sorted_boxes(polygons) - - cell_box_map = match_ocr_cell(sorted_polygons, ocr_res) - cell_box_map = self.re_rec(img, sorted_polygons, cell_box_map) - - logi_points = self.sort_logi_by_polygons( - sorted_polygons, polygons, logi_points - ) - - table_str = plot_html_table(logi_points, cell_box_map) + # 输出可视化排序,用于验证结果,生产版本可以去掉 + _, idx_list = sorted_ocr_boxes([t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list]) + t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list] + sorted_polygons = [t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list] + sorted_logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list] + ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res] + sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) table_elapse = time.perf_counter() - ss - return table_str, table_elapse + return table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res except Exception: logging.warning(traceback.format_exc()) - return "", 0.0 + return "", 0.0, None, None, None + + def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.ndarray, + logi_points: list[np.ndarray]) -> list[dict[str, any]]: + res = [] + for i in range(len(polygons)): + ocr_res_list = cell_box_det_map.get(i) + if not ocr_res_list: + continue + xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list]) + ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list]) + xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list]) + ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list]) + dict_res = { + # xmin,xmax,ymin,ymax + 't_box': [xmin, ymin, xmax, ymax], + # row_start,row_end,col_start,col_end + 't_logic_box': logi_points[i].tolist(), + # [[xmin,xmax,ymin,ymax], text] + 't_ocr_res': [[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]] for ocr_det in ocr_res_list] + } + res.append(dict_res) + return res def preprocess(self, img: np.ndarray) -> Dict[str, Any]: height, width = img.shape[:2] @@ -115,52 +154,107 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]: ) return slct_output_dets, slct_logi - def filter_logi_points(self, slct_logi: np.ndarray) -> Dict[str, Any]: + def sort_and_gather_ocr_res(self, res): + for i, dict_res in enumerate(res): + dict_res['t_ocr_res'] = gather_ocr_list_by_row(dict_res['t_ocr_res']) + _, sorted_idx = sorted_ocr_boxes([ocr_det[0] for ocr_det in dict_res['t_ocr_res']]) + dict_res['t_ocr_res'] = [dict_res['t_ocr_res'][i] for i in sorted_idx] + return res + + def handle_overlap_row_col(self, res): + max_row, max_col = 0, 0 + for dict_res in res: + max_row = max(max_row, dict_res['t_logic_box'][1] + 1) # 加1是因为结束下标是包含在内的 + max_col = max(max_col, dict_res['t_logic_box'][3] + 1) # 加1是因为结束下标是包含在内的 + # 创建一个二维数组来存储 sorted_logi_points 中的元素 + grid = [[None] * max_col for _ in range(max_row)] + # 将 sorted_logi_points 中的元素填充到 grid 中 + deleted_idx = set() + for i, dict_res in enumerate(res): + if i in deleted_idx: + continue + row_start, row_end, col_start, col_end = dict_res['t_logic_box'] + for row in range(row_start, row_end + 1): + if i in deleted_idx: + continue + for col in range(col_start, col_end + 1): + if i in deleted_idx: + continue + exist_dict_res = grid[row][col] + if not exist_dict_res: + grid[row][col] = dict_res + continue + if exist_dict_res['t_logic_box'] == dict_res['t_logic_box']: + exist_dict_res['t_ocr_res'].extend(dict_res['t_ocr_res']) + deleted_idx.add(i) + # 修正识别框坐标 + exist_dict_res['t_box'] = [min(exist_dict_res['t_box'][0], dict_res['t_box'][0]), + min(exist_dict_res['t_box'][1], dict_res['t_box'][1]), + max(exist_dict_res['t_box'][2], dict_res['t_box'][2]), + max(exist_dict_res['t_box'][3], dict_res['t_box'][3]), + ] + continue + + # 去掉重叠框 + res = [res[i] for i in range(len(res)) if i not in deleted_idx] + return res, grid + + @staticmethod + def filter_logi_points(slct_logi: np.ndarray) -> list[np.ndarray]: + for logic_points in slct_logi[0]: + # 修正坐标接近导致的r_e > r_s 或 c_e > c_s + if abs(logic_points[0] - logic_points[1]) < 0.2: + row = (logic_points[0] + logic_points[1]) / 2 + logic_points[0] = row + logic_points[1] = row + if abs(logic_points[2] - logic_points[3]) < 0.2: + col = (logic_points[2] + logic_points[3]) / 2 + logic_points[2] = col + logic_points[3] = col logi_floor = np.floor(slct_logi) dev = slct_logi - logi_floor slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor) - return slct_logi[0] - - @staticmethod - def sort_logi_by_polygons( - sorted_polygons: np.ndarray, polygons: np.ndarray, logi_points: np.ndarray - ) -> np.ndarray: - sorted_idx = [] - for v in sorted_polygons: - loc_idx = np.argwhere(v[0, 0] == polygons[:, 0, 0]).squeeze() - sorted_idx.append(int(loc_idx)) - logi_points = logi_points[sorted_idx] - return logi_points + return slct_logi[0].astype(np.int32) def re_rec( - self, - img: np.ndarray, - sorted_polygons: np.ndarray, - cell_box_map: Dict[int, List[str]], - ) -> Dict[int, List[str]]: + self, + img: np.ndarray, + sorted_polygons: np.ndarray, + cell_box_map: Dict[int, List[str]], + ) -> Dict[int, List[any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" - for k, v in cell_box_map.items(): - if v[0]: + # + for i in range(sorted_polygons.shape[0]): + if cell_box_map.get(i): continue - - crop_img = get_rotate_crop_image(img, sorted_polygons[k]) + crop_img = get_rotate_crop_image(img, sorted_polygons[i]) pad_img = cv2.copyMakeBorder( - crop_img, 2, 2, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) + crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) ) rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True) - cell_box_map[k] = [rec_res[0][0]] + box = sorted_polygons[i] + text = [rec[0] for rec in rec_res] + scores = [rec[1] for rec in rec_res] + cell_box_map[i] = [[box, "".join(text), min(scores)]] return cell_box_map def main(): parser = argparse.ArgumentParser() parser.add_argument("-img", "--img_path", type=str, required=True) + parser.add_argument( "--output_dir", default= "./output", type=str) args = parser.parse_args() - + # args.img_path = '../images/image (78).png' table_rec = LinelessTableRecognition() - table_str, elapse = table_rec(args.img_path) - print(table_str) - print(f"cost: {elapse:.5f}") + html, elasp, polygons, logic_points, ocr_res = table_rec(args.img_path) + print(f"cost: {elasp:.5f}") + complete_html = format_html(html) + os.makedirs(os.path.dirname(f'{args.output_dir}/table.html'), exist_ok=True) + with open(f'{args.output_dir}/table.html', 'w', encoding='utf-8') as file: + file.write(complete_html) + plot_rec_box_with_logic_info(args.img_path, f'{args.output_dir}/table_rec_box.jpg', logic_points, polygons) + plot_rec_box(args.img_path, f'{args.output_dir}/ocr_box.jpg', ocr_res) + if __name__ == "__main__": diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index 56ac401..a26cac7 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -1,12 +1,14 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import os import random -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Iterable, Union, Any import cv2 import numpy as np import shapely +from numpy import ndarray from shapely.geometry import MultiPoint, Polygon @@ -26,8 +28,8 @@ def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray: for i in range(num_boxes - 1): for j in range(i, -1, -1): if ( - abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 - and _boxes[j + 1][0][0] < _boxes[j][0][0] + abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 + and _boxes[j + 1][0][0] < _boxes[j][0][0] ): _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j] else: @@ -66,6 +68,227 @@ def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float: return float(inter_area) / union_area +def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: + """ + :param table_boxes: [[xmin,ymin,xmax,ymax]] + :return: + """ + delete_idx = set() + for i in range(len(table_boxes)): + polygons_i = table_boxes[i] + if i in delete_idx: + continue + for j in range(i + 1, len(table_boxes)): + if j in delete_idx: + continue + # 下一个box + polygons_j = table_boxes[j] + # 重叠关系先记录,后续删除掉 + if calculate_iou(polygons_i, polygons_j) > 0.8: + delete_idx.add(j) + continue + # 是否存在包含关系 + contained_idx = is_box_contained(polygons_i, polygons_j) + if contained_idx == 2: + delete_idx.add(j) + elif contained_idx == 1: + delete_idx.add(i) + return delete_idx + + +def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: iou: float 0-1 + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + # 不相交直接退出检测 + if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2: + return 0.0 + # 计算交集 + inter_x1 = max(b1_x1, b2_x1) + inter_y1 = max(b1_y1, b2_y1) + inter_x2 = min(b1_x2, b2_x2) + inter_y2 = min(b1_y2, b2_y2) + i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) + + # 计算并集 + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + u_area = b1_area + b2_area - i_area + + # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉 + if u_area == 0: + return 1 + # 检查完全包含 + iou = i_area / u_area + return iou + + +def caculate_single_axis_iou(box1: list | np.ndarray, box2: list | np.ndarray, axis='x') -> float: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: iou: float 0-1 + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1 + b2_x1, b2_y1, b2_x2, b2_y2 = box2 + if axis == 'x': + i_min = max(b1_x1, b2_x1) + i_max = min(b1_x2, b2_x2) + u_area = max(b1_x2, b2_x2) - min(b1_x1, b2_x1) + else: + i_min = max(b1_y1, b2_y1) + i_max = min(b1_y2, b2_y2) + u_area = max(b1_y2, b2_y2) - min(b1_y1, b2_y1) + i_area = max(i_max - i_min, 0) + if u_area == 0: + return 1 + return i_area / u_area + + +def is_box_contained(box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2) -> int | None: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: 1: box1 is contained 2: box2 is contained None: no contain these + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + # 不相交直接退出检测 + if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2: + return None + # 计算box2的总面积 + b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + + # 计算box1和box2的交集 + intersect_x1 = max(b1_x1, b2_x1) + intersect_y1 = max(b1_y1, b2_y1) + intersect_x2 = min(b1_x2, b2_x2) + intersect_y2 = min(b1_y2, b2_y2) + + # 计算交集的面积 + intersect_area = max(0, intersect_x2 - intersect_x1) * max(0, intersect_y2 - intersect_y1) + + # 计算外面的面积 + b1_outside_area = b1_area - intersect_area + b2_outside_area = b2_area - intersect_area + + # 计算外面的面积占box2总面积的比例 + ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0 + ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0 + + if ratio_b1 < threshold: + return 1 + if ratio_b2 < threshold: + return 2 + # 判断比例是否大于阈值 + return None + + +def is_single_axis_contained(box1: list | np.ndarray, box2: list | np.ndarray, axis='x', threshold=0.2) -> int | None: + """ + :param box1: Iterable [xmin,ymin,xmax,ymax] + :param box2: Iterable [xmin,ymin,xmax,ymax] + :return: 1: box1 is contained 2: box2 is contained None: no contain these + """ + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + + # 计算轴重叠大小 + if axis == 'x': + b1_area = (b1_x2 - b1_x1) + b2_area = (b2_x2 - b2_x1) + i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1) + else: + b1_area = (b1_y2 - b1_y1) + b2_area = (b2_y2 - b2_y1) + i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1) + # 计算外面的面积 + b1_outside_area = b1_area - i_area + b2_outside_area = b2_area - i_area + + ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0 + ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0 + if ratio_b1 < threshold: + return 1 + if ratio_b2 < threshold: + return 2 + return None + + +def sorted_ocr_boxes(dt_boxes: np.ndarray | list) -> tuple[np.ndarray | list, list[int]]: + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax) + return: + sorted boxes(array) with (xmin, ymin, xmax, ymax) + """ + num_boxes = len(dt_boxes) + indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)] + sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes, indices = zip(*sorted_boxes_with_idx) + indices = list(indices) + _boxes = [dt_boxes[i] for i in indices] + for i in range(num_boxes - 1): + for j in range(i, -1, -1): + c_idx = is_single_axis_contained(_boxes[j], _boxes[j + 1], axis='y') + if c_idx is not None and _boxes[j + 1][0] < _boxes[j][0]: + _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j] + indices[j], indices[j + 1] = indices[j + 1], indices[j] + else: + break + return _boxes, indices + + +def gather_ocr_list_by_row(ocr_list: list[list[list[float], str]]) -> list[list[list[float], str]]: + """ + :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] + :return: + """ + for i in range(len(ocr_list)): + if not ocr_list[i]: + continue + for j in range(i + 1, len(ocr_list)): + if not ocr_list[j]: + continue + cur = ocr_list[i] + next = ocr_list[j] + cur_box = cur[0] + next_box = next[0] + c_idx = is_single_axis_contained(cur[0], next[0], axis='y') + if c_idx: + cur[1] = cur[1] + next[1] + xmin = min(cur_box[0], next_box[0]) + xmax = max(cur_box[2], next_box[2]) + ymin = min(cur_box[1], next_box[1]) + ymax = max(cur_box[3], next_box[3]) + cur_box[0] = xmin + cur_box[1] = ymin + cur_box[2] = xmax + cur_box[3] = ymax + ocr_list[j] = None + ocr_list = [x for x in ocr_list if x] + return ocr_list + +def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: + xmin, ymin, xmax, ymax = tuple(poly_box) + return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + + +def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[ndarray[Any, Any]]: + """ + 将poly_box转换为box_4_1 + :param poly_box: + :return: + """ + return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]] + + def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray: """合并相邻iou大于阈值的框""" combine_iou_thresh = 0.1 @@ -122,91 +345,84 @@ def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray: return polygons -def match_ocr_cell( - polygons: np.ndarray, ocr_res: List[Tuple[np.ndarray, str, str]] -) -> Dict[int, List]: - cell_box_map = {} - dt_boxes, rec_res, _ = list(zip(*ocr_res)) - dt_boxes = np.array(dt_boxes) - iou_thresh = 0.009 - for i, cell_box in enumerate(polygons): - ious = [compute_poly_iou(dt_box, cell_box) for dt_box in dt_boxes] - - # 对有iou的值,计算是否存在包含关系。如存在→iou=1 - have_iou_idxs = np.argwhere(ious) - if have_iou_idxs.size > 0: - have_iou_idxs = have_iou_idxs.squeeze(1) - for idx in have_iou_idxs: - if is_inclusive_each_other(cell_box, dt_boxes[idx]): - ious[idx] = 1.0 - - if all(x <= iou_thresh for x in ious): - # 说明这个cell中没有文本 - cell_box_map.setdefault(i, []).append("") - continue - - same_cell_idxs = np.argwhere(np.array(ious) >= iou_thresh).squeeze(1) - one_cell_txts = "\n".join([rec_res[idx] for idx in same_cell_idxs]) - cell_box_map.setdefault(i, []).append(one_cell_txts) - return cell_box_map - - -def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): - """判断两个多边形框是否存在包含关系 - - Args: - box1 (np.ndarray): (4, 2) - box2 (np.ndarray): (4, 2) - - Returns: - bool: 是否存在包含关系 +def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray): """ - poly1 = Polygon(box1) - poly2 = Polygon(box2) - - poly1_area = poly1.convex_hull.area - poly2_area = poly2.convex_hull.area - - if poly1_area > poly2_area: - box_max = box1 - box_min = box2 - else: - box_max = box2 - box_min = box1 - - x0, y0 = np.min(box_min[:, 0]), np.min(box_min[:, 1]) - x1, y1 = np.max(box_min[:, 0]), np.max(box_min[:, 1]) - - edge_x0, edge_y0 = np.min(box_max[:, 0]), np.min(box_max[:, 1]) - edge_x1, edge_y1 = np.max(box_max[:, 0]), np.max(box_max[:, 1]) - - if x0 >= edge_x0 and y0 >= edge_y0 and x1 <= edge_x1 and y1 <= edge_y1: - return True - return False - - -def plot_html_table(logi_points: np.ndarray, cell_box_map: Dict[int, List[str]]) -> str: - logi_points = logi_points.astype(np.int32) - table_dict = {} - for cell_idx, v in enumerate(logi_points): - cur_row = v[0] - cur_txt = "\n".join(cell_box_map.get(cell_idx)) - sr, er, sc, ec = v.tolist() - rowspan, colspan = er - sr + 1, ec - sc + 1 - table_str = f'{cur_txt}' - # table_str = f'
{cur_txt}
' - table_dict.setdefault(cur_row, []).append(table_str) - - new_table_dict = {} - for k, v in table_dict.items(): - new_table_dict[k] = [""] + v + [""] + :param dt_rec_boxes: [[(4.2), text, score]] + :param pred_bboxes: shap (4,2) + :return: + """ + matched = {} + not_match_orc_boxes = [] + for i, gt_box in enumerate(dt_rec_boxes): + for j, pred_box in enumerate(pred_bboxes): + pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]] + ocr_boxes = gt_box[0] + # xmin,ymin,xmax,ymax + ocr_box = (ocr_boxes[0][0], ocr_boxes[0][1], ocr_boxes[2][0], ocr_boxes[2][1]) + contained = is_box_contained(ocr_box, pred_box, 0.6) + if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8: + if j not in matched.keys(): + matched[j] = [gt_box] + else: + matched[j].append(gt_box) + else: + not_match_orc_boxes.append(gt_box) + + return matched, not_match_orc_boxes + + +def plot_html_table(logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]]) -> str: + # 初始化最大行数和列数 + max_row = 0 + max_col = 0 + # 计算最大行数和列数 + for point in logi_points: + max_row = max(max_row, point[1] + 1) # 加1是因为结束下标是包含在内的 + max_col = max(max_col, point[3] + 1) # 加1是因为结束下标是包含在内的 + + # 创建一个二维数组来存储 sorted_logi_points 中的元素 + grid = [[None] * max_col for _ in range(max_row)] + + # 将 sorted_logi_points 中的元素填充到 grid 中 + for i, logic_point in enumerate(logi_points): + row_start, row_end, col_start, col_end = logic_point[0],logic_point[1],logic_point[2],logic_point[3] + for row in range(row_start, row_end + 1): + for col in range(col_start, col_end + 1): + grid[row][col] = (i, row_start, row_end, col_start, col_end) + + # 创建表格 + table_html = "\n" + + # 遍历每行 + for row in range(max_row): + empty_temp = True + temp = " \n" + + # 遍历每一列 + for col in range(max_col): + if not grid[row][col]: + temp += " \n" + else: + i, row_start, row_end, col_start, col_end = grid[row][col] + if not cell_box_map.get(i): + continue + empty_temp = False + if (row == row_start and col == col_start): + ocr_rec_text = cell_box_map.get(i) + text = "
".join(ocr_rec_text) + if not text.strip(): + continue + # 如果是起始单元格 + row_span = row_end - row_start + 1 + col_span = col_end - col_start + 1 + cell_content = f"\n" + temp += cell_content + if not empty_temp: + table_html = table_html + temp + " \n" + + table_html += "
{text}
" + return table_html - html_start = """""" - # html_start = """
""" - html_end = "
" - html_middle = "".join([vv for v in new_table_dict.values() for vv in v]) - table_str = f"{html_start}{html_middle}{html_end}" - return table_str def vis_table(img: np.ndarray, polygons: np.ndarray) -> np.ndarray: @@ -223,7 +439,105 @@ def vis_table(img: np.ndarray, polygons: np.ndarray) -> np.ndarray: cv2.putText(img, str(i), poly[0], font, 1, (0, 0, 255), 2) return img +def plot_rec_box_with_logic_info(img_path, output_path, logic_points, sorted_polygons): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.imread(img_path) + img = cv2.copyMakeBorder(img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 1.0 # 原先是0.5 + thickness = 2 # 原先是1 + + cv2.putText( + img, + f'{idx}-{logic_points[idx]}', + (x1, y1), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + cv2.imwrite(output_path, img) +def plot_rec_box(img_path, output_path, sorted_polygons): + """ + :param img_path + :param output_path + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 处理ocr_res + img = cv2.imread(img_path) + img = cv2.copyMakeBorder(img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]) + # 绘制 ocr_res 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 1.0 # 原先是0.5 + thickness = 2 # 原先是1 + + cv2.putText( + img, + str(idx), + (x1, y1), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + cv2.imwrite(output_path, img) + +def format_html(html): + + return f""" + + + + + Complex Table Example + + + + {html} + + + """ def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: img_crop_width = int( max(