Skip to content

Commit

Permalink
Merge pull request #21 from Joker1212/lineless
Browse files Browse the repository at this point in the history
feature: optimize lineless table rec
  • Loading branch information
SWHL authored Sep 11, 2024
2 parents a7dfe47 + 9da519c commit 8ccbaa5
Show file tree
Hide file tree
Showing 3 changed files with 547 additions and 138 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,5 @@ long1.jpg
*.pdmodel

.DS_Store
*.npy
*.npy
/lineless_table_rec/output/
196 changes: 145 additions & 51 deletions lineless_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @Contact: liekkaskono@163.com
import argparse
import logging
import os
import time
import traceback
from pathlib import Path
Expand All @@ -12,13 +13,13 @@
import numpy as np
from rapidocr_onnxruntime import RapidOCR

from .lineless_table_process import DetProcess, get_affine_transform_upper_left
from .utils import InputType, LoadImage, OrtInferSession
from .utils_table_recover import (
from lineless_table_process import DetProcess, get_affine_transform_upper_left
from utils import InputType, LoadImage, OrtInferSession
from utils_table_recover import (
get_rotate_crop_image,
match_ocr_cell,
plot_html_table,
sorted_boxes,
sorted_ocr_boxes, box_4_2_poly_to_box_4_1, match_ocr_cell,
filter_duplicated_box, gather_ocr_list_by_row, plot_rec_box_with_logic_info, plot_rec_box, format_html,
)

cur_dir = Path(__file__).resolve().parent
Expand All @@ -28,9 +29,9 @@

class LinelessTableRecognition:
def __init__(
self,
detect_model_path: Union[str, Path] = detect_model_path,
process_model_path: Union[str, Path] = process_model_path,
self,
detect_model_path: Union[str, Path] = detect_model_path,
process_model_path: Union[str, Path] = process_model_path,
):
self.mean = np.array([0.408, 0.447, 0.470], dtype=np.float32).reshape(1, 1, 3)
self.std = np.array([0.289, 0.274, 0.278], dtype=np.float32).reshape(1, 1, 3)
Expand All @@ -45,32 +46,70 @@ def __init__(
self.det_process = DetProcess()
self.ocr = RapidOCR()

def __call__(self, content: InputType) -> str:
def __call__(self, content: InputType):
ss = time.perf_counter()
img = self.load_img(content)

ocr_res, _ = self.ocr(img)

input_info = self.preprocess(img)
try:
polygons, slct_logi = self.infer(input_info)
logi_points = self.filter_logi_points(slct_logi)
# ocr 结果匹配
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_res, polygons)
# 如果有识别框没有ocr结果,直接进行rec补充
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
# 拆分包含和重叠的识别框
deleted_idx_set = filter_duplicated_box([table_box_ocr['t_box'] for table_box_ocr in t_rec_ocr_list])
t_rec_ocr_list = [t_rec_ocr_list[i] for i in range(len(t_rec_ocr_list)) if i not in deleted_idx_set]
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list)
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
# 将同一个识别框中的ocr结果排序并同行合并
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
# 渲染为html
logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
cell_box_det_map = {
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr['t_ocr_res']]
for i, t_box_ocr in enumerate(t_rec_ocr_list)
}
table_str = plot_html_table(logi_points, cell_box_det_map)

sorted_polygons = sorted_boxes(polygons)

cell_box_map = match_ocr_cell(sorted_polygons, ocr_res)
cell_box_map = self.re_rec(img, sorted_polygons, cell_box_map)

logi_points = self.sort_logi_by_polygons(
sorted_polygons, polygons, logi_points
)

table_str = plot_html_table(logi_points, cell_box_map)
# 输出可视化排序,用于验证结果,生产版本可以去掉
_, idx_list = sorted_ocr_boxes([t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list])
t_rec_ocr_list = [t_rec_ocr_list[i] for i in idx_list]
sorted_polygons = [t_box_ocr['t_box'] for t_box_ocr in t_rec_ocr_list]
sorted_logi_points = [t_box_ocr['t_logic_box'] for t_box_ocr in t_rec_ocr_list]
ocr_boxes_res = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
table_elapse = time.perf_counter() - ss
return table_str, table_elapse
return table_str, table_elapse, sorted_polygons, sorted_logi_points, sorted_ocr_boxes_res
except Exception:
logging.warning(traceback.format_exc())
return "", 0.0
return "", 0.0, None, None, None

def transform_res(self, cell_box_det_map: dict[int, List[any]], polygons: np.ndarray,
logi_points: list[np.ndarray]) -> list[dict[str, any]]:
res = []
for i in range(len(polygons)):
ocr_res_list = cell_box_det_map.get(i)
if not ocr_res_list:
continue
xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
dict_res = {
# xmin,xmax,ymin,ymax
't_box': [xmin, ymin, xmax, ymax],
# row_start,row_end,col_start,col_end
't_logic_box': logi_points[i].tolist(),
# [[xmin,xmax,ymin,ymax], text]
't_ocr_res': [[box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]] for ocr_det in ocr_res_list]
}
res.append(dict_res)
return res

def preprocess(self, img: np.ndarray) -> Dict[str, Any]:
height, width = img.shape[:2]
Expand Down Expand Up @@ -115,52 +154,107 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
)
return slct_output_dets, slct_logi

