diff --git a/README.md b/README.md
index 1b6c44f..f14eec1 100644
--- a/README.md
+++ b/README.md
@@ -13,17 +13,17 @@
### 简介
-主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
+主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。
目前支持三种类别的版面分析模型:中文、英文和表格版面分析模型,具体可参见下面表格:
|`model_type`| 版面类型 | 模型名称 | 支持类别|
| :------ | :----- | :------ | :----- |
-|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` |
-| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` |
-| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` |
-| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` |
-| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption`
`toc` |
+|`pp_layout_table`| 表格 | `layout_table.onnx` |`["table"]` |
+| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`["text", "title", "list", "table", "figure"]` |
+| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` |
+| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` |
+| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `['text', 'title', 'header', 'footer', 'figure', 'figure_caption', 'table', 'table_caption', 'toc']` |
PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
diff --git a/docs/doc_whl_rapid_layout.md b/docs/doc_whl_rapid_layout.md
index 8163c11..9c23ec9 100644
--- a/docs/doc_whl_rapid_layout.md
+++ b/docs/doc_whl_rapid_layout.md
@@ -1 +1 @@
-See [link](https://github.com/RapidAI/RapidStructure) for details.
+See [link](https://github.com/RapidAI/RapidLayout) for details.
diff --git a/rapid_layout/main.py b/rapid_layout/main.py
index 14cb89f..0cc7408 100644
--- a/rapid_layout/main.py
+++ b/rapid_layout/main.py
@@ -13,7 +13,7 @@
DownloadModel,
LoadImage,
OrtInferSession,
- PicoDetPostProcess,
+ PPPostProcess,
PPPreProcess,
VisLayout,
YOLOv8PostProcess,
@@ -36,7 +36,6 @@
class RapidLayout:
-
def __init__(
self,
model_type: str = "pp_layout_cdla",
@@ -44,11 +43,19 @@ def __init__(
conf_thres: float = 0.5,
iou_thres: float = 0.5,
use_cuda: bool = False,
+ use_dml: bool = False,
):
+ if not self.check_of(conf_thres):
+ raise ValueError(f"conf_thres {conf_thres} is outside of range [0, 1]")
+
+ if not self.check_of(iou_thres):
+ raise ValueError(f"iou_thres {conf_thres} is outside of range [0, 1]")
+
self.model_type = model_type
config = {
"model_path": self.get_model_path(model_type, model_path),
"use_cuda": use_cuda,
+ "use_dml": use_dml,
}
self.session = OrtInferSession(config)
labels = self.session.get_character_list()
@@ -56,12 +63,12 @@ def __init__(
# pp
self.pp_preprocess = PPPreProcess(img_size=(800, 608))
- self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres)
+ self.pp_postprocess = PPPostProcess(labels, conf_thres, iou_thres)
# yolov8
self.yolov8_input_shape = (640, 640)
- self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
- self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)
+ self.yolov8_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
+ self.yolov8_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)
self.load_img = LoadImage()
@@ -97,9 +104,9 @@ def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
return boxes, scores, class_names, elapse
def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
- input_tensor = self.yolo_preprocess(img)
+ input_tensor = self.yolov8_preprocess(img)
outputs = self.session(input_tensor)
- boxes, scores, class_names = self.yolo_postprocess(
+ boxes, scores, class_names = self.yolov8_postprocess(
outputs, ori_img_shape, self.yolov8_input_shape
)
return boxes, scores, class_names
@@ -117,6 +124,12 @@ def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH)
return DEFAULT_MODEL_PATH
+ @staticmethod
+ def check_of(thres: float) -> bool:
+ if 0 <= thres <= 1.0:
+ return True
+ return False
+
def main():
parser = argparse.ArgumentParser()
diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py
index 0cadd2a..ed27313 100644
--- a/rapid_layout/utils/__init__.py
+++ b/rapid_layout/utils/__init__.py
@@ -1,18 +1,10 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
-import yaml
-
from .download_model import DownloadModel
from .infer_engine import OrtInferSession
-from .load_image import LoadImage
+from .load_image import LoadImage, LoadImageError
from .logger import get_logger
-from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess
+from .post_prepross import PPPostProcess, YOLOv8PostProcess
from .pre_procss import PPPreProcess, YOLOv8PreProcess
from .vis_res import VisLayout
-
-
-def read_yaml(yaml_path):
- with open(yaml_path, "rb") as f:
- data = yaml.load(f, Loader=yaml.Loader)
- return data
diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py
index eacd006..516fc33 100644
--- a/rapid_layout/utils/post_prepross.py
+++ b/rapid_layout/utils/post_prepross.py
@@ -6,7 +6,7 @@
import numpy as np
-class PicoDetPostProcess:
+class PPPostProcess:
def __init__(self, labels, conf_thres=0.4, iou_thres=0.5):
self.labels = labels
self.strides = [8, 16, 32, 64]
@@ -247,7 +247,6 @@ def area_of(left_top, right_bottom):
class YOLOv8PostProcess:
-
def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
self.labels = labels
self.conf_threshold = conf_thres
@@ -297,7 +296,6 @@ def extract_boxes(self, predictions):
return boxes
def rescale_boxes(self, boxes):
-
# Rescale boxes to original image dimensions
input_shape = np.array(
[self.input_width, self.input_height, self.input_width, self.input_height]
@@ -332,7 +330,6 @@ def nms(boxes, scores, iou_threshold):
def multiclass_nms(boxes, scores, class_ids, iou_threshold):
-
unique_class_ids = np.unique(class_ids)
keep_boxes = []
diff --git a/requirements.txt b/requirements.txt
index 3b141d5..b004a57 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,4 @@
onnxruntime>=1.7.0
-PyYAML>=6.0
opencv_python>=4.5.1.48
numpy>=1.21.6,<2
Pillow
diff --git a/setup.py b/setup.py
index 464eb69..7b48c64 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@ def get_readme():
include_package_data=True,
install_requires=read_txt("requirements.txt"),
packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"],
- package_data={"": ["layout_cdla.onnx", "*.yaml"]},
+ package_data={"": ["layout_cdla.onnx"]},
keywords=["ppstructure,layout,rapidocr,rapid_layout"],
classifiers=[
"Programming Language :: Python :: 3.6",
diff --git a/tests/test_layout.py b/tests/test_layout.py
index fa56d7b..78b639a 100644
--- a/tests/test_layout.py
+++ b/tests/test_layout.py
@@ -13,6 +13,7 @@
sys.path.append(str(root_dir))
from rapid_layout import RapidLayout
+from rapid_layout.utils import LoadImageError
test_file_dir = cur_dir / "test_files"
img_path = test_file_dir / "layout.png"
@@ -20,10 +21,29 @@
img = cv2.imread(str(img_path))
+def test_iou_outside_thres():
+ with pytest.raises(ValueError) as exc:
+ engine = RapidLayout(iou_thres=1.2)
+ assert exc.type is ValueError
+
+
+def test_conf_outside_thres():
+ with pytest.raises(ValueError) as exc:
+ engine = RapidLayout(conf_thres=1.2)
+ assert exc.type is ValueError
+
+
+def test_empty():
+ with pytest.raises(LoadImageError) as exc:
+ engine = RapidLayout()
+ engine(None)
+ assert exc.type is LoadImageError
+
+
@pytest.mark.parametrize(
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
-def test_multi_input(img_content):
+def test_pp_layout(img_content):
engine = RapidLayout()
boxes, scores, class_names, *elapse = engine(img_content)
assert len(boxes) == 15
@@ -32,7 +52,7 @@ def test_multi_input(img_content):
@pytest.mark.parametrize(
"img_content", [img_path, str(img_path), open(img_path, "rb").read(), img]
)
-def test_yolov8_input(img_content):
+def test_yolov8_layout(img_content):
engine = RapidLayout(model_type="yolov8n_layout_paper")
boxes, scores, class_names, *elapse = engine(img_content)
assert len(boxes) == 11