Skip to content

Commit

Permalink
feat: Support 360 yolov8 model
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 19, 2024
1 parent a794d08 commit 171d494
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 190 deletions.
29 changes: 19 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
|`pp_layout_table`| 表格 | `layout_table.onnx` |`table` |
| `pp_layout_publaynet`| 英文 | `layout_publaynet.onnx` |`text title list table figure` |
| `pp_layout_table`| 中文 | `layout_cdla.onnx` | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |
| `yolov8n_layout_paper`| 论文 | `yolov8n_layout_paper.onnx` | `text title figure figure_caption table table_caption` <br> `header footer reference equation` |
| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `text title header footer figure figure_caption table table_caption` <br> `toc` |

模型来源[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
PP模型来源[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)

模型下载地址为:[百度网盘](https://pan.baidu.com/s/1PI9fksW6F6kQfJhwUkewWg?pwd=p29g) | [Google Drive](https://drive.google.com/drive/folders/1DAPWSN2zGQ-ED_Pz7RaJGTjfkN2-Mvsf?usp=sharing)
yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis)

模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0)

### 安装
由于模型较小,预先将中文版面分析模型(`layout_cdla.onnx`)打包进了whl包内,如果做中文版面分析,可直接安装使用
Expand All @@ -41,7 +45,7 @@ import cv2
from rapid_layout import RapidLayout, VisLayout

# model_type类型参见上表。指定不同model_type时,会自动下载相应模型到安装目录下的。
layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")
layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla")

img = cv2.imread('test_images/layout.png')

