Skip to content

Commit

Permalink
chore: Optim code logic and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 20, 2024
1 parent 171d494 commit 42f9e16
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 32 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
</div>

### 简介
主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。

目前支持三种类别的版面分析模型:中文、英文和表格版面分析模型,具体可参见下面表格:

|`model_type`| 版面类型 | 模型名称 | 支持类别|
| :------ | :----- | :------ | :----- |
|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` |
| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` |
| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |
| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |
| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption` <br> `toc` |
|`pp_layout_table`| 表格 | `layout_table.onnx` |`["table"]` |
| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`["text", "title", "list", "table", "figure"]` |
| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` |
| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` |
| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `['text', 'title', 'header', 'footer', 'figure', 'figure_caption', 'table', 'table_caption', 'toc']` |

PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)

Expand Down
2 changes: 1 addition & 1 deletion docs/doc_whl_rapid_layout.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
See [link](https://github.com/RapidAI/RapidStructure) for details.
See [link](https://github.com/RapidAI/RapidLayout) for details.
27 changes: 20 additions & 7 deletions rapid_layout/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
DownloadModel,
LoadImage,
OrtInferSession,
PicoDetPostProcess,
PPPostProcess,
PPPreProcess,
VisLayout,
YOLOv8PostProcess,
Expand All @@ -36,32 +36,39 @@


class RapidLayout:

def __init__(
self,
model_type: str = "pp_layout_cdla",
model_path: Union[str, Path, None] = None,
conf_thres: float = 0.5,
iou_thres: float = 0.5,
use_cuda: bool = False,
use_dml: bool = False,
):
if not self.check_of(conf_thres):
raise ValueError(f"conf_thres {conf_thres} is outside of range [0, 1]")

if not self.check_of(iou_thres):
raise ValueError(f"iou_thres {conf_thres} is outside of range [0, 1]")

self.model_type = model_type
config = {
"model_path": self.get_model_path(model_type, model_path),
"use_cuda": use_cuda,
"use_dml": use_dml,
}
self.session = OrtInferSession(config)
labels = self.session.get_character_list()
logger.info("%s contains %s", model_type, labels)

# pp
self.pp_preprocess = PPPreProcess(img_size=(800, 608))
self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres)
self.pp_postprocess = PPPostProcess(labels, conf_thres, iou_thres)

# yolov8
self.yolov8_input_shape = (640, 640)
self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)
self.yolov8_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
self.yolov8_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)

self.load_img = LoadImage()

Expand Down Expand Up @@ -97,9 +104,9 @@ def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
return boxes, scores, class_names, elapse

def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
input_tensor = self.yolo_preprocess(img)
input_tensor = self.yolov8_preprocess(img)
outputs = self.session(input_tensor)
boxes, scores, class_names = self.yolo_postprocess(
boxes, scores, class_names = self.yolov8_postprocess(
outputs, ori_img_shape, self.yolov8_input_shape
)
return boxes, scores, class_names
Expand All @@ -117,6 +124,12 @@ def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH)
return DEFAULT_MODEL_PATH

@staticmethod
def check_of(thres: float) -> bool:
if 0 <= thres <= 1.0:
return True
return False


def main():
parser = argparse.ArgumentParser()
Expand Down
12 changes: 2 additions & 10 deletions rapid_layout/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import yaml

from .download_model import DownloadModel
from .infer_engine import OrtInferSession
from .load_image import LoadImage
from .load_image import LoadImage, LoadImageError
from .logger import get_logger
from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess
from .post_prepross import PPPostProcess, YOLOv8PostProcess
from .pre_procss import PPPreProcess, YOLOv8PreProcess
from .vis_res import VisLayout


def read_yaml(yaml_path):
with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
5 changes: 1 addition & 4 deletions rapid_layout/utils/post_prepross.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np


class PicoDetPostProcess:
class PPPostProcess:
def __init__(self, labels, conf_thres=0.4, iou_thres=0.5):
self.labels = labels
self.strides = [8, 16, 32, 64]
Expand Down Expand Up @@ -247,7 +247,6 @@ def area_of(left_top, right_bottom):


class YOLOv8PostProcess:

def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
self.labels = labels
self.conf_threshold = conf_thres
Expand Down Expand Up @@ -297,7 +296,6 @@ def extract_boxes(self, predictions):
return boxes

def rescale_boxes(self, boxes):

# Rescale boxes to original image dimensions
input_shape = np.array(
[self.input_width, self.input_height, self.input_width, self.input_height]
Expand Down Expand Up @@ -332,7 +330,6 @@ def nms(boxes, scores, iou_threshold):


def multiclass_nms(boxes, scores, class_ids, iou_threshold):

unique_class_ids = np.unique(class_ids)

keep_boxes = []
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
onnxruntime>=1.7.0
PyYAML>=6.0
opencv_python>=4.5.1.48
numpy>=1.21.6,<2
Pillow
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_readme():
include_package_data=True,
install_requires=read_txt("requirements.txt"),
packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"],
package_data={"": ["layout_cdla.onnx", "*.yaml"]},
package_data={"": ["layout_cdla.onnx"]},
keywords=["ppstructure,layout,rapidocr,rapid_layout"],
classifiers=[
"Programming Language :: Python :: 3.6",
Expand Down
24 changes: 22 additions & 2 deletions tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,37 @@
sys.path.append(str(root_dir))

from rapid_layout import RapidLayout
from rapid_layout.utils import LoadImageError

test_file_dir = cur_dir / "test_files"
img_path = test_file_dir / "layout.png"

img = cv2.imread(str(img_path))


def test_iou_outside_thres():
with pytest.raises(ValueError) as exc:
engine = RapidLayout(iou_thres=1.2)
assert exc.type is ValueError


def test_conf_outside_thres():
with pytest.raises(ValueError) as exc:
engine = RapidLayout(conf_thres=1.2)
assert exc.type is ValueError


def test_empty():
with pytest.raises(LoadImageError) as exc:
engine = RapidLayout()
engine(None)
assert exc.type is LoadImageError


@pytest.mark.parametrize(
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
def test_multi_input(img_content):
def test_pp_layout(img_content):
engine = RapidLayout()
boxes, scores, class_names, *elapse = engine(img_content)
assert len(boxes) == 15
Expand All @@ -32,7 +52,7 @@ def test_multi_input(img_content):
@pytest.mark.parametrize(
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
def test_yolov8_input(img_content):
def test_yolov8_layout(img_content):
engine = RapidLayout(model_type="yolov8n_layout_paper")
boxes, scores, class_names, *elapse = engine(img_content)
assert len(boxes) == 11

0 comments on commit 42f9e16

Please sign in to comment.