def filter_logi_points(self, slct_logi: np.ndarray) -> Dict[str, Any]:
def sort_and_gather_ocr_res(self, res):
for i, dict_res in enumerate(res):
dict_res['t_ocr_res'] = gather_ocr_list_by_row(dict_res['t_ocr_res'])
_, sorted_idx = sorted_ocr_boxes([ocr_det[0] for ocr_det in dict_res['t_ocr_res']])
dict_res['t_ocr_res'] = [dict_res['t_ocr_res'][i] for i in sorted_idx]
return res

def handle_overlap_row_col(self, res):
max_row, max_col = 0, 0
for dict_res in res:
max_row = max(max_row, dict_res['t_logic_box'][1] + 1) # 加1是因为结束下标是包含在内的
max_col = max(max_col, dict_res['t_logic_box'][3] + 1) # 加1是因为结束下标是包含在内的
# 创建一个二维数组来存储 sorted_logi_points 中的元素
grid = [[None] * max_col for _ in range(max_row)]
# 将 sorted_logi_points 中的元素填充到 grid 中
deleted_idx = set()
for i, dict_res in enumerate(res):
if i in deleted_idx:
continue
row_start, row_end, col_start, col_end = dict_res['t_logic_box']
for row in range(row_start, row_end + 1):
if i in deleted_idx:
continue
for col in range(col_start, col_end + 1):
if i in deleted_idx:
continue
exist_dict_res = grid[row][col]
if not exist_dict_res:
grid[row][col] = dict_res
continue
if exist_dict_res['t_logic_box'] == dict_res['t_logic_box']:
exist_dict_res['t_ocr_res'].extend(dict_res['t_ocr_res'])
deleted_idx.add(i)
# 修正识别框坐标
exist_dict_res['t_box'] = [min(exist_dict_res['t_box'][0], dict_res['t_box'][0]),
min(exist_dict_res['t_box'][1], dict_res['t_box'][1]),
max(exist_dict_res['t_box'][2], dict_res['t_box'][2]),
max(exist_dict_res['t_box'][3], dict_res['t_box'][3]),
]
continue

# 去掉重叠框
res = [res[i] for i in range(len(res)) if i not in deleted_idx]
return res, grid

@staticmethod
def filter_logi_points(slct_logi: np.ndarray) -> list[np.ndarray]:
for logic_points in slct_logi[0]:
# 修正坐标接近导致的r_e > r_s 或 c_e > c_s
if abs(logic_points[0] - logic_points[1]) < 0.2:
row = (logic_points[0] + logic_points[1]) / 2
logic_points[0] = row
logic_points[1] = row
if abs(logic_points[2] - logic_points[3]) < 0.2:
col = (logic_points[2] + logic_points[3]) / 2
logic_points[2] = col
logic_points[3] = col
logi_floor = np.floor(slct_logi)
dev = slct_logi - logi_floor
slct_logi = np.where(dev > 0.5, logi_floor + 1, logi_floor)
return slct_logi[0]

@staticmethod
def sort_logi_by_polygons(
sorted_polygons: np.ndarray, polygons: np.ndarray, logi_points: np.ndarray
) -> np.ndarray:
sorted_idx = []
for v in sorted_polygons:
loc_idx = np.argwhere(v[0, 0] == polygons[:, 0, 0]).squeeze()
sorted_idx.append(int(loc_idx))
logi_points = logi_points[sorted_idx]
return logi_points
return slct_logi[0].astype(np.int32)

def re_rec(
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
) -> Dict[int, List[str]]:
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
) -> Dict[int, List[any]]:
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
for k, v in cell_box_map.items():
if v[0]:
#
for i in range(sorted_polygons.shape[0]):
if cell_box_map.get(i):
continue

crop_img = get_rotate_crop_image(img, sorted_polygons[k])
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
pad_img = cv2.copyMakeBorder(
crop_img, 2, 2, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
cell_box_map[k] = [rec_res[0][0]]
box = sorted_polygons[i]
text = [rec[0] for rec in rec_res]
scores = [rec[1] for rec in rec_res]
cell_box_map[i] = [[box, "".join(text), min(scores)]]
return cell_box_map


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, required=True)
parser.add_argument( "--output_dir", default= "./output", type=str)
args = parser.parse_args()

# args.img_path = '../images/image (78).png'
table_rec = LinelessTableRecognition()
table_str, elapse = table_rec(args.img_path)
print(table_str)
print(f"cost: {elapse:.5f}")
html, elasp, polygons, logic_points, ocr_res = table_rec(args.img_path)
print(f"cost: {elasp:.5f}")
complete_html = format_html(html)
os.makedirs(os.path.dirname(f'{args.output_dir}/table.html'), exist_ok=True)
with open(f'{args.output_dir}/table.html', 'w', encoding='utf-8') as file:
file.write(complete_html)
plot_rec_box_with_logic_info(args.img_path, f'{args.output_dir}/table_rec_box.jpg', logic_points, polygons)
plot_rec_box(args.img_path, f'{args.output_dir}/ocr_box.jpg', ocr_res)



if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 8ccbaa5

Please sign in to comment.