Expand All @@ -55,18 +59,23 @@ if ploted_img is not None:
- 用法:
```bash
$ rapid_layout -h
usage: rapid_layout [-h] -img IMG_PATH [-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}]
[--box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}] [-v]
usage: rapid_layout [-h] -img IMG_PATH
[-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
[--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
[--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}]
[-v]

options:
-h, --help show this help message and exit
-img IMG_PATH, --img_path IMG_PATH
-h, --help show this help message and exit
-img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}
-m {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}, --model_type {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
Support model type
--box_threshold {pp_layout_cdla,pp_layout_publaynet,pp_layout_table}
--conf_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
Box threshold, the range is [0, 1]
-v, --vis Wheter to visualize the layout results.
--iou_thres {pp_layout_cdla,pp_layout_publaynet,pp_layout_table,yolov8n_layout_paper,yolov8n_layout_report}
IoU threshold, the range is [0, 1]
-v, --vis Wheter to visualize the layout results.
```
- 示例:
```bash
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from rapid_layout import RapidLayout, VisLayout

layout_engine = RapidLayout(box_threshold=0.5, model_type="pp_layout_cdla")
layout_engine = RapidLayout(model_type="yolov8n_layout_paper")

img_path = "tests/test_files/layout.png"
img = cv2.imread(img_path)
Expand Down
24 changes: 0 additions & 24 deletions rapid_layout/config.yaml

This file was deleted.

95 changes: 62 additions & 33 deletions rapid_layout/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
LoadImage,
OrtInferSession,
PicoDetPostProcess,
PPPreProcess,
VisLayout,
create_operators,
YOLOv8PostProcess,
YOLOv8PreProcess,
get_logger,
read_yaml,
transform,
)

ROOT_DIR = Path(__file__).resolve().parent
Expand All @@ -29,64 +29,86 @@
"pp_layout_cdla": f"{ROOT_URL}/layout_cdla.onnx",
"pp_layout_publaynet": f"{ROOT_URL}/layout_publaynet.onnx",
"pp_layout_table": f"{ROOT_URL}/layout_table.onnx",
"yolov8n_layout_paper": f"{ROOT_URL}/yolov8n_layout_paper.onnx",
"yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx",
}
DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx")


class RapidLayout:

def __init__(
self,
model_type: str = "pp_layout_cdla",
box_threshold: float = 0.5,
model_path: Union[str, Path, None] = None,
conf_thres: float = 0.5,
iou_thres: float = 0.5,
use_cuda: bool = False,
):
config_path = str(ROOT_DIR / "config.yaml")
config = read_yaml(config_path)
config["model_path"] = self.get_model_path(model_type)
config["use_cuda"] = use_cuda

self.model_type = model_type
config = {
"model_path": self.get_model_path(model_type, model_path),
"use_cuda": use_cuda,
}
self.session = OrtInferSession(config)
labels = self.session.get_character_list()
logger.info("%s contains %s", model_type, labels)

self.preprocess_op = create_operators(config["pre_process"])
# pp
self.pp_preprocess = PPPreProcess(img_size=(800, 608))
self.pp_postprocess = PicoDetPostProcess(labels, conf_thres, iou_thres)

# yolov8
self.yolov8_input_shape = (640, 640)
self.yolo_preprocess = YOLOv8PreProcess(img_size=self.yolov8_input_shape)
self.yolo_postprocess = YOLOv8PostProcess(labels, conf_thres, iou_thres)

config["post_process"]["score_threshold"] = box_threshold
self.postprocess_op = PicoDetPostProcess(labels, **config["post_process"])
self.load_img = LoadImage()

self.pp_layout_type = [
"pp_layout_cdla",
"pp_layout_publaynet",
"pp_layout_table",
]
self.yolov8_layout_type = ["yolov8n_layout_paper", "yolov8n_layout_report"]

def __call__(
self, img_content: Union[str, np.ndarray, bytes, Path]
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], float]:
img = self.load_img(img_content)
ori_img_shape = img.shape[:2]

ori_im = img.copy()
data = transform({"image": img}, self.preprocess_op)
img = data[0]
if img is None:
return None, None, None, 0.0
if self.model_type in self.pp_layout_type:
return self.pp_layout(img, ori_img_shape)

img = np.expand_dims(img, axis=0)
img = img.copy()
if self.model_type in self.yolov8_layout_type:
return self.yolov8_layout(img, ori_img_shape)

preds, elapse = 0, 1
starttime = time.time()
raise ValueError(f"{self.model_type} is not supported.")

def pp_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
s_time = time.time()

img = self.pp_preprocess(img)
preds = self.session(img)
boxes, scores, class_names = self.pp_postprocess(ori_img_shape, img, preds)

score_list, boxes_list = [], []
num_outs = int(len(preds) / 2)
for out_idx in range(num_outs):
score_list.append(preds[out_idx])
boxes_list.append(preds[out_idx + num_outs])
elapse = time.time() - s_time
return boxes, scores, class_names, elapse

boxes, scores, class_names = self.postprocess_op(
ori_im, img, {"boxes": score_list, "boxes_num": boxes_list}
def yolov8_layout(self, img: np.ndarray, ori_img_shape: Tuple[int, int]):
input_tensor = self.yolo_preprocess(img)
outputs = self.session(input_tensor)
boxes, scores, class_names = self.yolo_postprocess(
outputs, ori_img_shape, self.yolov8_input_shape
)
elapse = time.time() - starttime
return boxes, scores, class_names, elapse
return boxes, scores, class_names

@staticmethod
def get_model_path(model_type: str) -> str:
def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str:
if model_path is not None:
return model_path

model_url = KEY_TO_MODEL_URL.get(model_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
Expand All @@ -110,12 +132,19 @@ def main():
help="Support model type",
)
parser.add_argument(
"--box_threshold",
"--conf_thres",
type=float,
default=0.5,
choices=list(KEY_TO_MODEL_URL.keys()),
help="Box threshold, the range is [0, 1]",
)
parser.add_argument(
"--iou_thres",
type=float,
default=0.5,
choices=list(KEY_TO_MODEL_URL.keys()),
help="IoU threshold, the range is [0, 1]",
)
parser.add_argument(
"-v",
"--vis",
Expand All @@ -125,7 +154,7 @@ def main():
args = parser.parse_args()

layout_engine = RapidLayout(
model_type=args.model_type, box_threshold=args.box_threshold
model_type=args.model_type, conf_thres=args.conf_thres, iou_thres=args.iou_thres
)

img = cv2.imread(args.img_path)
Expand Down
4 changes: 2 additions & 2 deletions rapid_layout/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from .infer_engine import OrtInferSession
from .load_image import LoadImage
from .logger import get_logger
from .post_prepross import PicoDetPostProcess
from .pre_procss import create_operators, transform
from .post_prepross import PicoDetPostProcess, YOLOv8PostProcess
from .pre_procss import PPPreProcess, YOLOv8PreProcess
from .vis_res import VisLayout


Expand Down
Loading

0 comments on commit 171d494

Please sign in to comment.