From d986974dcae5a1fb020f7bcb922b0e0b2328d626 Mon Sep 17 00:00:00 2001 From: SWHL Date: Wed, 12 Jul 2023 09:56:30 +0800 Subject: [PATCH] Fix issue #104 --- python/README.md | 11 ++- python/rapidocr_onnxruntime/rapid_ocr_api.py | 12 +++- python/rapidocr_onnxruntime/utils.py | 74 ++++++++++++-------- 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/python/README.md b/python/README.md index 19c829f17..6148a10c5 100755 --- a/python/README.md +++ b/python/README.md @@ -54,7 +54,9 @@ 1. 推理使用 - 脚本使用: - - ⚠️初始化RapidOCR可不提供`config.yaml`,默认使用安装目录下的`config.yaml`。如有自定义需求,可直接通过初始化参数传入。详细参数参考下面命令行部分,和`config.yaml`基本对应。 + - ⚠️注意:初始化RapidOCR可不提供`config.yaml`,默认使用安装目录下的`config.yaml`。如有自定义需求: + - 一是可直接通过初始化参数传入。详细参数参考下面命令行部分,和`config.yaml`基本对应。 + - 二是复制`config.yaml`,自行更改,然后初始化给出。e.g. `engine = RapidOCR(config_path="custom.yaml")` - 输入:`Union[str, np.ndarray, bytes, Path]` - 输出: - 有值:`([[文本框坐标], 文本内容, 置信度], 推理时间)`, @@ -68,6 +70,7 @@ # RapidOCR可传入参数参考下面的命令行部分 rapid_ocr = RapidOCR() + # rapid_ocr = RapidOCR(config_path='custom.yaml') img_path = 'tests/test_files/ch_en_num.jpg' @@ -96,6 +99,7 @@ [--print_verbose PRINT_VERBOSE] [--min_height MIN_HEIGHT] [--width_height_ratio WIDTH_HEIGHT_RATIO] + [--det_use_cuda DET_USE_CUDA] [--det_model_path DET_MODEL_PATH] [--det_limit_side_len DET_LIMIT_SIDE_LEN] [--det_limit_type {max,min}] @@ -104,11 +108,13 @@ [--det_unclip_ratio DET_UNCLIP_RATIO] [--det_use_dilation DET_USE_DILATION] [--det_score_mode {slow,fast}] + [--cls_use_cuda CLS_USE_CUDA] [--cls_model_path CLS_MODEL_PATH] [--cls_image_shape CLS_IMAGE_SHAPE] [--cls_label_list CLS_LABEL_LIST] [--cls_batch_num CLS_BATCH_NUM] [--cls_thresh CLS_THRESH] + [--rec_use_cuda REC_USE_CUDA] [--rec_model_path REC_MODEL_PATH] [--rec_img_shape REC_IMAGE_SHAPE] [--rec_batch_num REC_BATCH_NUM] @@ -127,6 +133,7 @@ --width_height_ratio WIDTH_HEIGHT_RATIO Det: + --det_use_cuda DET_USE_CUDA --det_model_path DET_MODEL_PATH --det_limit_side_len DET_LIMIT_SIDE_LEN --det_limit_type {max,min} @@ -137,6 +144,7 @@ --det_score_mode {slow,fast} Cls: + --cls_use_cuda CLS_USE_CUDA --cls_model_path CLS_MODEL_PATH --cls_image_shape CLS_IMAGE_SHAPE --cls_label_list CLS_LABEL_LIST @@ -144,6 +152,7 @@ --cls_thresh CLS_THRESH Rec: + --rec_use_cuda REC_USE_CUDA --rec_model_path REC_MODEL_PATH --rec_img_shape REC_IMAGE_SHAPE --rec_batch_num REC_BATCH_NUM diff --git a/python/rapidocr_onnxruntime/rapid_ocr_api.py b/python/rapidocr_onnxruntime/rapid_ocr_api.py index 0c351b3c7..62b74716b 100644 --- a/python/rapidocr_onnxruntime/rapid_ocr_api.py +++ b/python/rapidocr_onnxruntime/rapid_ocr_api.py @@ -4,7 +4,7 @@ import copy import importlib from pathlib import Path -from typing import Union +from typing import Optional, Union import cv2 import numpy as np @@ -18,10 +18,13 @@ class RapidOCR: - def __init__(self, **kwargs): - config_path = str(root_dir / "config.yaml") + def __init__(self, config_path: Optional[str] = None, **kwargs): + if config_path is None: + config_path = str(root_dir / "config.yaml") + if not Path(config_path).exists(): raise FileExistsError(f"{config_path} does not exist!") + config = read_yaml(config_path) config = concat_model_path(config) @@ -29,6 +32,9 @@ def __init__(self, **kwargs): updater = UpdateParameters() config = updater(config, **kwargs) + print(config) + exit() + global_config = config["Global"] self.print_verbose = global_config["print_verbose"] self.text_score = global_config["text_score"] diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index 0e5e4d038..f5f21f563 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -6,7 +6,7 @@ import warnings from io import BytesIO from pathlib import Path -from typing import Union +from typing import Dict, List, Union import cv2 import numpy as np @@ -200,6 +200,7 @@ def init_args(): global_group.add_argument("--width_height_ratio", type=int, default=8) det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_use_cuda", action='store_true', default=False) det_group.add_argument("--det_model_path", type=str, default=None) det_group.add_argument("--det_limit_side_len", type=float, default=736) det_group.add_argument( @@ -214,6 +215,7 @@ def init_args(): ) cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_use_cuda", action='store_true', default=False) cls_group.add_argument("--cls_model_path", type=str, default=None) cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) @@ -221,6 +223,7 @@ def init_args(): cls_group.add_argument("--cls_thresh", type=float, default=0.9) rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_use_cuda", action='store_true', default=False) rec_group.add_argument("--rec_model_path", type=str, default=None) rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) rec_group.add_argument("--rec_batch_num", type=int, default=6) @@ -262,37 +265,52 @@ def update_global_params(self, config, global_dict): return config def update_det_params(self, config, det_dict): - if det_dict: - det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()} - if not det_dict["model_path"]: - det_dict["model_path"] = str(root_dir / config["model_path"]) - config.update(det_dict) + if not det_dict: + return config + + det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()} + model_path = det_dict.get('model_path', None) + if not model_path: + det_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(det_dict) return config def update_cls_params(self, config, cls_dict): - if cls_dict: - need_remove_prefix = ["cls_label_list", "cls_model_path"] - new_cls_dict = {} - for k, v in cls_dict.items(): - if k in need_remove_prefix: - k = k.split("cls_")[1] - new_cls_dict[k] = v - - if not new_cls_dict["model_path"]: - new_cls_dict["model_path"] = str(root_dir / config["model_path"]) - config.update(new_cls_dict) + if not cls_dict: + return config + + need_remove_prefix = ["cls_label_list", "cls_model_path", "cls_use_cuda"] + new_cls_dict = self.remove_prefix(cls_dict, 'cls_', need_remove_prefix) + + model_path = new_cls_dict.get('model_path', None) + if model_path: + new_cls_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(new_cls_dict) return config def update_rec_params(self, config, rec_dict): - if rec_dict: - need_remove_prefix = ["rec_model_path"] - new_rec_dict = {} - for k, v in rec_dict.items(): - if k in need_remove_prefix: - k = k.split("rec_")[1] - new_rec_dict[k] = v - - if not new_rec_dict["model_path"]: - new_rec_dict["model_path"] = str(root_dir / config["model_path"]) - config.update(new_rec_dict) + if not rec_dict: + return config + + need_remove_prefix = ["rec_model_path", "rec_use_cuda"] + new_rec_dict = self.remove_prefix(rec_dict, 'rec_', need_remove_prefix) + + model_path = new_rec_dict.get('model_path', None) + if not model_path: + new_rec_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(new_rec_dict) return config + + @staticmethod + def remove_prefix( + config: Dict[str, str], prefix: str, remove_params: List[str] + ) -> Dict[str, str]: + new_rec_dict = {} + for k, v in config.items(): + if k in remove_params: + k = k.split(prefix)[1] + new_rec_dict[k] = v + return new_rec_dict