Skip to content

Commit

Permalink
Fix issue #104
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jul 12, 2023
1 parent 0b7bfda commit d986974
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 32 deletions.
11 changes: 10 additions & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@

1. 推理使用
- 脚本使用:
- ⚠️初始化RapidOCR可不提供`config.yaml`,默认使用安装目录下的`config.yaml`。如有自定义需求,可直接通过初始化参数传入。详细参数参考下面命令行部分,和`config.yaml`基本对应。
- ⚠️注意:初始化RapidOCR可不提供`config.yaml`,默认使用安装目录下的`config.yaml`。如有自定义需求:
- 一是可直接通过初始化参数传入。详细参数参考下面命令行部分,和`config.yaml`基本对应。
- 二是复制`config.yaml`,自行更改,然后初始化给出。e.g. `engine = RapidOCR(config_path="custom.yaml")`
- 输入:`Union[str, np.ndarray, bytes, Path]`
- 输出:
- 有值:`([[文本框坐标], 文本内容, 置信度], 推理时间)`
Expand All @@ -68,6 +70,7 @@

# RapidOCR可传入参数参考下面的命令行部分
rapid_ocr = RapidOCR()
# rapid_ocr = RapidOCR(config_path='custom.yaml')

img_path = 'tests/test_files/ch_en_num.jpg'

Expand Down Expand Up @@ -96,6 +99,7 @@
[--print_verbose PRINT_VERBOSE]
[--min_height MIN_HEIGHT]
[--width_height_ratio WIDTH_HEIGHT_RATIO]
[--det_use_cuda DET_USE_CUDA]
[--det_model_path DET_MODEL_PATH]
[--det_limit_side_len DET_LIMIT_SIDE_LEN]
[--det_limit_type {max,min}]
Expand All @@ -104,11 +108,13 @@
[--det_unclip_ratio DET_UNCLIP_RATIO]
[--det_use_dilation DET_USE_DILATION]
[--det_score_mode {slow,fast}]
[--cls_use_cuda CLS_USE_CUDA]
[--cls_model_path CLS_MODEL_PATH]
[--cls_image_shape CLS_IMAGE_SHAPE]
[--cls_label_list CLS_LABEL_LIST]
[--cls_batch_num CLS_BATCH_NUM]
[--cls_thresh CLS_THRESH]
[--rec_use_cuda REC_USE_CUDA]
[--rec_model_path REC_MODEL_PATH]
[--rec_img_shape REC_IMAGE_SHAPE]
[--rec_batch_num REC_BATCH_NUM]
Expand All @@ -127,6 +133,7 @@
--width_height_ratio WIDTH_HEIGHT_RATIO

Det:
--det_use_cuda DET_USE_CUDA
--det_model_path DET_MODEL_PATH
--det_limit_side_len DET_LIMIT_SIDE_LEN
--det_limit_type {max,min}
Expand All @@ -137,13 +144,15 @@
--det_score_mode {slow,fast}

Cls:
--cls_use_cuda CLS_USE_CUDA
--cls_model_path CLS_MODEL_PATH
--cls_image_shape CLS_IMAGE_SHAPE
--cls_label_list CLS_LABEL_LIST
--cls_batch_num CLS_BATCH_NUM
--cls_thresh CLS_THRESH

Rec:
--rec_use_cuda REC_USE_CUDA
--rec_model_path REC_MODEL_PATH
--rec_img_shape REC_IMAGE_SHAPE
--rec_batch_num REC_BATCH_NUM
Expand Down
12 changes: 9 additions & 3 deletions python/rapidocr_onnxruntime/rapid_ocr_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import importlib
from pathlib import Path
from typing import Union
from typing import Optional, Union

import cv2
import numpy as np
Expand All @@ -18,17 +18,23 @@


class RapidOCR:
def __init__(self, **kwargs):
config_path = str(root_dir / "config.yaml")
def __init__(self, config_path: Optional[str] = None, **kwargs):
if config_path is None:
config_path = str(root_dir / "config.yaml")

