diff --git a/README.md b/README.md index 1b6c44f..f14eec1 100644 --- a/README.md +++ b/README.md @@ -13,17 +13,17 @@ ### 简介 -主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。 +主要是做文档类图像的版面分析。具体来说,就是分析给定的文档类别图像(论文截图、研报等),定位其中类别和位置,如标题、段落、表格和图片等各个部分。 目前支持三种类别的版面分析模型:中文、英文和表格版面分析模型,具体可参见下面表格: |`model_type`| 版面类型 | 模型名称 | 支持类别| | :------ | :----- | :------ | :----- | -|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` | -| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` | -| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` | -| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption`
`header footer reference equation` | -| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption`
`toc` | +|`pp_layout_table`| 表格 | `layout_table.onnx` |`["table"]` | +| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`["text", "title", "list", "table", "figure"]` | +| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` | +| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `['text', 'title', 'figure', 'figure_caption', 'table', 'table_caption', 'header', 'footer', 'reference', 'equation']` | +| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `['text', 'title', 'header', 'footer', 'figure', 'figure_caption', 'table', 'table_caption', 'toc']` | PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md) diff --git a/docs/doc_whl_rapid_layout.md b/docs/doc_whl_rapid_layout.md index 8163c11..9c23ec9 100644 --- a/docs/doc_whl_rapid_layout.md +++ b/docs/doc_whl_rapid_layout.md @@ -1 +1 @@ -See [link](https://github.com/RapidAI/RapidStructure) for details. +See [link](https://github.com/RapidAI/RapidLayout) for details. diff --git a/rapid_layout/main.py b/rapid_layout/main.py index 14cb89f..0cc7408 100644 --- a/rapid_layout/main.py +++ b/rapid_layout/main.py @@ -13,7 +13,7 @@ DownloadModel, LoadImage, OrtInferSession, - PicoDetPostProcess, + PPPostProcess, PPPreProcess, VisLayout, YOLOv8PostProcess, @@ -36,7 +36,6 @@ class RapidLayout: - def __init__( self, model_type: str = "pp_layout_cdla", @@ -44,11 +43,19 @@ def __init__( conf_thres: float = 0.5, iou_thres: float = 0.5, use_cuda: bool = False, + use_dml: bool = False, ): + if not self.check_of(conf_thres): + raise ValueError(f"conf_thres {conf_thres} is outside of range [0, 1]") + + if not self.check_of(iou_thres): + raise ValueError(f"iou_thres {conf_thres} is outside of range [0, 1]") + self.model_type = model_type config = { "model_path": self.get_model_path(model_type, model_path), "use_cuda": use_cuda, + "use_dml": use_dml, } self.session = OrtInferSession(config) labels = self.session.get_character_list() @@ -56,12 +63,12 @@ def __init__( # pp self.pp_preprocess = PPPreProcess(img_size=(800, 608)) - self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres) + self.pp_postprocess = PPPostProcess(labels, conf_thres, iou_thres) # yolov8 self.yolov8_input_shape = (640, 640) - self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape) - self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres) + self.yolov8_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape) + self.yolov8_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres) self.load_img = LoadImage() @@ -97,9 +104,9 @@ def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): return boxes, scores, class_names, elapse def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]): - input_tensor = self.yolo_preprocess(img) + input_tensor = self.yolov8_preprocess(img) outputs = self.session(input_tensor) - boxes, scores, class_names = self.yolo_postprocess( + boxes, scores, class_names = self.yolov8_postprocess( outputs, ori_img_shape, self.yolov8_input_shape ) return boxes, scores, class_names @@ -117,6 +124,12 @@ def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str: logger.info("model url is None, using the default model %s", DEFAULT_MODEL_PATH) return DEFAULT_MODEL_PATH + @staticmethod + def check_of(thres: float) -> bool: + if 0 <= thres <= 1.0: + return True + return False + def main(): parser = argparse.ArgumentParser() diff --git a/rapid_layout/utils/__init__.py b/rapid_layout/utils/__init__.py index 0cadd2a..ed27313 100644 --- a/rapid_layout/utils/__init__.py +++ b/rapid_layout/utils/__init__.py @@ -1,18 +1,10 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import yaml - from .download_model import DownloadModel from .infer_engine import OrtInferSession -from .load_image import LoadImage +from .load_image import LoadImage, LoadImageError from .logger import get_logger -from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess +from .post_prepross import PPPostProcess, YOLOv8PostProcess from .pre_procss import PPPreProcess, YOLOv8PreProcess from .vis_res import VisLayout - - -def read_yaml(yaml_path): - with open(yaml_path, "rb") as f: - data = yaml.load(f, Loader=yaml.Loader) - return data diff --git a/rapid_layout/utils/post_prepross.py b/rapid_layout/utils/post_prepross.py index eacd006..516fc33 100644 --- a/rapid_layout/utils/post_prepross.py +++ b/rapid_layout/utils/post_prepross.py @@ -6,7 +6,7 @@ import numpy as np -class PicoDetPostProcess: +class PPPostProcess: def __init__(self, labels, conf_thres=0.4, iou_thres=0.5): self.labels = labels self.strides = [8, 16, 32, 64] @@ -247,7 +247,6 @@ def area_of(left_top, right_bottom): class YOLOv8PostProcess: - def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5): self.labels = labels self.conf_threshold = conf_thres @@ -297,7 +296,6 @@ def extract_boxes(self, predictions): return boxes def rescale_boxes(self, boxes): - # Rescale boxes to original image dimensions input_shape = np.array( [self.input_width, self.input_height, self.input_width, self.input_height] @@ -332,7 +330,6 @@ def nms(boxes, scores, iou_threshold): def multiclass_nms(boxes, scores, class_ids, iou_threshold): - unique_class_ids = np.unique(class_ids) keep_boxes = [] diff --git a/requirements.txt b/requirements.txt index 3b141d5..b004a57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ onnxruntime>=1.7.0 -PyYAML>=6.0 opencv_python>=4.5.1.48 numpy>=1.21.6,<2 Pillow diff --git a/setup.py b/setup.py index 464eb69..7b48c64 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def get_readme(): include_package_data=True, install_requires=read_txt("requirements.txt"), packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"], - package_data={"": ["layout_cdla.onnx", "*.yaml"]}, + package_data={"": ["layout_cdla.onnx"]}, keywords=["ppstructure,layout,rapidocr,rapid_layout"], classifiers=[ "Programming Language :: Python :: 3.6", diff --git a/tests/test_layout.py b/tests/test_layout.py index fa56d7b..78b639a 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -13,6 +13,7 @@ sys.path.append(str(root_dir)) from rapid_layout import RapidLayout +from rapid_layout.utils import LoadImageError test_file_dir = cur_dir / "test_files" img_path = test_file_dir / "layout.png" @@ -20,10 +21,29 @@ img = cv2.imread(str(img_path)) +def test_iou_outside_thres(): + with pytest.raises(ValueError) as exc: + engine = RapidLayout(iou_thres=1.2) + assert exc.type is ValueError + + +def test_conf_outside_thres(): + with pytest.raises(ValueError) as exc: + engine = RapidLayout(conf_thres=1.2) + assert exc.type is ValueError + + +def test_empty(): + with pytest.raises(LoadImageError) as exc: + engine = RapidLayout() + engine(None) + assert exc.type is LoadImageError + + @pytest.mark.parametrize( "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img] ) -def test_multi_input(img_content): +def test_pp_layout(img_content): engine = RapidLayout() boxes, scores, class_names, *elapse = engine(img_content) assert len(boxes) == 15 @@ -32,7 +52,7 @@ def test_multi_input(img_content): @pytest.mark.parametrize( "img_content", [img_path, str(img_path), open(img_path, "rb").read(), img] ) -def test_yolov8_input(img_content): +def test_yolov8_layout(img_content): engine = RapidLayout(model_type="yolov8n_layout_paper") boxes, scores, class_names, *elapse = engine(img_content) assert len(boxes) == 11