if not Path(config_path).exists():
raise FileExistsError(f"{config_path} does not exist!")

config = read_yaml(config_path)
config = concat_model_path(config)

if kwargs:
updater = UpdateParameters()
config = updater(config, **kwargs)

print(config)
exit()

global_config = config["Global"]
self.print_verbose = global_config["print_verbose"]
self.text_score = global_config["text_score"]
Expand Down
74 changes: 46 additions & 28 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from io import BytesIO
from pathlib import Path
from typing import Union
from typing import Dict, List, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -200,6 +200,7 @@ def init_args():
global_group.add_argument("--width_height_ratio", type=int, default=8)

det_group = parser.add_argument_group(title="Det")
det_group.add_argument("--det_use_cuda", action='store_true', default=False)
det_group.add_argument("--det_model_path", type=str, default=None)
det_group.add_argument("--det_limit_side_len", type=float, default=736)
det_group.add_argument(
Expand All @@ -214,13 +215,15 @@ def init_args():
)

cls_group = parser.add_argument_group(title="Cls")
cls_group.add_argument("--cls_use_cuda", action='store_true', default=False)
cls_group.add_argument("--cls_model_path", type=str, default=None)
cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192])
cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"])
cls_group.add_argument("--cls_batch_num", type=int, default=6)
cls_group.add_argument("--cls_thresh", type=float, default=0.9)

rec_group = parser.add_argument_group(title="Rec")
rec_group.add_argument("--rec_use_cuda", action='store_true', default=False)
rec_group.add_argument("--rec_model_path", type=str, default=None)
rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
rec_group.add_argument("--rec_batch_num", type=int, default=6)
Expand Down Expand Up @@ -262,37 +265,52 @@ def update_global_params(self, config, global_dict):
return config

def update_det_params(self, config, det_dict):
if det_dict:
det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()}
if not det_dict["model_path"]:
det_dict["model_path"] = str(root_dir / config["model_path"])
config.update(det_dict)
if not det_dict:
return config

det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()}
model_path = det_dict.get('model_path', None)
if not model_path:
det_dict["model_path"] = str(root_dir / config["model_path"])

config.update(det_dict)
return config

def update_cls_params(self, config, cls_dict):
if cls_dict:
need_remove_prefix = ["cls_label_list", "cls_model_path"]
new_cls_dict = {}
for k, v in cls_dict.items():
if k in need_remove_prefix:
k = k.split("cls_")[1]
new_cls_dict[k] = v

if not new_cls_dict["model_path"]:
new_cls_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_cls_dict)
if not cls_dict:
return config

need_remove_prefix = ["cls_label_list", "cls_model_path", "cls_use_cuda"]
new_cls_dict = self.remove_prefix(cls_dict, 'cls_', need_remove_prefix)

model_path = new_cls_dict.get('model_path', None)
if model_path:
new_cls_dict["model_path"] = str(root_dir / config["model_path"])

config.update(new_cls_dict)
return config

def update_rec_params(self, config, rec_dict):
if rec_dict:
need_remove_prefix = ["rec_model_path"]
new_rec_dict = {}
for k, v in rec_dict.items():
if k in need_remove_prefix:
k = k.split("rec_")[1]
new_rec_dict[k] = v

if not new_rec_dict["model_path"]:
new_rec_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_rec_dict)
if not rec_dict:
return config

need_remove_prefix = ["rec_model_path", "rec_use_cuda"]
new_rec_dict = self.remove_prefix(rec_dict, 'rec_', need_remove_prefix)

model_path = new_rec_dict.get('model_path', None)
if not model_path:
new_rec_dict["model_path"] = str(root_dir / config["model_path"])

config.update(new_rec_dict)
return config

@staticmethod
def remove_prefix(
config: Dict[str, str], prefix: str, remove_params: List[str]
) -> Dict[str, str]:
new_rec_dict = {}
for k, v in config.items():
if k in remove_params:
k = k.split(prefix)[1]
new_rec_dict[k] = v
return new_rec_dict

0 comments on commit d986974

Please sign in to comment.