From f4fd5cdcba2e2e9540f6c58d650965ba1a8534ed Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 24 Jun 2023 10:27:12 +0800 Subject: [PATCH] Format by black --- README.md | 3 +- api/rapidocr_api/api.py | 29 ++- docs/README_en.md | 3 +- ocrweb/rapidocr_web/ocrweb.py | 26 +-- ocrweb_multi/build.py | 26 +-- ocrweb_multi/main.py | 60 ++--- ocrweb_multi/rapidocr/classify.py | 18 +- ocrweb_multi/rapidocr/detect.py | 12 +- ocrweb_multi/rapidocr/detect_process.py | 75 +++--- ocrweb_multi/rapidocr/main.py | 14 +- ocrweb_multi/rapidocr/rapid_ocr_api.py | 54 ++--- ocrweb_multi/rapidocr/recognize.py | 24 +- ocrweb_multi/utils/config.py | 4 +- ocrweb_multi/utils/utils.py | 30 +-- python/demo.py | 78 ++++--- .../ch_ppocr_v2_cls/text_cls.py | 23 +- .../ch_ppocr_v2_cls/utils.py | 9 +- .../ch_ppocr_v3_det/text_detect.py | 51 ++-- .../ch_ppocr_v3_det/utils.py | 145 ++++++------ .../ch_ppocr_v3_rec/text_recognize.py | 24 +- .../ch_ppocr_v3_rec/utils.py | 26 +-- python/rapidocr_onnxruntime/rapid_ocr_api.py | 90 ++++---- python/rapidocr_onnxruntime/utils.py | 217 +++++++++--------- .../ch_ppocr_v2_cls/text_cls.py | 23 +- .../ch_ppocr_v2_cls/utils.py | 11 +- .../ch_ppocr_v3_det/text_detect.py | 19 +- .../ch_ppocr_v3_det/utils.py | 145 ++++++------ .../ch_ppocr_v3_rec/text_recognize.py | 26 +-- .../ch_ppocr_v3_rec/utils.py | 26 +-- python/rapidocr_openvino/rapid_ocr_api.py | 90 ++++---- python/rapidocr_openvino/utils.py | 159 +++++++------ python/setup_onnxruntime.py | 47 ++-- python/setup_openvino.py | 49 ++-- python/tests/base_module.py | 12 +- python/tests/benchmark/benchmark.py | 18 +- python/tests/test_all_ort.py | 34 +-- python/tests/test_all_vino.py | 34 +-- python/tests/test_cls.py | 16 +- python/tests/test_det.py | 15 +- python/tests/test_rec.py | 17 +- 40 files changed, 918 insertions(+), 864 deletions(-) diff --git a/README.md b/README.md index 405f83319..bc869b87b 100755 --- a/README.md +++ b/README.md @@ -17,10 +17,11 @@ PyPI - SemVer2.0 Documentation Status + SemVer2.0 +

diff --git a/api/rapidocr_api/api.py b/api/rapidocr_api/api.py index c599092e0..d00b407cb 100644 --- a/api/rapidocr_api/api.py +++ b/api/rapidocr_api/api.py @@ -19,7 +19,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) -class OCRAPIUtils(): +class OCRAPIUtils: def __init__(self) -> None: self.ocr = RapidOCR() @@ -31,10 +31,10 @@ def __call__(self, img): if not ocr_res: return json.dumps({}) - out_dict = {str(i): {'rec_txt': rec, - 'dt_boxes': dt_box, - 'score': score} - for i, (dt_box, rec, score) in enumerate(ocr_res)} + out_dict = { + str(i): {"rec_txt": rec, "dt_boxes": dt_box, "score": score} + for i, (dt_box, rec, score) in enumerate(ocr_res) + } return out_dict @@ -44,10 +44,10 @@ def __call__(self, img): @app.get("/") async def root(): - return {'message': 'Welcome to RapidOCR Server!'} + return {"message": "Welcome to RapidOCR Server!"} -@app.post('/ocr') +@app.post("/ocr") async def ocr(image_file: UploadFile = None, image_data: str = Form(None)): if image_file: img = Image.open(image_file.file) @@ -57,25 +57,24 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None)): img = Image.open(io.BytesIO(img_b64decode)) else: raise ValueError( - 'When sending a post request, data or files must have a value.') + "When sending a post request, data or files must have a value." + ) ocr_res = processor(img) return ocr_res def main(): - parser = argparse.ArgumentParser('rapidocr_api') - parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0', - help='IP Address') - parser.add_argument('-p', '--port', type=int, default=9003, - help='IP port') + parser = argparse.ArgumentParser("rapidocr_api") + parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address") + parser.add_argument("-p", "--port", type=int, default=9003, help="IP port") args = parser.parse_args() cur_file_path = Path(__file__).resolve() - app_path = f'{cur_file_path.parent.name}.{cur_file_path.stem}:app' + app_path = f"{cur_file_path.parent.name}.{cur_file_path.stem}:app" print(app_path) uvicorn.run(app_path, host=args.ip, port=args.port, reload=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/docs/README_en.md b/docs/README_en.md index 292430340..94937469a 100644 --- a/docs/README_en.md +++ b/docs/README_en.md @@ -16,10 +16,11 @@ PyPI - SemVer2.0 Documentation Status + SemVer2.0 +

diff --git a/ocrweb/rapidocr_web/ocrweb.py b/ocrweb/rapidocr_web/ocrweb.py index 64ed43c66..45825e7bc 100644 --- a/ocrweb/rapidocr_web/ocrweb.py +++ b/ocrweb/rapidocr_web/ocrweb.py @@ -14,36 +14,34 @@ root_dir = Path(__file__).resolve().parent -app = Flask(__name__, template_folder='templates') -app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024 +app = Flask(__name__, template_folder="templates") +app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024 processor = OCRWebUtils() -@app.route('/') +@app.route("/") def index(): - return render_template('index.html') + return render_template("index.html") -@app.route('/ocr', methods=['POST']) +@app.route("/ocr", methods=["POST"]) def ocr(): - if request.method == 'POST': - img_str = request.get_json().get('file', None) + if request.method == "POST": + img_str = request.get_json().get("file", None) ocr_res = processor(img_str) return ocr_res def main(): - parser = argparse.ArgumentParser('rapidocr_web') - parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0', - help='IP Address') - parser.add_argument('-p', '--port', type=int, default=9003, - help='IP port') + parser = argparse.ArgumentParser("rapidocr_web") + parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address") + parser.add_argument("-p", "--port", type=int, default=9003, help="IP port") args = parser.parse_args() - print(f'Successfully launched and visit https://{args.ip}:{args.port} to view.') + print(f"Successfully launched and visit https://{args.ip}:{args.port} to view.") server = make_server(args.ip, args.port, app) server.serve_forever() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/ocrweb_multi/build.py b/ocrweb_multi/build.py index a23ee2791..5b09f91b0 100644 --- a/ocrweb_multi/build.py +++ b/ocrweb_multi/build.py @@ -1,21 +1,21 @@ import os import shutil -print('Compile ocrweb') -os.system('pyinstaller -y main.spec') +print("Compile ocrweb") +os.system("pyinstaller -y main.spec") -print('Compile wrapper') -os.system('windres .\wrapper.rc -O coff -o wrapper.res') -os.system('gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe') +print("Compile wrapper") +os.system("windres .\wrapper.rc -O coff -o wrapper.res") +os.system("gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe") -print('Copy config.yaml') -shutil.copy2('config.yaml', 'dist/config.yaml') +print("Copy config.yaml") +shutil.copy2("config.yaml", "dist/config.yaml") -print('Copy models') -shutil.copytree('models', 'dist/models', dirs_exist_ok=True) -os.remove('dist/models/.gitkeep') +print("Copy models") +shutil.copytree("models", "dist/models", dirs_exist_ok=True) +os.remove("dist/models/.gitkeep") -print('Pack to ocrweb.zip') -shutil.make_archive('ocrweb', 'zip', 'dist') +print("Pack to ocrweb.zip") +shutil.make_archive("ocrweb", "zip", "dist") -print('Done') +print("Done") diff --git a/ocrweb_multi/main.py b/ocrweb_multi/main.py index 05f773699..270f2a47b 100644 --- a/ocrweb_multi/main.py +++ b/ocrweb_multi/main.py @@ -13,61 +13,67 @@ from utils.utils import tojson, parse_bool app = Flask(__name__) -log = logging.getLogger('app') +log = logging.getLogger("app") # 设置上传文件大小 -app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024 +app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024 -@app.route('/') +@app.route("/") def index(): - return send_file('static/index.html') + return send_file("static/index.html") def json_response(data, status=200): - return make_response(tojson(data), status, {"content-type": 'application/json'}) + return make_response(tojson(data), status, {"content-type": "application/json"}) -@app.route('/lang') +@app.route("/lang") def get_languages(): """返回可用语言列表""" data = [ - {'code': key, 'name': val['name']} for key, val in conf['languages'].items() + {"code": key, "name": val["name"]} for key, val in conf["languages"].items() ] - result = {'msg': 'OK', 'data': data} - log.info('Send langs: %s', data) + result = {"msg": "OK", "data": data} + log.info("Send langs: %s", data) return json_response(result) -@app.route('/ocr', methods=['POST', 'GET']) +@app.route("/ocr", methods=["POST", "GET"]) def ocr(): """执行文字识别""" - if conf['server'].get('token'): - if request.values.get('token') != conf['server']['token']: - return json_response({'msg': 'invalid token'}, status=403) + if conf["server"].get("token"): + if request.values.get("token") != conf["server"]["token"]: + return json_response({"msg": "invalid token"}, status=403) - lang = request.values.get('lang') or 'ch' - detect = parse_bool(request.values.get('detect') or 'true') - classify = parse_bool(request.values.get('classify') or 'true') + lang = request.values.get("lang") or "ch" + detect = parse_bool(request.values.get("detect") or "true") + classify = parse_bool(request.values.get("classify") or "true") - image_file = request.files.get('image') + image_file = request.files.get("image") if not image_file: - return json_response({'msg': 'no image'}, 400) + return json_response({"msg": "no image"}, 400) nparr = np.frombuffer(image_file.stream.read(), np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - log.info('Input: image %s, lang=%s, detect=%s, classify=%s', image.shape, lang, detect, classify) + log.info( + "Input: image %s, lang=%s, detect=%s, classify=%s", + image.shape, + lang, + detect, + classify, + ) if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) result = detect_recognize(image, lang=lang, detect=detect, classify=classify) - log.info('OCR Done %s %s', result['ts'], len(result['results'])) - return json_response({'msg': 'OK', 'data': result}) + log.info("OCR Done %s %s", result["ts"], len(result["results"])) + return json_response({"msg": "OK", "data": result}) -if __name__ == '__main__': - logging.basicConfig(level='INFO') - logging.getLogger('waitress').setLevel(logging.INFO) - if parse_bool(conf.get('debug', '0')): +if __name__ == "__main__": + logging.basicConfig(level="INFO") + logging.getLogger("waitress").setLevel(logging.INFO) + if parse_bool(conf.get("debug", "0")): # Debug - app.run(host=conf['server']['host'], port=conf['server']['port'], debug=True) + app.run(host=conf["server"]["host"], port=conf["server"]["port"], debug=True) else: # Deploy with waitress - serve(app, host=conf['server']['host'], port=conf['server']['port']) + serve(app, host=conf["server"]["host"], port=conf["server"]["port"]) diff --git a/ocrweb_multi/rapidocr/classify.py b/ocrweb_multi/rapidocr/classify.py index 5c2b37bcc..5a3e4c4e5 100644 --- a/ocrweb_multi/rapidocr/classify.py +++ b/ocrweb_multi/rapidocr/classify.py @@ -21,7 +21,7 @@ from utils.utils import OrtInferSession -class ClsPostProcess(): +class ClsPostProcess: """Convert between text-label and text-index""" def __init__(self, label_list): @@ -40,18 +40,18 @@ def __call__(self, preds, label=None): return decode_out, label -class TextClassifier(): +class TextClassifier: def __init__(self, path, config): - self.cls_batch_num = config['batch_size'] - self.cls_thresh = config['score_thresh'] + self.cls_batch_num = config["batch_size"] + self.cls_thresh = config["score_thresh"] session_instance = OrtInferSession(path) self.session = session_instance.session metamap = self.session.get_modelmeta().custom_metadata_map - self.cls_image_shape = json.loads(metamap['shape']) + self.cls_image_shape = json.loads(metamap["shape"]) - labels = json.loads(metamap['labels']) + labels = json.loads(metamap["labels"]) self.postprocess_op = ClsPostProcess(labels) self.input_name = session_instance.get_input_name() @@ -65,7 +65,7 @@ def resize_norm_img(self, img): resized_w = int(math.ceil(img_h * ratio)) resized_image = cv2.resize(img, (resized_w, img_h)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") if img_c == 1: resized_image = resized_image / 255 resized_image = resized_image[np.newaxis, :] @@ -91,7 +91,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - cls_res = [['', 0.0]] * img_num + cls_res = [["", 0.0]] * img_num batch_num = self.cls_batch_num for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) @@ -115,7 +115,7 @@ def __call__(self, img_list: List[np.ndarray]): for rno in range(len(cls_result)): label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] - if label == '180' and score > self.cls_thresh: + if label == "180" and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1 ) diff --git a/ocrweb_multi/rapidocr/detect.py b/ocrweb_multi/rapidocr/detect.py index 136b0bf67..a6810ac78 100644 --- a/ocrweb_multi/rapidocr/detect.py +++ b/ocrweb_multi/rapidocr/detect.py @@ -21,10 +21,10 @@ from .detect_process import DBPostProcess, create_operators, transform -class TextDetector(): +class TextDetector: def __init__(self, path, config): - self.preprocess_op = create_operators(config['pre_process']) - self.postprocess_op = DBPostProcess(**config['post_process']) + self.preprocess_op = create_operators(config["pre_process"]) + self.postprocess_op = DBPostProcess(**config["post_process"]) session_instance = OrtInferSession(path) self.session = session_instance.session @@ -32,11 +32,11 @@ def __init__(self, path, config): def __call__(self, img): if img is None: - raise ValueError('img is None') + raise ValueError("img is None") ori_im_shape = img.shape[:2] - data = {'image': img} + data = {"image": img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -49,7 +49,7 @@ def __call__(self, img): post_result = self.postprocess_op(preds[0], shape_list) - dt_boxes = post_result[0]['points'] + dt_boxes = post_result[0]["points"] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape) return dt_boxes diff --git a/ocrweb_multi/rapidocr/detect_process.py b/ocrweb_multi/rapidocr/detect_process.py index 2b4307b74..e941039f5 100644 --- a/ocrweb_multi/rapidocr/detect_process.py +++ b/ocrweb_multi/rapidocr/detect_process.py @@ -18,7 +18,6 @@ # @Contact: liekkaskono@163.com from copy import deepcopy import sys -import warnings import cv2 import numpy as np @@ -27,15 +26,15 @@ from shapely.geometry import Polygon -class DecodeImage(): +class DecodeImage: """decode image""" - def __init__(self, img_mode='RGB', channel_first=False): + def __init__(self, img_mode="RGB", channel_first=False): self.img_mode = img_mode self.channel_first = channel_first def __call__(self, data): - img = data['image'] + img = data["image"] if six.PY2: assert ( type(img) is str and len(img) > 0 @@ -45,53 +44,53 @@ def __call__(self, data): type(img) is bytes and len(img) > 0 ), "invalid input 'img' in DecodeImage" - img = np.frombuffer(img, dtype='uint8') + img = np.frombuffer(img, dtype="uint8") img = cv2.imdecode(img, 1) if img is None: return None - if self.img_mode == 'GRAY': + if self.img_mode == "GRAY": img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif self.img_mode == 'RGB': - assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]' + elif self.img_mode == "RGB": + assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: img = img.transpose((2, 0, 1)) - data['image'] = img + data["image"] = img return data -class NormalizeImage(): +class NormalizeImage: """normalize image such as substract mean, divide std""" - def __init__(self, scale=None, mean=None, std=None, order='chw'): + def __init__(self, scale=None, mean=None, std=None, order="chw"): self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") def __call__(self, data): - img = np.array(data['image']).astype(np.float32) - data['image'] = (img * self.scale - self.mean) / self.std + img = np.array(data["image"]).astype(np.float32) + data["image"] = (img * self.scale - self.mean) / self.std return data -class ToCHWImage(): +class ToCHWImage: """convert hwc image to chw image""" def __init__(self): pass def __call__(self, data): - img = data['image'] - data['image'] = img.transpose((2, 0, 1)) + img = data["image"] + data["image"] = img.transpose((2, 0, 1)) return data -class KeepKeys(): +class KeepKeys: def __init__(self, keep_keys): self.keep_keys = keep_keys @@ -102,25 +101,25 @@ def __call__(self, data): return data_list -class DetResizeForTest(): +class DetResizeForTest: def __init__(self, **kwargs): super(DetResizeForTest, self).__init__() self.resize_type = 0 - if 'image_shape' in kwargs: - self.image_shape = kwargs['image_shape'] + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] self.resize_type = 1 - elif 'limit_side_len' in kwargs: - self.limit_side_len = kwargs['limit_side_len'] - self.limit_type = kwargs.get('limit_type', 'min') - elif 'resize_long' in kwargs: + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs["limit_side_len"] + self.limit_type = kwargs.get("limit_type", "min") + elif "resize_long" in kwargs: self.resize_type = 2 - self.resize_long = kwargs.get('resize_long', 960) + self.resize_long = kwargs.get("resize_long", 960) else: self.limit_side_len = 736 - self.limit_type = 'min' + self.limit_type = "min" def __call__(self, data): - img = data['image'] + img = data["image"] src_h, src_w, _ = img.shape if self.resize_type == 0: @@ -131,8 +130,8 @@ def __call__(self, data): else: # img, shape = self.resize_image_type1(img) img, [ratio_h, ratio_w] = self.resize_image_type1(img) - data['image'] = img - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + data["image"] = img + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def resize_image_type1(self, img): @@ -156,7 +155,7 @@ def resize_image_type0(self, img): h, w, _ = img.shape # limit the max side - if self.limit_type == 'max': + if self.limit_type == "max": if max(h, w) > limit_side_len: if h > w: ratio = float(limit_side_len) / h @@ -234,7 +233,7 @@ def create_operators(op_param_list): ops = [] for args in op_param_list: args = deepcopy(args) - op_class = op_map[args.pop('class')] + op_class = op_map[args.pop("class")] ops.append(op_class(**args)) return ops @@ -247,7 +246,7 @@ def draw_text_det_res(dt_boxes, img_path): return src_im -class DBPostProcess(): +class DBPostProcess: """The post process for Differentiable Binarization (DB).""" def __init__( @@ -270,10 +269,10 @@ def __init__( self.dilation_kernel = None def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' + """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} - ''' + """ bitmap = _bitmap height, width = bitmap.shape @@ -376,5 +375,5 @@ def __call__(self, pred, shape_list): pred[batch_index], mask, src_w, src_h ) - boxes_batch.append({'points': boxes}) + boxes_batch.append({"points": boxes}) return boxes_batch diff --git a/ocrweb_multi/rapidocr/main.py b/ocrweb_multi/rapidocr/main.py index d88f8e755..016ebea35 100644 --- a/ocrweb_multi/rapidocr/main.py +++ b/ocrweb_multi/rapidocr/main.py @@ -13,21 +13,21 @@ @lru_cache(maxsize=None) -def load_language_model(lang='ch'): - models = conf['languages'][lang] - print('model', models) +def load_language_model(lang="ch"): + models = conf["languages"][lang] + print("model", models) return RapidOCR(models) -def detect_recognize(image, lang='ch', detect=True, classify=True): +def detect_recognize(image, lang="ch", detect=True, classify=True): model = load_language_model(lang) results, ts = model(image, detect=detect, classify=classify) - ts['total'] = sum(ts.values()) - return {'ts': ts, 'results': results} + ts["total"] = sum(ts.values()) + return {"ts": ts, "results": results} def check_and_read_gif(img_path): - if Path(img_path).suffix.lower() == 'gif': + if Path(img_path).suffix.lower() == "gif": gif = cv2.VideoCapture(img_path) ret, frame = gif.read() if not ret: diff --git a/ocrweb_multi/rapidocr/rapid_ocr_api.py b/ocrweb_multi/rapidocr/rapid_ocr_api.py index d5bc8e8e1..d2710a99b 100644 --- a/ocrweb_multi/rapidocr/rapid_ocr_api.py +++ b/ocrweb_multi/rapidocr/rapid_ocr_api.py @@ -52,61 +52,61 @@ def get_rotate_crop_image(img, points): @lru_cache(maxsize=None) def load_onnx_model(step, name): - model_config = conf['models'][step][name] + model_config = conf["models"][step][name] model_class = { - 'detect': TextDetector, - 'classify': TextClassifier, - 'recognize': TextRecognizer, + "detect": TextDetector, + "classify": TextClassifier, + "recognize": TextRecognizer, }[step] - return model_class(model_config['path'], model_config.get('config')) + return model_class(model_config["path"], model_config.get("config")) -class RapidOCR(): +class RapidOCR: def __init__(self, config): super(RapidOCR).__init__() self.config = config - self.text_score = config['config']['text_score'] - self.min_height = config['config']['min_height'] + self.text_score = config["config"]["text_score"] + self.min_height = config["config"]["min_height"] - models = config['models'] - self.text_detector = load_onnx_model('detect', models['detect']) - self.text_recognizer = load_onnx_model('recognize', models['recognize']) - self.text_cls = load_onnx_model('classify', models['classify']) + models = config["models"] + self.text_detector = load_onnx_model("detect", models["detect"]) + self.text_recognizer = load_onnx_model("recognize", models["recognize"]) + self.text_cls = load_onnx_model("classify", models["classify"]) def __call__(self, img: np.ndarray, detect=True, classify=True): ticker = Ticker() h, w = img.shape[:2] if not detect or h < self.min_height: dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w) - ticker.tick('detect') + ticker.tick("detect") else: dt_boxes = self.text_detector(img) - ticker.tick('detect') + ticker.tick("detect") if dt_boxes is None or len(dt_boxes) < 1: return [], ticker.maps - if conf['global']['verbose']: - print(f'boxes num: {len(dt_boxes)}') + if conf["global"]["verbose"]: + print(f"boxes num: {len(dt_boxes)}") dt_boxes = self.sorted_boxes(dt_boxes) img_crop_list = self.get_crop_img_list(img, dt_boxes) - ticker.tick('post-detect') + ticker.tick("post-detect") if classify: # 进行子图像角度修正 img_crop_list, _ = self.text_cls(img_crop_list) - ticker.tick('classify') - if conf['global']['verbose']: - print(f'cls num: {len(img_crop_list)}') + ticker.tick("classify") + if conf["global"]["verbose"]: + print(f"cls num: {len(img_crop_list)}") recog_result = self.text_recognizer(img_crop_list) - ticker.tick('recognize') - if conf['global']['verbose']: - print(f'rec_res num: {len(recog_result)}') + ticker.tick("recognize") + if conf["global"]["verbose"]: + print(f"rec_res num: {len(recog_result)}") results = self.filter_boxes_rec_by_score(dt_boxes, recog_result) - ticker.tick('post-recognize') + ticker.tick("post-recognize") return results, ticker.maps def get_boxes_img_without_det(self, img, h, w): @@ -134,13 +134,13 @@ def sorted_boxes(dt_boxes): sorted boxes(array) with shape [4, 2] """ - class AlignBox(): + class AlignBox: def __init__(self, data) -> None: self.data = data self.x = data[0][0] self.y = data[0][1] - def __lt__(self, other: 'AlignBox'): + def __lt__(self, other: "AlignBox"): dy = self.y - other.y # y差距小于10, 视为相等, 根据x排序 if abs(dy) < 10: @@ -156,5 +156,5 @@ def filter_boxes_rec_by_score(self, dt_boxes, rec_res): for box, rec_reuslt in zip(dt_boxes, rec_res): text, score = rec_reuslt if score >= self.text_score: - results.append({'box': box, 'text': text, 'score': score}) + results.append({"box": box, "text": text, "score": score}) return results diff --git a/ocrweb_multi/rapidocr/recognize.py b/ocrweb_multi/rapidocr/recognize.py index ae14ff741..ff79e2540 100644 --- a/ocrweb_multi/rapidocr/recognize.py +++ b/ocrweb_multi/rapidocr/recognize.py @@ -11,26 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import json import math -import time from typing import List import cv2 import numpy as np -from utils.utils import get_resource_path, OrtInferSession +from utils.utils import OrtInferSession -class CTCLabelDecode(): +class CTCLabelDecode: """Convert between text-label and text-index""" def __init__(self, characters: List[str]): super(CTCLabelDecode, self).__init__() self.characters = characters - self.characters.append(' ') + self.characters.append(" ") dict_character = self.add_special_char(self.characters) self.character = dict_character @@ -49,7 +47,7 @@ def __call__(self, preds, label=None): return text, label def add_special_char(self, dict_character): - dict_character = ['blank'] + dict_character + dict_character = ["blank"] + dict_character return dict_character def get_ignored_tokens(self): @@ -81,22 +79,22 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): conf_list.append(1) # avoid `Mean of empty slice.` warning score = np.mean(conf_list) if conf_list else 0 - text = ''.join(char_list) + text = "".join(char_list) result_list.append((text, score)) return result_list -class TextRecognizer(): +class TextRecognizer: def __init__(self, path, config): - self.rec_batch_num = config.get('rec_batch_num', 6) + self.rec_batch_num = config.get("rec_batch_num", 6) session_instance = OrtInferSession(path) self.session = session_instance.session metamap = session_instance.session.get_modelmeta().custom_metadata_map - chars = metamap['dictionary'].splitlines() + chars = metamap["dictionary"].splitlines() self.postprocess_op = CTCLabelDecode(chars) - self.rec_image_shape = json.loads(metamap['shape']) + self.rec_image_shape = json.loads(metamap["shape"]) self.input_name = session_instance.get_input_name() def resize_norm_img(self, img, max_wh_ratio): @@ -114,7 +112,7 @@ def resize_norm_img(self, img, max_wh_ratio): resized_w = int(math.ceil(img_height * ratio)) resized_image = cv2.resize(img, (resized_w, img_height)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 @@ -134,7 +132,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - rec_res = [['', 0.0]] * img_num + rec_res = [["", 0.0]] * img_num batch_num = self.rec_batch_num for beg_img_no in range(0, img_num, batch_num): diff --git a/ocrweb_multi/utils/config.py b/ocrweb_multi/utils/config.py index aca5c6a68..fda18f8f8 100644 --- a/ocrweb_multi/utils/config.py +++ b/ocrweb_multi/utils/config.py @@ -19,9 +19,9 @@ def get_resource_path(name: str): Path(name), ]: if path.exists(): - print('Loaded:', path) + print("Loaded:", path) return path raise FileNotFoundError(name) -conf = yaml.safe_load(get_resource_path('config.yaml').read_text(encoding='utf-8')) +conf = yaml.safe_load(get_resource_path("config.yaml").read_text(encoding="utf-8")) diff --git a/ocrweb_multi/utils/utils.py b/ocrweb_multi/utils/utils.py index f1916113f..9e9b4ac78 100644 --- a/ocrweb_multi/utils/utils.py +++ b/ocrweb_multi/utils/utils.py @@ -14,33 +14,33 @@ def parse_bool(val): if not isinstance(val, str): return bool(val) - return val.lower() in ('1', 'true', 'yes') + return val.lower() in ("1", "true", "yes") def default(obj): - if hasattr(obj, 'tolist'): + if hasattr(obj, "tolist"): return obj.tolist() return obj def tojson(obj, **kws): - return json.dumps(obj, default=default, ensure_ascii=False, **kws) + '\n' + return json.dumps(obj, default=default, ensure_ascii=False, **kws) + "\n" -class OrtInferSession(): +class OrtInferSession: def __init__(self, model_path): - ort_conf = conf['global'] + ort_conf = conf["global"] sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False - cuda_ep = 'CUDAExecutionProvider' - cpu_ep = 'CPUExecutionProvider' + cuda_ep = "CUDAExecutionProvider" + cpu_ep = "CPUExecutionProvider" providers = [] if ( - ort_conf['use_cuda'] - and get_device() == 'GPU' + ort_conf["use_cuda"] + and get_device() == "GPU" and cuda_ep in get_available_providers() ): providers = [(cuda_ep, ort_conf[cuda_ep])] @@ -53,12 +53,12 @@ def __init__(self, model_path): providers=providers, ) - if ort_conf['use_cuda'] and cuda_ep not in self.session.get_providers(): + if ort_conf["use_cuda"] and cuda_ep not in self.session.get_providers(): warnings.warn( - f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' - 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' - 'you can check their relations from the offical web site: ' - 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html', + f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" + "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " + "you can check their relations from the offical web site: " + "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", RuntimeWarning, ) @@ -69,7 +69,7 @@ def get_output_name(self, output_idx=0): return self.session.get_outputs()[output_idx].name -class Ticker(): +class Ticker: def __init__(self, reset=True) -> None: self.ts = time.perf_counter() self.reset = reset diff --git a/python/demo.py b/python/demo.py index ab4e963c2..e33dd6eaf 100644 --- a/python/demo.py +++ b/python/demo.py @@ -10,18 +10,20 @@ from PIL import Image, ImageDraw, ImageFont from rapidocr_onnxruntime import RapidOCR + # from rapidocr_openvino import RapidOCR -def draw_ocr_box_txt(image, boxes, txts, font_path, - scores=None, text_score=0.5): +def draw_ocr_box_txt(image, boxes, txts, font_path, scores=None, text_score=0.5): if not Path(font_path).exists(): - raise FileNotFoundError(f'The {font_path} does not exists! \n' - f'Please download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing') + raise FileNotFoundError( + f"The {font_path} does not exists! \n" + f"Please download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" + ) h, w = image.height, image.width img_left = image.copy() - img_right = Image.new('RGB', (w, h), (255, 255, 255)) + img_right = Image.new("RGB", (w, h), (255, 255, 255)) random.seed(0) draw_left = ImageDraw.Draw(img_left) @@ -30,40 +32,45 @@ def draw_ocr_box_txt(image, boxes, txts, font_path, if scores is not None and float(scores[idx]) < text_score: continue - color = (random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255)) + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw_left.polygon(box, fill=color) - draw_right.polygon([box[0][0], box[0][1], - box[1][0], box[1][1], - box[2][0], box[2][1], - box[3][0], box[3][1]], - outline=color) - - box_height = math.sqrt((box[0][0] - box[3][0])**2 - + (box[0][1] - box[3][1])**2) - - box_width = math.sqrt((box[0][0] - box[1][0])**2 - + (box[0][1] - box[1][1])**2) + draw_right.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + outline=color, + ) + + box_height = math.sqrt( + (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2 + ) + + box_width = math.sqrt( + (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2 + ) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, - encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") cur_y = box[0][1] for c in txt: char_size = font.getsize(c) - draw_right.text((box[0][0] + 3, cur_y), c, - fill=(0, 0, 0), font=font) + draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) cur_y += char_size[1] else: font_size = max(int(box_height * 0.8), 10) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - draw_right.text([box[0][0], box[0][1]], txt, - fill=(0, 0, 0), font=font) + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) img_left = Image.blend(image, img_left, 0.5) - img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) img_show.paste(img_left, (0, 0, w, h)) img_show.paste(img_right, (w, 0, w * 2, h)) return np.array(img_show) @@ -73,29 +80,28 @@ def visualize(image_path, result, font_path="resources/fonts/FZYTK.TTF"): image = Image.open(image_path) boxes, txts, scores = list(zip(*result)) - draw_img = draw_ocr_box_txt(image, np.array(boxes), - txts, font_path, - scores, - text_score=0.5) + draw_img = draw_ocr_box_txt( + image, np.array(boxes), txts, font_path, scores, text_score=0.5 + ) draw_img_save = Path("./inference_results/") if not draw_img_save.exists(): draw_img_save.mkdir(parents=True, exist_ok=True) - image_save = str(draw_img_save / f'infer_{Path(image_path).name}') + image_save = str(draw_img_save / f"infer_{Path(image_path).name}") cv2.imwrite(image_save, draw_img[:, :, ::-1]) - print(f'The infer result has saved in {image_save}') + print(f"The infer result has saved in {image_save}") -if __name__ == '__main__': +if __name__ == "__main__": rapid_ocr = RapidOCR() - image_path = 'tests/test_files/ch_en_num.jpg' - with open(image_path, 'rb') as f: + image_path = "tests/test_files/ch_en_num.jpg" + with open(image_path, "rb") as f: img = f.read() result, elapse_list = rapid_ocr(img) print(result) print(elapse_list) if result: - visualize(image_path, result, font_path='resources/fonts/FZYTK.TTF') + visualize(image_path, result, font_path="resources/fonts/FZYTK.TTF") diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py index d6abdacbe..7c289b457 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py @@ -25,12 +25,12 @@ from .utils import ClsPostProcess -class TextClassifier(): +class TextClassifier: def __init__(self, config): - self.cls_image_shape = config['cls_image_shape'] - self.cls_batch_num = config['cls_batch_num'] - self.cls_thresh = config['cls_thresh'] - self.postprocess_op = ClsPostProcess(config['label_list']) + self.cls_image_shape = config["cls_image_shape"] + self.cls_batch_num = config["cls_batch_num"] + self.cls_thresh = config["cls_thresh"] + self.postprocess_op = ClsPostProcess(config["label_list"]) self.infer = OrtInferSession(config) @@ -47,7 +47,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - cls_res = [['', 0.0]] * img_num + cls_res = [["", 0.0]] * img_num batch_num = self.cls_batch_num elapse = 0 for beg_img_no in range(0, img_num, batch_num): @@ -68,9 +68,10 @@ def __call__(self, img_list: List[np.ndarray]): for rno in range(len(cls_result)): label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] - if '180' in label and score > self.cls_thresh: + if "180" in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( - img_list[indices[beg_img_no + rno]], 1) + img_list[indices[beg_img_no + rno]], 1 + ) return img_list, cls_res, elapse def resize_norm_img(self, img): @@ -83,7 +84,7 @@ def resize_norm_img(self, img): resized_w = int(math.ceil(img_h * ratio)) resized_image = cv2.resize(img, (resized_w, img_h)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") if img_c == 1: resized_image = resized_image / 255 resized_image = resized_image[np.newaxis, :] @@ -99,8 +100,8 @@ def resize_norm_img(self, img): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--image_path', type=str, help='image_dir|image_path') - parser.add_argument('--config_path', type=str, default='config.yaml') + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") args = parser.parse_args() config = read_yaml(args.config_path) diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py index 466fee949..5c75d54ee 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -class ClsPostProcess(): - """ Convert between text-label and text-index """ +class ClsPostProcess: + """Convert between text-label and text-index""" def __init__(self, label_list): super(ClsPostProcess, self).__init__() @@ -20,8 +20,9 @@ def __init__(self, label_list): def __call__(self, preds, label=None): pred_idxs = preds.argmax(axis=1) - decode_out = [(self.label_list[idx], preds[i, idx]) - for i, idx in enumerate(pred_idxs)] + decode_out = [ + (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs) + ] if label is None: return decode_out diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py index 99f7a00de..0e42377f9 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py @@ -25,33 +25,31 @@ from .utils import DBPostProcess, create_operators, transform -class TextDetector(): +class TextDetector: def __init__(self, config): pre_process_list = { - 'DetResizeForTest': { - 'limit_side_len': config.get('limit_side_len', 736), - 'limit_type': config.get('limit_type', 'min') + "DetResizeForTest": { + "limit_side_len": config.get("limit_side_len", 736), + "limit_type": config.get("limit_type", "min"), }, - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' + "NormalizeImage": { + "std": [0.229, 0.224, 0.225], + "mean": [0.485, 0.456, 0.406], + "scale": "1./255.", + "order": "hwc", }, - 'ToCHWImage': None, - 'KeepKeys': { - 'keep_keys': ['image', 'shape'] - } + "ToCHWImage": None, + "KeepKeys": {"keep_keys": ["image", "shape"]}, } self.preprocess_op = create_operators(pre_process_list) post_process = { - 'thresh': config.get('thresh', 0.3), - 'box_thresh': config.get('box_thresh', 0.5), - 'max_candidates': config.get('max_candidates', 1000), - 'unclip_ratio': config.get('unclip_ratio', 1.6), - 'use_dilation': config.get('use_dilation', True), - 'score_mode': config.get('score_mode', 'fast'), + "thresh": config.get("thresh", 0.3), + "box_thresh": config.get("box_thresh", 0.5), + "max_candidates": config.get("max_candidates", 1000), + "unclip_ratio": config.get("unclip_ratio", 1.6), + "use_dilation": config.get("use_dilation", True), + "score_mode": config.get("score_mode", "fast"), } self.postprocess_op = DBPostProcess(**post_process) @@ -59,11 +57,11 @@ def __init__(self, config): def __call__(self, img): if img is None: - raise ValueError('img is None') + raise ValueError("img is None") ori_im_shape = img.shape[:2] - data = {'image': img} + data = {"image": img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -76,7 +74,7 @@ def __call__(self, img): preds = self.infer(img)[0] post_result = self.postprocess_op(preds, shape_list) - dt_boxes = post_result[0]['points'] + dt_boxes = post_result[0]["points"] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape) elapse = time.time() - starttime return dt_boxes, elapse @@ -129,8 +127,8 @@ def filter_tag_det_res(self, dt_boxes, image_shape): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--config_path', type=str, default='config.yaml') - parser.add_argument('--image_path', type=str, default=None) + parser.add_argument("--config_path", type=str, default="config.yaml") + parser.add_argument("--image_path", type=str, default=None) args = parser.parse_args() config = read_yaml(args.config_path) @@ -141,6 +139,7 @@ def filter_tag_det_res(self, dt_boxes, image_shape): dt_boxes, elapse = text_detector(img) from utils import draw_text_det_res + src_im = draw_text_det_res(dt_boxes, args.image_path) - cv2.imwrite('det_results.jpg', src_im) - print('The det_results.jpg has been saved in the current directory.') + cv2.imwrite("det_results.jpg", src_im) + print("The det_results.jpg has been saved in the current directory.") diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py index b1b489f2f..3b08e07ea 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py @@ -25,69 +25,74 @@ from shapely.geometry import Polygon -class DecodeImage(): - """ decode image """ +class DecodeImage: + """decode image""" - def __init__(self, img_mode='RGB', channel_first=False): + def __init__(self, img_mode="RGB", channel_first=False): self.img_mode = img_mode self.channel_first = channel_first def __call__(self, data): - img = data['image'] + img = data["image"] if six.PY2: - assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage" + assert ( + type(img) is str and len(img) > 0 + ), "invalid input 'img' in DecodeImage" else: - assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage" + assert ( + type(img) is bytes and len(img) > 0 + ), "invalid input 'img' in DecodeImage" - img = np.frombuffer(img, dtype='uint8') + img = np.frombuffer(img, dtype="uint8") img = cv2.imdecode(img, 1) if img is None: return None - if self.img_mode == 'GRAY': + if self.img_mode == "GRAY": img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif self.img_mode == 'RGB': - assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]' + elif self.img_mode == "RGB": + assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: img = img.transpose((2, 0, 1)) - data['image'] = img + data["image"] = img return data -class NormalizeImage(): - """ normalize image such as substract mean, divide std""" +class NormalizeImage: + """normalize image such as substract mean, divide std""" - def __init__(self, scale=None, mean=None, std=None, order='chw'): + def __init__(self, scale=None, mean=None, std=None, order="chw"): if isinstance(scale, str): scale = eval(scale) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") def __call__(self, data): - img = np.array(data['image']).astype(np.float32) - data['image'] = (img * self.scale - self.mean) / self.std + img = np.array(data["image"]).astype(np.float32) + data["image"] = (img * self.scale - self.mean) / self.std return data -class ToCHWImage(): - """ convert hwc image to chw image""" +class ToCHWImage: + """convert hwc image to chw image""" + def __init__(self): pass def __call__(self, data): - img = np.array(data['image']) - data['image'] = img.transpose((2, 0, 1)) + img = np.array(data["image"]) + data["image"] = img.transpose((2, 0, 1)) return data -class KeepKeys(): +class KeepKeys: def __init__(self, keep_keys): self.keep_keys = keep_keys @@ -98,26 +103,26 @@ def __call__(self, data): return data_list -class DetResizeForTest(): +class DetResizeForTest: def __init__(self, **kwargs): super(DetResizeForTest, self).__init__() self.resize_type = 0 - if 'image_shape' in kwargs: - self.image_shape = kwargs['image_shape'] + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] self.resize_type = 1 - elif 'limit_side_len' in kwargs: - self.limit_side_len = kwargs.get('limit_side_len', 736) - self.limit_type = kwargs.get('limit_type', 'min') + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs.get("limit_side_len", 736) + self.limit_type = kwargs.get("limit_type", "min") - if 'resize_long' in kwargs: + if "resize_long" in kwargs: self.resize_type = 2 - self.resize_long = kwargs.get('resize_long', 960) + self.resize_long = kwargs.get("resize_long", 960) else: - self.limit_side_len = kwargs.get('limit_side_len', 736) - self.limit_type = kwargs.get('limit_type', 'min') + self.limit_side_len = kwargs.get("limit_side_len", 736) + self.limit_type = kwargs.get("limit_type", "min") def __call__(self, data): - img = data['image'] + img = data["image"] src_h, src_w = img.shape[:2] if self.resize_type == 0: @@ -128,8 +133,8 @@ def __call__(self, data): else: # img, shape = self.resize_image_type1(img) img, [ratio_h, ratio_w] = self.resize_image_type1(img) - data['image'] = img - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + data["image"] = img + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def resize_image_type1(self, img): @@ -153,14 +158,14 @@ def resize_image_type0(self, img): h, w = img.shape[:2] # limit the max side - if self.limit_type == 'max': + if self.limit_type == "max": if max(h, w) > limit_side_len: if h > w: ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w else: - ratio = 1. + ratio = 1.0 else: if min(h, w) < limit_side_len: if h < w: @@ -168,7 +173,7 @@ def resize_image_type0(self, img): else: ratio = float(limit_side_len) / w else: - ratio = 1. + ratio = 1.0 resize_h = int(h * ratio) resize_w = int(w * ratio) @@ -212,7 +217,7 @@ def resize_image_type2(self, img): def transform(data, ops=None): - """ transform """ + """transform""" if ops is None: ops = [] @@ -240,21 +245,22 @@ def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) - cv2.polylines(src_im, [box], True, - color=(255, 255, 0), thickness=2) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) return src_im -class DBPostProcess(): +class DBPostProcess: """The post process for Differentiable Binarization (DB).""" - def __init__(self, - thresh=0.3, - box_thresh=0.7, - max_candidates=1000, - unclip_ratio=2.0, - score_mode="fast", - use_dilation=False): + def __init__( + self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + score_mode="fast", + use_dilation=False, + ): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates @@ -268,16 +274,17 @@ def __init__(self, self.dilation_kernel = None def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' + """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} - ''' + """ bitmap = _bitmap height, width = bitmap.shape - outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, - cv2.CHAIN_APPROX_SIMPLE) + outs = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) if len(outs) == 3: img, contours, _ = outs[0], outs[1], outs[2] elif len(outs) == 2: @@ -306,10 +313,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): continue box = np.array(box) - box[:, 0] = np.clip( - np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 1] = np.clip( - np.round(box[:, 1] / height * dest_height), 0, dest_height) + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) boxes.append(box.astype(np.int16)) scores.append(score) return np.array(boxes, dtype=np.int16), scores @@ -341,9 +348,7 @@ def get_mini_boxes(self, contour): index_2 = 3 index_3 = 2 - box = [ - points[index_1], points[index_2], points[index_3], points[index_4] - ] + box = [points[index_1], points[index_2], points[index_3], points[index_4]] return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): @@ -358,12 +363,12 @@ def box_score_fast(self, bitmap, _box): box[:, 0] = box[:, 0] - xmin box[:, 1] = box[:, 1] - ymin cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def box_score_slow(self, bitmap, contour): - ''' + """ box_score_slow: use polyon mean score as the mean score - ''' + """ h, w = bitmap.shape[:2] contour = contour.copy() contour = np.reshape(contour, (-1, 2)) @@ -379,7 +384,7 @@ def box_score_slow(self, bitmap, contour): contour[:, 1] = contour[:, 1] - ymin cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def __call__(self, pred, shape_list): pred = pred[:, 0, :, :] @@ -391,11 +396,13 @@ def __call__(self, pred, shape_list): if self.dilation_kernel is not None: mask = cv2.dilate( np.array(segmentation[batch_index]).astype(np.uint8), - self.dilation_kernel) + self.dilation_kernel, + ) else: mask = segmentation[batch_index] - boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, - src_w, src_h) + boxes, scores = self.boxes_from_bitmap( + pred[batch_index], mask, src_w, src_h + ) - boxes_batch.append({'points': boxes}) + boxes_batch.append({"points": boxes}) return boxes_batch diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py index 1ce14a621..b487cd296 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py @@ -24,18 +24,18 @@ from .utils import CTCLabelDecode -class TextRecognizer(): +class TextRecognizer: def __init__(self, config): self.session = OrtInferSession(config) if self.session.have_key(): self.character_dict_path = self.session.get_character_list() else: - self.character_dict_path = config.get('keys_path', None) + self.character_dict_path = config.get("keys_path", None) self.postprocess_op = CTCLabelDecode(self.character_dict_path) - self.rec_batch_num = config['rec_batch_num'] - self.rec_image_shape = config['rec_img_shape'] + self.rec_batch_num = config["rec_batch_num"] + self.rec_image_shape = config["rec_img_shape"] def __call__(self, img_list: List[np.ndarray]): if isinstance(img_list, np.ndarray): @@ -48,7 +48,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - rec_res = [['', 0.0]] * img_num + rec_res = [["", 0.0]] * img_num batch_num = self.rec_batch_num elapse = 0 @@ -62,8 +62,7 @@ def __call__(self, img_list: List[np.ndarray]): norm_img_batch = [] for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img_batch.append(norm_img[np.newaxis, :]) norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) @@ -90,21 +89,20 @@ def resize_norm_img(self, img, max_wh_ratio): resized_w = int(math.ceil(img_height * ratio)) resized_image = cv2.resize(img, (resized_w, img_height)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 - padding_im = np.zeros((img_channel, img_height, img_width), - dtype=np.float32) + padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--image_path', type=str, help='image_dir|image_path') - parser.add_argument('--config_path', type=str, default='config.yaml') + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") args = parser.parse_args() config = read_yaml(args.config_path) @@ -112,4 +110,4 @@ def resize_norm_img(self, img, max_wh_ratio): img = cv2.imread(args.image_path) rec_res, predict_time = text_recognizer(img) - print(f'rec result: {rec_res}\t cost: {predict_time}s') + print(f"rec result: {rec_res}\t cost: {predict_time}s") diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py index a1a53b8f2..1cde51931 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py @@ -4,8 +4,8 @@ import numpy as np -class CTCLabelDecode(): - """ Convert between text-label and text-index """ +class CTCLabelDecode: + """Convert between text-label and text-index""" def __init__(self, character_dict_path): super(CTCLabelDecode, self).__init__() @@ -17,11 +17,11 @@ def __init__(self, character_dict_path): with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: - line = line.decode('utf-8').strip("\n").strip("\r\n") + line = line.decode("utf-8").strip("\n").strip("\r\n") self.character_str.append(line) else: self.character_str = character_dict_path - self.character_str.append(' ') + self.character_str.append(" ") dict_character = self.add_special_char(self.character_str) self.character = dict_character @@ -33,22 +33,21 @@ def __init__(self, character_dict_path): def __call__(self, preds, label=None): preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, - is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) if label is None: return text label = self.decode(label) return text, label def add_special_char(self, dict_character): - dict_character = ['blank'] + dict_character + dict_character = ["blank"] + dict_character return dict_character def get_ignored_tokens(self): return [0] # for ctc blank def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ + """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() @@ -61,15 +60,16 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): continue if is_remove_duplicate: # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: conf_list.append(1) - text = ''.join(char_list) + text = "".join(char_list) result_list.append((text, np.mean(conf_list + [1e-50]))) return result_list diff --git a/python/rapidocr_onnxruntime/rapid_ocr_api.py b/python/rapidocr_onnxruntime/rapid_ocr_api.py index 7f92c0f95..0c351b3c7 100644 --- a/python/rapidocr_onnxruntime/rapid_ocr_api.py +++ b/python/rapidocr_onnxruntime/rapid_ocr_api.py @@ -12,17 +12,16 @@ from .ch_ppocr_v2_cls import TextClassifier from .ch_ppocr_v3_det import TextDetector from .ch_ppocr_v3_rec import TextRecognizer -from .utils import (LoadImage, UpdateParameters, concat_model_path, init_args, - read_yaml) +from .utils import LoadImage, UpdateParameters, concat_model_path, init_args, read_yaml root_dir = Path(__file__).resolve().parent -class RapidOCR(): +class RapidOCR: def __init__(self, **kwargs): - config_path = str(root_dir / 'config.yaml') + config_path = str(root_dir / "config.yaml") if not Path(config_path).exists(): - raise FileExistsError(f'{config_path} does not exist!') + raise FileExistsError(f"{config_path} does not exist!") config = read_yaml(config_path) config = concat_model_path(config) @@ -30,30 +29,29 @@ def __init__(self, **kwargs): updater = UpdateParameters() config = updater(config, **kwargs) - global_config = config['Global'] - self.print_verbose = global_config['print_verbose'] - self.text_score = global_config['text_score'] - self.min_height = global_config['min_height'] - self.width_height_ratio = global_config['width_height_ratio'] + global_config = config["Global"] + self.print_verbose = global_config["print_verbose"] + self.text_score = global_config["text_score"] + self.min_height = global_config["min_height"] + self.width_height_ratio = global_config["width_height_ratio"] - self.use_text_det = config['Global']['use_text_det'] + self.use_text_det = config["Global"]["use_text_det"] if self.use_text_det: - self.text_detector = TextDetector(config['Det']) + self.text_detector = TextDetector(config["Det"]) - self.text_recognizer = TextRecognizer(config['Rec']) + self.text_recognizer = TextRecognizer(config["Rec"]) - self.use_angle_cls = config['Global']['use_angle_cls'] + self.use_angle_cls = config["Global"]["use_angle_cls"] if self.use_angle_cls: - self.text_cls = TextClassifier(config['Cls']) + self.text_cls = TextClassifier(config["Cls"]) self.load_img = LoadImage() - def __call__(self, - img_content: Union[str, np.ndarray, bytes, Path], **kwargs): + def __call__(self, img_content: Union[str, np.ndarray, bytes, Path], **kwargs): if kwargs: - box_thresh = kwargs.get('box_thresh', 0.5) - unclip_ratio = kwargs.get('unclip_ratio', 1.6) - text_score = kwargs.get('text_score', 0.5) + box_thresh = kwargs.get("box_thresh", 0.5) + unclip_ratio = kwargs.get("unclip_ratio", 1.6) + text_score = kwargs.get("text_score", 0.5) self.text_detector.postprocess_op.box_thresh = box_thresh self.text_detector.postprocess_op.unclip_ratio = unclip_ratio @@ -66,9 +64,7 @@ def __call__(self, else: use_limit_ratio = w / h > self.width_height_ratio - if not self.use_text_det \ - or h <= self.min_height \ - or use_limit_ratio: + if not self.use_text_det or h <= self.min_height or use_limit_ratio: dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w) det_elapse = 0.0 else: @@ -77,7 +73,7 @@ def __call__(self, return None, None if self.print_verbose: - print(f'dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}') + print(f"dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}") dt_boxes = self.sorted_boxes(dt_boxes) img_crop_list = self.get_crop_img_list(img, dt_boxes) @@ -87,16 +83,17 @@ def __call__(self, img_crop_list, _, cls_elapse = self.text_cls(img_crop_list) if self.print_verbose: - print(f'cls num: {len(img_crop_list)}, elapse: {cls_elapse}') + print(f"cls num: {len(img_crop_list)}, elapse: {cls_elapse}") rec_res, rec_elapse = self.text_recognizer(img_crop_list) if self.print_verbose: - print(f'rec_res num: {len(rec_res)}, elapse: {rec_elapse}') + print(f"rec_res num: {len(rec_res)}, elapse: {rec_elapse}") - filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, - rec_res) - fina_result = [[dt.tolist(), rec[0], str(rec[1])] - for dt, rec in zip(filter_boxes, filter_rec_res)] + filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, rec_res) + fina_result = [ + [dt.tolist(), rec[0], str(rec[1])] + for dt, rec in zip(filter_boxes, filter_rec_res) + ] if fina_result: return fina_result, [det_elapse, cls_elapse, rec_elapse] return None, None @@ -118,20 +115,31 @@ def get_rotate_crop_image(img, points): img_crop_width = int( max( np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) + np.linalg.norm(points[2] - points[3]), + ) + ) img_crop_height = int( max( np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( img, - M, (img_crop_width, img_crop_height), + M, + (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) + flags=cv2.INTER_CUBIC, + ) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) @@ -159,8 +167,10 @@ def sorted_boxes(dt_boxes): for i in range(num_boxes - 1): for j in range(i, -1, -1): - if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 \ - and _boxes[j + 1][0][0] < _boxes[j][0][0]: + if ( + abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 + and _boxes[j + 1][0][0] < _boxes[j][0][0] + ): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp @@ -188,5 +198,5 @@ def main(): print(elapse_list) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index a9fedefd5..62e062ade 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -10,69 +10,83 @@ import cv2 import numpy as np import yaml -from onnxruntime import (GraphOptimizationLevel, InferenceSession, - SessionOptions, get_available_providers, get_device) +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) from PIL import Image, UnidentifiedImageError root_dir = Path(__file__).resolve().parent InputType = Union[str, np.ndarray, bytes, Path] -class OrtInferSession(): +class OrtInferSession: def __init__(self, config): sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - cpu_ep = 'CPUExecutionProvider' + cpu_ep = "CPUExecutionProvider" cpu_provider_options = { - 'arena_extend_strategy': 'kSameAsRequested', + "arena_extend_strategy": "kSameAsRequested", } - cuda_ep = 'CUDAExecutionProvider' + cuda_ep = "CUDAExecutionProvider" cuda_provider_options = { - 'device_id': 0, - 'arena_extend_strategy': 'kNextPowerOfTwo', - 'cudnn_conv_algo_search': 'EXHAUSTIVE', - 'do_copy_in_default_stream': True + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, } EP_list = [] - if config['use_cuda'] and get_device() == 'GPU' \ - and cuda_ep in get_available_providers(): + if ( + config["use_cuda"] + and get_device() == "GPU" + and cuda_ep in get_available_providers() + ): EP_list = [(cuda_ep, cuda_provider_options)] EP_list.append((cpu_ep, cpu_provider_options)) - self._verify_model(config['model_path']) - self.session = InferenceSession(config['model_path'], - sess_options=sess_opt, - providers=EP_list) + self._verify_model(config["model_path"]) + self.session = InferenceSession( + config["model_path"], sess_options=sess_opt, providers=EP_list + ) - if config['use_cuda'] and cuda_ep not in self.session.get_providers(): - warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' - 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' - 'you can check their relations from the offical web site: ' - 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html', - RuntimeWarning) + if config["use_cuda"] and cuda_ep not in self.session.get_providers(): + warnings.warn( + f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" + "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " + "you can check their relations from the offical web site: " + "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", + RuntimeWarning, + ) def __call__(self, input_content: np.ndarray) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), [input_content])) try: return self.session.run(self.get_output_names(), input_dict) except Exception as e: - raise ONNXRuntimeError('ONNXRuntime inference failed.') from e + raise ONNXRuntimeError("ONNXRuntime inference failed.") from e - def get_input_names(self, ): + def get_input_names( + self, + ): return [v.name for v in self.session.get_inputs()] - def get_output_names(self,): + def get_output_names( + self, + ): return [v.name for v in self.session.get_outputs()] - def get_character_list(self, key: str = 'character'): + def get_character_list(self, key: str = "character"): return self.meta_dict[key].splitlines() - def have_key(self, key: str = 'character') -> bool: + def have_key(self, key: str = "character") -> bool: self.meta_dict = self.session.get_modelmeta().custom_metadata_map if key in self.meta_dict.keys(): return True @@ -82,23 +96,26 @@ def have_key(self, key: str = 'character') -> bool: def _verify_model(model_path): model_path = Path(model_path) if not model_path.exists(): - raise FileNotFoundError(f'{model_path} does not exists.') + raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): - raise FileExistsError(f'{model_path} is not a file.') + raise FileExistsError(f"{model_path} is not a file.") class ONNXRuntimeError(Exception): pass -class LoadImage(): - def __init__(self, ): +class LoadImage: + def __init__( + self, + ): pass def __call__(self, img: InputType) -> np.ndarray: if not isinstance(img, InputType.__args__): raise LoadImageError( - f'The img type {type(img)} does not in {InputType.__args__}') + f"The img type {type(img)} does not in {InputType.__args__}" + ) img = self.load_img(img) @@ -117,8 +134,7 @@ def load_img(self, img: InputType) -> np.ndarray: img = np.array(Image.open(img)) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) except UnidentifiedImageError as e: - raise LoadImageError( - f'cannot identify image file {img}') from e + raise LoadImageError(f"cannot identify image file {img}") from e return img if isinstance(img, bytes): @@ -129,12 +145,11 @@ def load_img(self, img: InputType) -> np.ndarray: if isinstance(img, np.ndarray): return img - raise LoadImageError(f'{type(img)} is not supported!') + raise LoadImageError(f"{type(img)} is not supported!") @staticmethod def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - '''RGBA → RGB - ''' + """RGBA → RGB""" r, g, b, a = cv2.split(img) new_img = cv2.merge((b, g, r)) @@ -148,7 +163,7 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: @staticmethod def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): - raise LoadImageError(f'{file_path} does not exist.') + raise LoadImageError(f"{file_path} does not exist.") class LoadImageError(Exception): @@ -156,77 +171,74 @@ class LoadImageError(Exception): def read_yaml(yaml_path): - with open(yaml_path, 'rb') as f: + with open(yaml_path, "rb") as f: data = yaml.load(f, Loader=yaml.Loader) return data def concat_model_path(config): - key = 'model_path' - config['Det'][key] = str(root_dir / config['Det'][key]) - config['Rec'][key] = str(root_dir / config['Rec'][key]) - config['Cls'][key] = str(root_dir / config['Cls'][key]) + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + config["Cls"][key] = str(root_dir / config["Cls"][key]) return config def init_args(): parser = argparse.ArgumentParser() - parser.add_argument('-img', '--img_path', type=str, default=None, - required=True) - parser.add_argument('-p', '--print_cost', - action='store_true', default=False) - - global_group = parser.add_argument_group(title='Global') - global_group.add_argument('--text_score', type=float, default=0.5) - global_group.add_argument('--use_angle_cls', type=bool, default=True) - global_group.add_argument('--use_text_det', type=bool, default=True) - global_group.add_argument('--print_verbose', type=bool, default=False) - global_group.add_argument('--min_height', type=int, default=30) - global_group.add_argument('--width_height_ratio', type=int, default=8) - - det_group = parser.add_argument_group(title='Det') - det_group.add_argument('--det_model_path', type=str, default=None) - det_group.add_argument('--det_limit_side_len', type=float, default=736) - det_group.add_argument('--det_limit_type', type=str, default='min', - choices=['max', 'min']) - det_group.add_argument('--det_thresh', type=float, default=0.3) - det_group.add_argument('--det_box_thresh', type=float, default=0.5) - det_group.add_argument('--det_unclip_ratio', type=float, default=1.6) - det_group.add_argument('--det_use_dilation', type=bool, default=True) - det_group.add_argument('--det_score_mode', type=str, default='fast', - choices=['slow', 'fast']) - - cls_group = parser.add_argument_group(title='Cls') - cls_group.add_argument('--cls_model_path', type=str, default=None) - cls_group.add_argument('--cls_image_shape', type=list, - default=[3, 48, 192]) - cls_group.add_argument('--cls_label_list', type=list, - default=['0', '180']) - cls_group.add_argument('--cls_batch_num', type=int, default=6) - cls_group.add_argument('--cls_thresh', type=float, default=0.9) - - rec_group = parser.add_argument_group(title='Rec') - rec_group.add_argument('--rec_model_path', type=str, default=None) - rec_group.add_argument('--rec_img_shape', type=list, - default=[3, 48, 320]) - rec_group.add_argument('--rec_batch_num', type=int, default=6) + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + global_group.add_argument("--use_angle_cls", type=bool, default=True) + global_group.add_argument("--use_text_det", type=bool, default=True) + global_group.add_argument("--print_verbose", type=bool, default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument("--det_use_dilation", type=bool, default=True) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) args = parser.parse_args() return args -class UpdateParameters(): +class UpdateParameters: def __init__(self) -> None: pass def parse_kwargs(self, **kwargs): global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} for k, v in kwargs.items(): - if k.startswith('det'): + if k.startswith("det"): det_dict[k] = v - elif k.startswith('cls'): + elif k.startswith("cls"): cls_dict[k] = v - elif k.startswith('rec'): + elif k.startswith("rec"): rec_dict[k] = v else: global_dict[k] = v @@ -235,11 +247,10 @@ def parse_kwargs(self, **kwargs): def __call__(self, config, **kwargs): global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) new_config = { - 'Global': self.update_global_params(config['Global'], - global_dict), - 'Det': self.update_det_params(config['Det'], det_dict), - 'Cls': self.update_cls_params(config['Cls'], cls_dict), - 'Rec': self.update_rec_params(config['Rec'], rec_dict) + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_det_params(config["Det"], det_dict), + "Cls": self.update_cls_params(config["Cls"], cls_dict), + "Rec": self.update_rec_params(config["Rec"], rec_dict), } return new_config @@ -250,38 +261,36 @@ def update_global_params(self, config, global_dict): def update_det_params(self, config, det_dict): if det_dict: - det_dict = {k.split('det_')[1]: v for k, v in det_dict.items()} - if not det_dict['model_path']: - det_dict['model_path'] = str(root_dir / config['model_path']) + det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()} + if not det_dict["model_path"]: + det_dict["model_path"] = str(root_dir / config["model_path"]) config.update(det_dict) return config def update_cls_params(self, config, cls_dict): if cls_dict: - need_remove_prefix = ['cls_label_list', 'cls_model_path'] + need_remove_prefix = ["cls_label_list", "cls_model_path"] new_cls_dict = {} for k, v in cls_dict.items(): if k in need_remove_prefix: - k = k.split('cls_')[1] + k = k.split("cls_")[1] new_cls_dict[k] = v - if not new_cls_dict['model_path']: - new_cls_dict['model_path'] = str( - root_dir / config['model_path']) + if not new_cls_dict["model_path"]: + new_cls_dict["model_path"] = str(root_dir / config["model_path"]) config.update(new_cls_dict) return config def update_rec_params(self, config, rec_dict): if rec_dict: - need_remove_prefix = ['rec_model_path'] + need_remove_prefix = ["rec_model_path"] new_rec_dict = {} for k, v in rec_dict.items(): if k in need_remove_prefix: - k = k.split('rec_')[1] + k = k.split("rec_")[1] new_rec_dict[k] = v - if not new_rec_dict['model_path']: - new_rec_dict['model_path'] = str( - root_dir / config['model_path']) + if not new_rec_dict["model_path"]: + new_rec_dict["model_path"] = str(root_dir / config["model_path"]) config.update(new_rec_dict) return config diff --git a/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py b/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py index a397d3514..5c44dbd57 100644 --- a/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py +++ b/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py @@ -25,12 +25,12 @@ from .utils import ClsPostProcess -class TextClassifier(): +class TextClassifier: def __init__(self, config): - self.cls_image_shape = config['cls_image_shape'] - self.cls_batch_num = config['cls_batch_num'] - self.cls_thresh = config['cls_thresh'] - self.postprocess_op = ClsPostProcess(config['label_list']) + self.cls_image_shape = config["cls_image_shape"] + self.cls_batch_num = config["cls_batch_num"] + self.cls_thresh = config["cls_thresh"] + self.postprocess_op = ClsPostProcess(config["label_list"]) self.infer = OpenVINOInferSession(config) @@ -47,7 +47,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - cls_res = [['', 0.0]] * img_num + cls_res = [["", 0.0]] * img_num batch_num = self.cls_batch_num elapse = 0 for beg_img_no in range(0, img_num, batch_num): @@ -68,9 +68,10 @@ def __call__(self, img_list: List[np.ndarray]): for rno in range(len(cls_result)): label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] - if '180' in label and score > self.cls_thresh: + if "180" in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( - img_list[indices[beg_img_no + rno]], 1) + img_list[indices[beg_img_no + rno]], 1 + ) return img_list, cls_res, elapse def resize_norm_img(self, img): @@ -83,7 +84,7 @@ def resize_norm_img(self, img): resized_w = int(math.ceil(img_h * ratio)) resized_image = cv2.resize(img, (resized_w, img_h)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") if img_c == 1: resized_image = resized_image / 255 resized_image = resized_image[np.newaxis, :] @@ -99,8 +100,8 @@ def resize_norm_img(self, img): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--image_path', type=str, help='image_dir|image_path') - parser.add_argument('--config_path', type=str, default='config.yaml') + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") args = parser.parse_args() config = read_yaml(args.config_path) diff --git a/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py b/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py index cb579fa05..5c75d54ee 100644 --- a/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py +++ b/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -class ClsPostProcess(): - """ Convert between text-label and text-index """ +class ClsPostProcess: + """Convert between text-label and text-index""" def __init__(self, label_list): super(ClsPostProcess, self).__init__() @@ -20,10 +20,11 @@ def __init__(self, label_list): def __call__(self, preds, label=None): pred_idxs = preds.argmax(axis=1) - decode_out = [(self.label_list[idx], preds[i, idx]) - for i, idx in enumerate(pred_idxs)] + decode_out = [ + (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs) + ] if label is None: return decode_out label = [(self.label_list[idx], 1.0) for idx in label] - return decode_out, label \ No newline at end of file + return decode_out, label diff --git a/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py b/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py index 5c50cbbdc..0b11a314f 100644 --- a/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py +++ b/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py @@ -23,16 +23,16 @@ from .utils import DBPostProcess, create_operators, transform -class TextDetector(): +class TextDetector: def __init__(self, config): - self.preprocess_op = create_operators(config['pre_process']) - self.postprocess_op = DBPostProcess(**config['post_process']) + self.preprocess_op = create_operators(config["pre_process"]) + self.postprocess_op = DBPostProcess(**config["post_process"]) self.infer = OpenVINOInferSession(config) def __call__(self, img): ori_im = img.copy() - data = {'image': img} + data = {"image": img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -44,7 +44,7 @@ def __call__(self, img): starttime = time.time() preds = self.infer(img) post_result = self.postprocess_op(preds, shape_list) - dt_boxes = post_result[0]['points'] + dt_boxes = post_result[0]["points"] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime return dt_boxes, elapse @@ -96,8 +96,8 @@ def filter_tag_det_res(self, dt_boxes, image_shape): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--config_path', type=str, default='config.yaml') - parser.add_argument('--image_path', type=str, default=None) + parser.add_argument("--config_path", type=str, default="config.yaml") + parser.add_argument("--image_path", type=str, default=None) args = parser.parse_args() config = read_yaml(args.config_path) @@ -108,6 +108,7 @@ def filter_tag_det_res(self, dt_boxes, image_shape): dt_boxes, elapse = text_detector(img) from utils import draw_text_det_res + src_im = draw_text_det_res(dt_boxes, args.image_path) - cv2.imwrite('det_results.jpg', src_im) - print('The det_results.jpg has been saved in the current directory.') + cv2.imwrite("det_results.jpg", src_im) + print("The det_results.jpg has been saved in the current directory.") diff --git a/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py b/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py index 781dbb1a9..2e586b7a5 100644 --- a/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py +++ b/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py @@ -25,69 +25,74 @@ from shapely.geometry import Polygon -class DecodeImage(): - """ decode image """ +class DecodeImage: + """decode image""" - def __init__(self, img_mode='RGB', channel_first=False): + def __init__(self, img_mode="RGB", channel_first=False): self.img_mode = img_mode self.channel_first = channel_first def __call__(self, data): - img = data['image'] + img = data["image"] if six.PY2: - assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage" + assert ( + type(img) is str and len(img) > 0 + ), "invalid input 'img' in DecodeImage" else: - assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage" + assert ( + type(img) is bytes and len(img) > 0 + ), "invalid input 'img' in DecodeImage" - img = np.frombuffer(img, dtype='uint8') + img = np.frombuffer(img, dtype="uint8") img = cv2.imdecode(img, 1) if img is None: return None - if self.img_mode == 'GRAY': + if self.img_mode == "GRAY": img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif self.img_mode == 'RGB': - assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]' + elif self.img_mode == "RGB": + assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: img = img.transpose((2, 0, 1)) - data['image'] = img + data["image"] = img return data -class NormalizeImage(): - """ normalize image such as substract mean, divide std""" +class NormalizeImage: + """normalize image such as substract mean, divide std""" - def __init__(self, scale=None, mean=None, std=None, order='chw'): + def __init__(self, scale=None, mean=None, std=None, order="chw"): if isinstance(scale, str): scale = eval(scale) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") def __call__(self, data): - img = np.array(data['image']).astype(np.float32) - data['image'] = (img * self.scale - self.mean) / self.std + img = np.array(data["image"]).astype(np.float32) + data["image"] = (img * self.scale - self.mean) / self.std return data -class ToCHWImage(): - """ convert hwc image to chw image""" +class ToCHWImage: + """convert hwc image to chw image""" + def __init__(self): pass def __call__(self, data): - img = np.array(data['image']) - data['image'] = img.transpose((2, 0, 1)) + img = np.array(data["image"]) + data["image"] = img.transpose((2, 0, 1)) return data -class KeepKeys(): +class KeepKeys: def __init__(self, keep_keys): self.keep_keys = keep_keys @@ -98,26 +103,26 @@ def __call__(self, data): return data_list -class DetResizeForTest(): +class DetResizeForTest: def __init__(self, **kwargs): super(DetResizeForTest, self).__init__() self.resize_type = 0 - if 'image_shape' in kwargs: - self.image_shape = kwargs['image_shape'] + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] self.resize_type = 1 - elif 'limit_side_len' in kwargs: - self.limit_side_len = kwargs.get('limit_side_len', 736) - self.limit_type = kwargs.get('limit_type', 'min') + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs.get("limit_side_len", 736) + self.limit_type = kwargs.get("limit_type", "min") - if 'resize_long' in kwargs: + if "resize_long" in kwargs: self.resize_type = 2 - self.resize_long = kwargs.get('resize_long', 960) + self.resize_long = kwargs.get("resize_long", 960) else: - self.limit_side_len = kwargs.get('limit_side_len', 736) - self.limit_type = kwargs.get('limit_type', 'min') + self.limit_side_len = kwargs.get("limit_side_len", 736) + self.limit_type = kwargs.get("limit_type", "min") def __call__(self, data): - img = data['image'] + img = data["image"] src_h, src_w = img.shape[:2] if self.resize_type == 0: @@ -128,8 +133,8 @@ def __call__(self, data): else: # img, shape = self.resize_image_type1(img) img, [ratio_h, ratio_w] = self.resize_image_type1(img) - data['image'] = img - data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + data["image"] = img + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) return data def resize_image_type1(self, img): @@ -153,14 +158,14 @@ def resize_image_type0(self, img): h, w = img.shape[:2] # limit the max side - if self.limit_type == 'max': + if self.limit_type == "max": if max(h, w) > limit_side_len: if h > w: ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w else: - ratio = 1. + ratio = 1.0 else: if min(h, w) < limit_side_len: if h < w: @@ -168,7 +173,7 @@ def resize_image_type0(self, img): else: ratio = float(limit_side_len) / w else: - ratio = 1. + ratio = 1.0 resize_h = int(h * ratio) resize_w = int(w * ratio) @@ -212,7 +217,7 @@ def resize_image_type2(self, img): def transform(data, ops=None): - """ transform """ + """transform""" if ops is None: ops = [] for op in ops: @@ -239,21 +244,22 @@ def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) - cv2.polylines(src_im, [box], True, - color=(255, 255, 0), thickness=2) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) return src_im -class DBPostProcess(): +class DBPostProcess: """The post process for Differentiable Binarization (DB).""" - def __init__(self, - thresh=0.3, - box_thresh=0.7, - max_candidates=1000, - unclip_ratio=2.0, - score_mode="fast", - use_dilation=False): + def __init__( + self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + score_mode="fast", + use_dilation=False, + ): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates @@ -267,16 +273,17 @@ def __init__(self, self.dilation_kernel = None def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' + """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} - ''' + """ bitmap = _bitmap height, width = bitmap.shape - outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, - cv2.CHAIN_APPROX_SIMPLE) + outs = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) if len(outs) == 3: img, contours, _ = outs[0], outs[1], outs[2] elif len(outs) == 2: @@ -305,10 +312,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): continue box = np.array(box) - box[:, 0] = np.clip( - np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 1] = np.clip( - np.round(box[:, 1] / height * dest_height), 0, dest_height) + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) boxes.append(box.astype(np.int16)) scores.append(score) return np.array(boxes, dtype=np.int16), scores @@ -340,9 +347,7 @@ def get_mini_boxes(self, contour): index_2 = 3 index_3 = 2 - box = [ - points[index_1], points[index_2], points[index_3], points[index_4] - ] + box = [points[index_1], points[index_2], points[index_3], points[index_4]] return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): @@ -357,12 +362,12 @@ def box_score_fast(self, bitmap, _box): box[:, 0] = box[:, 0] - xmin box[:, 1] = box[:, 1] - ymin cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def box_score_slow(self, bitmap, contour): - ''' + """ box_score_slow: use polyon mean score as the mean score - ''' + """ h, w = bitmap.shape[:2] contour = contour.copy() contour = np.reshape(contour, (-1, 2)) @@ -378,7 +383,7 @@ def box_score_slow(self, bitmap, contour): contour[:, 1] = contour[:, 1] - ymin cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] def __call__(self, pred, shape_list): pred = pred[:, 0, :, :] @@ -390,11 +395,13 @@ def __call__(self, pred, shape_list): if self.dilation_kernel is not None: mask = cv2.dilate( np.array(segmentation[batch_index]).astype(np.uint8), - self.dilation_kernel) + self.dilation_kernel, + ) else: mask = segmentation[batch_index] - boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, - src_w, src_h) + boxes, scores = self.boxes_from_bitmap( + pred[batch_index], mask, src_w, src_h + ) - boxes_batch.append({'points': boxes}) + boxes_batch.append({"points": boxes}) return boxes_batch diff --git a/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py index 1bf6bb095..ff274f805 100644 --- a/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py +++ b/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py @@ -28,13 +28,13 @@ from .utils import CTCLabelDecode -class TextRecognizer(): +class TextRecognizer: def __init__(self, config): - self.rec_image_shape = config['rec_img_shape'] - self.rec_batch_num = config['rec_batch_num'] + self.rec_image_shape = config["rec_img_shape"] + self.rec_batch_num = config["rec_batch_num"] - dict_path = str(Path(__file__).parent / 'ppocr_keys_v1.txt') - self.character_dict_path = config.get('keys_path', dict_path) + dict_path = str(Path(__file__).parent / "ppocr_keys_v1.txt") + self.character_dict_path = config.get("keys_path", dict_path) self.postprocess_op = CTCLabelDecode(self.character_dict_path) self.infer = OpenVINOInferSession(config) @@ -50,7 +50,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - rec_res = [['', 0.0]] * img_num + rec_res = [["", 0.0]] * img_num batch_num = self.rec_batch_num elapse = 0 @@ -64,8 +64,7 @@ def __call__(self, img_list: List[np.ndarray]): norm_img_batch = [] for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img_batch.append(norm_img[np.newaxis, :]) norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) @@ -91,21 +90,20 @@ def resize_norm_img(self, img, max_wh_ratio): resized_w = int(math.ceil(img_height * ratio)) resized_image = cv2.resize(img, (resized_w, img_height)) - resized_image = resized_image.astype('float32') + resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 - padding_im = np.zeros((img_channel, img_height, img_width), - dtype=np.float32) + padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--image_path', type=str, help='image_dir|image_path') - parser.add_argument('--config_path', type=str, default='config.yaml') + parser.add_argument("--image_path", type=str, help="image_dir|image_path") + parser.add_argument("--config_path", type=str, default="config.yaml") args = parser.parse_args() config = read_yaml(args.config_path) @@ -113,4 +111,4 @@ def resize_norm_img(self, img, max_wh_ratio): img = cv2.imread(args.image_path) rec_res, predict_time = text_recognizer(img) - print(f'rec result: {rec_res}\t cost: {predict_time}s') + print(f"rec result: {rec_res}\t cost: {predict_time}s") diff --git a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py index f40d906f6..9587af350 100644 --- a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py @@ -4,8 +4,8 @@ import numpy as np -class CTCLabelDecode(): - """ Convert between text-label and text-index """ +class CTCLabelDecode: + """Convert between text-label and text-index""" def __init__(self, character_dict_path): super(CTCLabelDecode, self).__init__() @@ -15,9 +15,9 @@ def __init__(self, character_dict_path): with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: - line = line.decode('utf-8').strip("\n").strip("\r\n") + line = line.decode("utf-8").strip("\n").strip("\r\n") self.character_str.append(line) - self.character_str.append(' ') + self.character_str.append(" ") dict_character = self.add_special_char(self.character_str) self.character = dict_character @@ -29,22 +29,21 @@ def __init__(self, character_dict_path): def __call__(self, preds, label=None): preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, - is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) if label is None: return text label = self.decode(label) return text, label def add_special_char(self, dict_character): - dict_character = ['blank'] + dict_character + dict_character = ["blank"] + dict_character return dict_character def get_ignored_tokens(self): return [0] # for ctc blank def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ + """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() @@ -57,15 +56,16 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): continue if is_remove_duplicate: # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: conf_list.append(1) - text = ''.join(char_list) + text = "".join(char_list) result_list.append((text, np.mean(conf_list + [1e-50]))) return result_list diff --git a/python/rapidocr_openvino/rapid_ocr_api.py b/python/rapidocr_openvino/rapid_ocr_api.py index 7f92c0f95..0c351b3c7 100644 --- a/python/rapidocr_openvino/rapid_ocr_api.py +++ b/python/rapidocr_openvino/rapid_ocr_api.py @@ -12,17 +12,16 @@ from .ch_ppocr_v2_cls import TextClassifier from .ch_ppocr_v3_det import TextDetector from .ch_ppocr_v3_rec import TextRecognizer -from .utils import (LoadImage, UpdateParameters, concat_model_path, init_args, - read_yaml) +from .utils import LoadImage, UpdateParameters, concat_model_path, init_args, read_yaml root_dir = Path(__file__).resolve().parent -class RapidOCR(): +class RapidOCR: def __init__(self, **kwargs): - config_path = str(root_dir / 'config.yaml') + config_path = str(root_dir / "config.yaml") if not Path(config_path).exists(): - raise FileExistsError(f'{config_path} does not exist!') + raise FileExistsError(f"{config_path} does not exist!") config = read_yaml(config_path) config = concat_model_path(config) @@ -30,30 +29,29 @@ def __init__(self, **kwargs): updater = UpdateParameters() config = updater(config, **kwargs) - global_config = config['Global'] - self.print_verbose = global_config['print_verbose'] - self.text_score = global_config['text_score'] - self.min_height = global_config['min_height'] - self.width_height_ratio = global_config['width_height_ratio'] + global_config = config["Global"] + self.print_verbose = global_config["print_verbose"] + self.text_score = global_config["text_score"] + self.min_height = global_config["min_height"] + self.width_height_ratio = global_config["width_height_ratio"] - self.use_text_det = config['Global']['use_text_det'] + self.use_text_det = config["Global"]["use_text_det"] if self.use_text_det: - self.text_detector = TextDetector(config['Det']) + self.text_detector = TextDetector(config["Det"]) - self.text_recognizer = TextRecognizer(config['Rec']) + self.text_recognizer = TextRecognizer(config["Rec"]) - self.use_angle_cls = config['Global']['use_angle_cls'] + self.use_angle_cls = config["Global"]["use_angle_cls"] if self.use_angle_cls: - self.text_cls = TextClassifier(config['Cls']) + self.text_cls = TextClassifier(config["Cls"]) self.load_img = LoadImage() - def __call__(self, - img_content: Union[str, np.ndarray, bytes, Path], **kwargs): + def __call__(self, img_content: Union[str, np.ndarray, bytes, Path], **kwargs): if kwargs: - box_thresh = kwargs.get('box_thresh', 0.5) - unclip_ratio = kwargs.get('unclip_ratio', 1.6) - text_score = kwargs.get('text_score', 0.5) + box_thresh = kwargs.get("box_thresh", 0.5) + unclip_ratio = kwargs.get("unclip_ratio", 1.6) + text_score = kwargs.get("text_score", 0.5) self.text_detector.postprocess_op.box_thresh = box_thresh self.text_detector.postprocess_op.unclip_ratio = unclip_ratio @@ -66,9 +64,7 @@ def __call__(self, else: use_limit_ratio = w / h > self.width_height_ratio - if not self.use_text_det \ - or h <= self.min_height \ - or use_limit_ratio: + if not self.use_text_det or h <= self.min_height or use_limit_ratio: dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w) det_elapse = 0.0 else: @@ -77,7 +73,7 @@ def __call__(self, return None, None if self.print_verbose: - print(f'dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}') + print(f"dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}") dt_boxes = self.sorted_boxes(dt_boxes) img_crop_list = self.get_crop_img_list(img, dt_boxes) @@ -87,16 +83,17 @@ def __call__(self, img_crop_list, _, cls_elapse = self.text_cls(img_crop_list) if self.print_verbose: - print(f'cls num: {len(img_crop_list)}, elapse: {cls_elapse}') + print(f"cls num: {len(img_crop_list)}, elapse: {cls_elapse}") rec_res, rec_elapse = self.text_recognizer(img_crop_list) if self.print_verbose: - print(f'rec_res num: {len(rec_res)}, elapse: {rec_elapse}') + print(f"rec_res num: {len(rec_res)}, elapse: {rec_elapse}") - filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, - rec_res) - fina_result = [[dt.tolist(), rec[0], str(rec[1])] - for dt, rec in zip(filter_boxes, filter_rec_res)] + filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, rec_res) + fina_result = [ + [dt.tolist(), rec[0], str(rec[1])] + for dt, rec in zip(filter_boxes, filter_rec_res) + ] if fina_result: return fina_result, [det_elapse, cls_elapse, rec_elapse] return None, None @@ -118,20 +115,31 @@ def get_rotate_crop_image(img, points): img_crop_width = int( max( np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) + np.linalg.norm(points[2] - points[3]), + ) + ) img_crop_height = int( max( np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( img, - M, (img_crop_width, img_crop_height), + M, + (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) + flags=cv2.INTER_CUBIC, + ) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) @@ -159,8 +167,10 @@ def sorted_boxes(dt_boxes): for i in range(num_boxes - 1): for j in range(i, -1, -1): - if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 \ - and _boxes[j + 1][0][0] < _boxes[j][0][0]: + if ( + abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 + and _boxes[j + 1][0][0] < _boxes[j][0][0] + ): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp @@ -188,5 +198,5 @@ def main(): print(elapse_list) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/rapidocr_openvino/utils.py b/python/rapidocr_openvino/utils.py index 4f01f387c..3c699b9f5 100644 --- a/python/rapidocr_openvino/utils.py +++ b/python/rapidocr_openvino/utils.py @@ -16,14 +16,14 @@ InputType = Union[str, np.ndarray, bytes, Path] -class OpenVINOInferSession(): +class OpenVINOInferSession: def __init__(self, config): ie = Core() - config['model_path'] = str(root_dir / config['model_path']) - self._verify_model(config['model_path']) - model_onnx = ie.read_model(config['model_path']) - compile_model = ie.compile_model(model=model_onnx, device_name='CPU') + config["model_path"] = str(root_dir / config["model_path"]) + self._verify_model(config["model_path"]) + model_onnx = ie.read_model(config["model_path"]) + compile_model = ie.compile_model(model=model_onnx, device_name="CPU") self.session = compile_model.create_infer_request() def __call__(self, input_content: np.ndarray) -> np.ndarray: @@ -34,19 +34,22 @@ def __call__(self, input_content: np.ndarray) -> np.ndarray: def _verify_model(model_path): model_path = Path(model_path) if not model_path.exists(): - raise FileNotFoundError(f'{model_path} does not exists.') + raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): - raise FileExistsError(f'{model_path} is not a file.') + raise FileExistsError(f"{model_path} is not a file.") -class LoadImage(): - def __init__(self, ): +class LoadImage: + def __init__( + self, + ): pass def __call__(self, img: InputType) -> np.ndarray: if not isinstance(img, InputType.__args__): raise LoadImageError( - f'The img type {type(img)} does not in {InputType.__args__}') + f"The img type {type(img)} does not in {InputType.__args__}" + ) img = self.load_img(img) @@ -65,8 +68,7 @@ def load_img(self, img: InputType) -> np.ndarray: img = np.array(Image.open(img)) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) except UnidentifiedImageError as e: - raise LoadImageError( - f'cannot identify image file {img}') from e + raise LoadImageError(f"cannot identify image file {img}") from e return img if isinstance(img, bytes): @@ -77,12 +79,11 @@ def load_img(self, img: InputType) -> np.ndarray: if isinstance(img, np.ndarray): return img - raise LoadImageError(f'{type(img)} is not supported!') + raise LoadImageError(f"{type(img)} is not supported!") @staticmethod def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - '''RGBA → RGB - ''' + """RGBA → RGB""" r, g, b, a = cv2.split(img) new_img = cv2.merge((b, g, r)) @@ -96,7 +97,7 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: @staticmethod def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): - raise LoadImageError(f'{file_path} does not exist.') + raise LoadImageError(f"{file_path} does not exist.") class LoadImageError(Exception): @@ -104,77 +105,74 @@ class LoadImageError(Exception): def read_yaml(yaml_path): - with open(yaml_path, 'rb') as f: + with open(yaml_path, "rb") as f: data = yaml.load(f, Loader=yaml.Loader) return data def concat_model_path(config): - key = 'model_path' - config['Det'][key] = str(root_dir / config['Det'][key]) - config['Rec'][key] = str(root_dir / config['Rec'][key]) - config['Cls'][key] = str(root_dir / config['Cls'][key]) + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + config["Cls"][key] = str(root_dir / config["Cls"][key]) return config def init_args(): parser = argparse.ArgumentParser() - parser.add_argument('-img', '--img_path', type=str, default=None, - required=True) - parser.add_argument('-p', '--print_cost', - action='store_true', default=False) - - global_group = parser.add_argument_group(title='Global') - global_group.add_argument('--text_score', type=float, default=0.5) - global_group.add_argument('--use_angle_cls', type=bool, default=True) - global_group.add_argument('--use_text_det', type=bool, default=True) - global_group.add_argument('--print_verbose', type=bool, default=False) - global_group.add_argument('--min_height', type=int, default=30) - global_group.add_argument('--width_height_ratio', type=int, default=8) - - det_group = parser.add_argument_group(title='Det') - det_group.add_argument('--det_model_path', type=str, default=None) - det_group.add_argument('--det_limit_side_len', type=float, default=736) - det_group.add_argument('--det_limit_type', type=str, default='min', - choices=['max', 'min']) - det_group.add_argument('--det_thresh', type=float, default=0.3) - det_group.add_argument('--det_box_thresh', type=float, default=0.5) - det_group.add_argument('--det_unclip_ratio', type=float, default=1.6) - det_group.add_argument('--det_use_dilation', type=bool, default=True) - det_group.add_argument('--det_score_mode', type=str, default='fast', - choices=['slow', 'fast']) - - cls_group = parser.add_argument_group(title='Cls') - cls_group.add_argument('--cls_model_path', type=str, default=None) - cls_group.add_argument('--cls_image_shape', type=list, - default=[3, 48, 192]) - cls_group.add_argument('--cls_label_list', type=list, - default=['0', '180']) - cls_group.add_argument('--cls_batch_num', type=int, default=6) - cls_group.add_argument('--cls_thresh', type=float, default=0.9) - - rec_group = parser.add_argument_group(title='Rec') - rec_group.add_argument('--rec_model_path', type=str, default=None) - rec_group.add_argument('--rec_image_shape', type=list, - default=[3, 48, 320]) - rec_group.add_argument('--rec_batch_num', type=int, default=6) + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + global_group.add_argument("--use_angle_cls", type=bool, default=True) + global_group.add_argument("--use_text_det", type=bool, default=True) + global_group.add_argument("--print_verbose", type=bool, default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument("--det_use_dilation", type=bool, default=True) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_image_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) args = parser.parse_args() return args -class UpdateParameters(): +class UpdateParameters: def __init__(self) -> None: pass def parse_kwargs(self, **kwargs): global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} for k, v in kwargs.items(): - if k.startswith('det'): + if k.startswith("det"): det_dict[k] = v - elif k.startswith('cls'): + elif k.startswith("cls"): cls_dict[k] = v - elif k.startswith('rec'): + elif k.startswith("rec"): rec_dict[k] = v else: global_dict[k] = v @@ -183,11 +181,10 @@ def parse_kwargs(self, **kwargs): def __call__(self, config, **kwargs): global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) new_config = { - 'Global': self.update_global_params(config['Global'], - global_dict), - 'Det': self.update_det_params(config['Det'], det_dict), - 'Cls': self.update_cls_params(config['Cls'], cls_dict), - 'Rec': self.update_rec_params(config['Rec'], rec_dict) + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_det_params(config["Det"], det_dict), + "Cls": self.update_cls_params(config["Cls"], cls_dict), + "Rec": self.update_rec_params(config["Rec"], rec_dict), } return new_config @@ -198,38 +195,36 @@ def update_global_params(self, config, global_dict): def update_det_params(self, config, det_dict): if det_dict: - det_dict = {k.split('det_')[1]: v for k, v in det_dict.items()} - if not det_dict['model_path']: - det_dict['model_path'] = str(root_dir / config['model_path']) + det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()} + if not det_dict["model_path"]: + det_dict["model_path"] = str(root_dir / config["model_path"]) config.update(det_dict) return config def update_cls_params(self, config, cls_dict): if cls_dict: - need_remove_prefix = ['cls_label_list', 'cls_model_path'] + need_remove_prefix = ["cls_label_list", "cls_model_path"] new_cls_dict = {} for k, v in cls_dict.items(): if k in need_remove_prefix: - k = k.split('cls_')[1] + k = k.split("cls_")[1] new_cls_dict[k] = v - if not new_cls_dict['model_path']: - new_cls_dict['model_path'] = str( - root_dir / config['model_path']) + if not new_cls_dict["model_path"]: + new_cls_dict["model_path"] = str(root_dir / config["model_path"]) config.update(new_cls_dict) return config def update_rec_params(self, config, rec_dict): if rec_dict: - need_remove_prefix = ['rec_model_path'] + need_remove_prefix = ["rec_model_path"] new_rec_dict = {} for k, v in rec_dict.items(): if k in need_remove_prefix: - k = k.split('rec_')[1] + k = k.split("rec_")[1] new_rec_dict[k] = v - if not new_rec_dict['model_path']: - new_rec_dict['model_path'] = str( - root_dir / config['model_path']) + if not new_rec_dict["model_path"]: + new_rec_dict["model_path"] = str(root_dir / config["model_path"]) config.update(new_rec_dict) return config diff --git a/python/setup_onnxruntime.py b/python/setup_onnxruntime.py index d0dc3ca3d..d5e181a06 100644 --- a/python/setup_onnxruntime.py +++ b/python/setup_onnxruntime.py @@ -10,14 +10,14 @@ def get_readme(): root_dir = Path(__file__).resolve().parent.parent - readme_path = str(root_dir / 'docs' / 'doc_whl_rapidocr_ort.md') + readme_path = str(root_dir / "docs" / "doc_whl_rapidocr_ort.md") print(readme_path) - with open(readme_path, 'r', encoding='utf-8') as f: + with open(readme_path, "r", encoding="utf-8") as f: readme = f.read() return readme -MODULE_NAME = 'rapidocr_onnxruntime' +MODULE_NAME = "rapidocr_onnxruntime" obtainer = GetPyPiLatestVersion() latest_version = obtainer(MODULE_NAME) @@ -25,7 +25,7 @@ def get_readme(): # 优先提取commit message中的语义化版本号,如无,则自动加1 if len(sys.argv) > 2: - match_str = ' '.join(sys.argv[2:]) + match_str = " ".join(sys.argv[2:]) matched_versions = obtainer.extract_version(match_str) if matched_versions: VERSION_NUM = matched_versions @@ -37,31 +37,38 @@ def get_readme(): platforms="Any", description="A cross platform OCR Library based on OnnxRuntime.", long_description=get_readme(), - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", author="SWHL", author_email="liekkaskono@163.com", url="https://github.com/RapidAI/RapidOCR", - license='Apache-2.0', + license="Apache-2.0", include_package_data=True, install_requires=[ - "pyclipper>=1.2.1", "onnxruntime>=1.7.0", "opencv_python>=4.5.1.48", - "numpy>=1.19.3", "six>=1.15.0", "Shapely>=1.7.1", 'PyYAML', 'Pillow' + "pyclipper>=1.2.1", + "onnxruntime>=1.7.0", + "opencv_python>=4.5.1.48", + "numpy>=1.19.3", + "six>=1.15.0", + "Shapely>=1.7.1", + "PyYAML", + "Pillow", ], - package_dir={'': MODULE_NAME}, + package_dir={"": MODULE_NAME}, packages=setuptools.find_namespace_packages(where=MODULE_NAME), - package_data={'': ['*.onnx', '*.yaml']}, + package_data={"": ["*.onnx", "*.yaml"]}, keywords=[ - 'ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr' + "ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr" ], classifiers=[ - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - python_requires='>=3.6,<3.12', + python_requires=">=3.6,<3.12", entry_points={ - 'console_scripts': [f'{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main'], - }) + "console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main"], + }, +) diff --git a/python/setup_openvino.py b/python/setup_openvino.py index d52e9d6cc..6431c93a3 100644 --- a/python/setup_openvino.py +++ b/python/setup_openvino.py @@ -10,14 +10,14 @@ def get_readme(): root_dir = Path(__file__).resolve().parent.parent - readme_path = str(root_dir / 'docs' / 'doc_whl_rapidocr_vino.md') + readme_path = str(root_dir / "docs" / "doc_whl_rapidocr_vino.md") print(readme_path) - with open(readme_path, 'r', encoding='utf-8') as f: + with open(readme_path, "r", encoding="utf-8") as f: readme = f.read() return readme -MODULE_NAME = 'rapidocr_openvino' +MODULE_NAME = "rapidocr_openvino" obtainer = GetPyPiLatestVersion() latest_version = obtainer(MODULE_NAME) @@ -25,7 +25,7 @@ def get_readme(): # 优先提取commit message中的语义化版本号,如无,则自动加1 if len(sys.argv) > 2: - match_str = ' '.join(sys.argv[2:]) + match_str = " ".join(sys.argv[2:]) matched_versions = obtainer.extract_version(match_str) if matched_versions: VERSION_NUM = matched_versions @@ -37,31 +37,38 @@ def get_readme(): platforms="Any", description="A cross platform OCR Library based on OpenVINO.", long_description=get_readme(), - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", author="SWHL", author_email="liekkaskono@163.com", url="https://github.com/RapidAI/RapidOCR", - license='Apache-2.0', + license="Apache-2.0", include_package_data=True, - install_requires=["pyclipper>=1.2.1", "openvino>=2022.2.0", - "opencv_python>=4.5.1.48", "numpy>=1.19.3", - "six>=1.15.0", "Shapely>=1.7.1", 'PyYAML', 'Pillow'], - package_dir={'': MODULE_NAME}, + install_requires=[ + "pyclipper>=1.2.1", + "openvino>=2022.2.0", + "opencv_python>=4.5.1.48", + "numpy>=1.19.3", + "six>=1.15.0", + "Shapely>=1.7.1", + "PyYAML", + "Pillow", + ], + package_dir={"": MODULE_NAME}, packages=setuptools.find_namespace_packages(where=MODULE_NAME), - package_data={'': ['*.onnx', '*.yaml', '*.txt']}, + package_data={"": ["*.onnx", "*.yaml", "*.txt"]}, keywords=[ - 'ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr' + "ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr" ], classifiers=[ - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - python_requires='>=3.6,<3.12', + python_requires=">=3.6,<3.12", entry_points={ - 'console_scripts': [f'{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main'], - } + "console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main"], + }, ) diff --git a/python/tests/base_module.py b/python/tests/base_module.py index 13bd0bdfe..73026d233 100644 --- a/python/tests/base_module.py +++ b/python/tests/base_module.py @@ -8,25 +8,25 @@ import yaml -class BaseModule(): - def __init__(self, package_name: str = 'rapidocr_onnxruntime'): +class BaseModule: + def __init__(self, package_name: str = "rapidocr_onnxruntime"): self.package_name = package_name self.root_dir = Path(__file__).resolve().parent.parent self.package_dir = self.root_dir / self.package_name - self.tests_dir = self.root_dir / 'tests' + self.tests_dir = self.root_dir / "tests" sys.path.append(str(self.root_dir)) sys.path.append(str(self.package_dir)) def init_module(self, module_name: str, class_name: str = None): if class_name is None: - module_part = importlib.import_module(f'{self.package_name}') + module_part = importlib.import_module(f"{self.package_name}") return module_part - module_part = importlib.import_module(f'{self.package_name}.{module_name}') + module_part = importlib.import_module(f"{self.package_name}.{module_name}") return getattr(module_part, class_name) @staticmethod def read_yaml(yaml_path: str): - with open(yaml_path, 'rb') as f: + with open(yaml_path, "rb") as f: data = yaml.load(f, Loader=yaml.Loader) return data diff --git a/python/tests/benchmark/benchmark.py b/python/tests/benchmark/benchmark.py index 8b55e8d20..7afe0233a 100644 --- a/python/tests/benchmark/benchmark.py +++ b/python/tests/benchmark/benchmark.py @@ -14,22 +14,22 @@ from rapidocr_onnxruntime import RapidOCR -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--yaml_path', type=str, default='config.yaml') + parser.add_argument("--yaml_path", type=str, default="config.yaml") args = parser.parse_args() yaml_path = cur_dir / args.yaml_path rapid_ocr = RapidOCR(yaml_path) - image_dir = cur_dir / 'test_images_benchmark' + image_dir = cur_dir / "test_images_benchmark" if not image_dir.exists(): - raise FileNotFoundError(f'{image_dir} does not exits!!') + raise FileNotFoundError(f"{image_dir} does not exits!!") image_list = list(image_dir.iterdir()) cost_time_list = [] - for image_path in tqdm(image_list, desc='Test'): + for image_path in tqdm(image_list, desc="Test"): img = cv2.imread(str(image_path)) start_time = time.time() @@ -40,6 +40,8 @@ total_time = sum(cost_time_list) avg_time = total_time / len(cost_time_list) - print(f'Total Files: {len(image_list)}, ' - f'Total Time: {total_time:.5f}, ' - f'Average Time: {avg_time:.5f}') + print( + f"Total Files: {len(image_list)}, " + f"Total Time: {total_time:.5f}, " + f"Average Time: {avg_time:.5f}" + ) diff --git a/python/tests/test_all_ort.py b/python/tests/test_all_ort.py index 6267c4199..4427c6cf4 100644 --- a/python/tests/test_all_ort.py +++ b/python/tests/test_all_ort.py @@ -16,14 +16,14 @@ from rapidocr_onnxruntime import RapidOCR, LoadImageError rapid_ocr = RapidOCR() -tests_dir = root_dir / 'tests' / 'test_files' +tests_dir = root_dir / "tests" / "test_files" def test_normal(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" img = cv2.imread(str(image_path)) result, _ = rapid_ocr(img) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 @@ -42,29 +42,29 @@ def test_zeros(): def test_input_str(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" result, _ = rapid_ocr(str(image_path)) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_bytes(): - image_path = tests_dir / 'ch_en_num.jpg' - with open(image_path, 'rb') as f: + image_path = tests_dir / "ch_en_num.jpg" + with open(image_path, "rb") as f: result, _ = rapid_ocr(f.read()) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_path(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" result, _ = rapid_ocr(image_path) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" rapid_ocr = RapidOCR(text_score=1) result, _ = rapid_ocr(image_path) @@ -72,27 +72,27 @@ def test_input_parameters(): def test_input_det_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(det_model_path='1.onnx') + rapid_ocr = RapidOCR(det_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError def test_input_cls_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(cls_model_path='1.onnx') + rapid_ocr = RapidOCR(cls_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError def test_input_rec_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(rec_model_path='1.onnx') + rapid_ocr = RapidOCR(rec_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError diff --git a/python/tests/test_all_vino.py b/python/tests/test_all_vino.py index 5551632aa..530f04ac0 100644 --- a/python/tests/test_all_vino.py +++ b/python/tests/test_all_vino.py @@ -16,14 +16,14 @@ rapid_ocr = RapidOCR() -tests_dir = root_dir / 'tests' / 'test_files' +tests_dir = root_dir / "tests" / "test_files" def test_normal(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" img = cv2.imread(str(image_path)) result, _ = rapid_ocr(img) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 @@ -42,29 +42,29 @@ def test_zeros(): def test_input_str(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" result, _ = rapid_ocr(str(image_path)) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_bytes(): - image_path = tests_dir / 'ch_en_num.jpg' - with open(image_path, 'rb') as f: + image_path = tests_dir / "ch_en_num.jpg" + with open(image_path, "rb") as f: result, _ = rapid_ocr(f.read()) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_path(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" result, _ = rapid_ocr(image_path) - assert result[0][1] == '正品促销' + assert result[0][1] == "正品促销" assert len(result) == 17 def test_input_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" rapid_ocr = RapidOCR(text_score=1) result, _ = rapid_ocr(image_path) @@ -72,27 +72,27 @@ def test_input_parameters(): def test_input_det_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(det_model_path='1.onnx') + rapid_ocr = RapidOCR(det_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError def test_input_cls_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(cls_model_path='1.onnx') + rapid_ocr = RapidOCR(cls_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError def test_input_rec_parameters(): - image_path = tests_dir / 'ch_en_num.jpg' + image_path = tests_dir / "ch_en_num.jpg" with pytest.raises(FileNotFoundError) as exc_info: - rapid_ocr = RapidOCR(rec_model_path='1.onnx') + rapid_ocr = RapidOCR(rec_model_path="1.onnx") result, _ = rapid_ocr(image_path) raise FileNotFoundError() assert exc_info.type is FileNotFoundError diff --git a/python/tests/test_cls.py b/python/tests/test_cls.py index 72c5b75e9..a5ebd044c 100644 --- a/python/tests/test_cls.py +++ b/python/tests/test_cls.py @@ -7,24 +7,22 @@ @pytest.mark.parametrize( - 'package_name', - [('rapidocr_onnxruntime'), - ('rapidocr_openvino')] + "package_name", [("rapidocr_onnxruntime"), ("rapidocr_openvino")] ) def test_cls(package_name: str): - module_name = 'ch_ppocr_v2_cls' - class_name = 'TextClassifier' + module_name = "ch_ppocr_v2_cls" + class_name = "TextClassifier" base = BaseModule(package_name=package_name) TextClassifier = base.init_module(module_name, class_name) - yaml_path = base.package_dir / module_name / 'config.yaml' + yaml_path = base.package_dir / module_name / "config.yaml" config = base.read_yaml(str(yaml_path)) - config['model_path'] = str(base.package_dir / config['model_path']) + config["model_path"] = str(base.package_dir / config["model_path"]) text_cls = TextClassifier(config) - img_path = base.tests_dir / 'test_files' / 'text_cls.jpg' + img_path = base.tests_dir / "test_files" / "text_cls.jpg" img = cv2.imread(str(img_path)) result = text_cls([img]) - assert result[1][0][0] == '180' + assert result[1][0][0] == "180" diff --git a/python/tests/test_det.py b/python/tests/test_det.py index c8e9ec516..d7bd0570d 100644 --- a/python/tests/test_det.py +++ b/python/tests/test_det.py @@ -6,23 +6,20 @@ from base_module import BaseModule -@pytest.mark.parametrize( - 'package_name', - ['rapidocr_onnxruntime', 'rapidocr_openvino'] -) +@pytest.mark.parametrize("package_name", ["rapidocr_onnxruntime", "rapidocr_openvino"]) def test_det(package_name): - module_name = 'ch_ppocr_v3_det' - class_name = 'TextDetector' + module_name = "ch_ppocr_v3_det" + class_name = "TextDetector" base = BaseModule(package_name) TextDetector = base.init_module(module_name, class_name) - yaml_path = base.package_dir / module_name / 'config.yaml' + yaml_path = base.package_dir / module_name / "config.yaml" config = base.read_yaml(str(yaml_path)) - config['model_path'] = str(base.package_dir / config['model_path']) + config["model_path"] = str(base.package_dir / config["model_path"]) text_det = TextDetector(config) - img_path = base.tests_dir / 'test_files' / 'text_det.jpg' + img_path = base.tests_dir / "test_files" / "text_det.jpg" img = cv2.imread(str(img_path)) dt_boxes, elapse = text_det(img) assert dt_boxes.shape == (18, 4, 2) diff --git a/python/tests/test_rec.py b/python/tests/test_rec.py index 98b79d085..e75a9d6ce 100644 --- a/python/tests/test_rec.py +++ b/python/tests/test_rec.py @@ -6,24 +6,21 @@ from base_module import BaseModule -@pytest.mark.parametrize( - 'package_name', - ['rapidocr_onnxruntime', 'rapidocr_openvino'] -) +@pytest.mark.parametrize("package_name", ["rapidocr_onnxruntime", "rapidocr_openvino"]) def test_det(package_name): - module_name = 'ch_ppocr_v3_rec' - class_name = 'TextRecognizer' + module_name = "ch_ppocr_v3_rec" + class_name = "TextRecognizer" base = BaseModule(package_name) TextRecognizer = base.init_module(module_name, class_name) - yaml_path = base.package_dir / module_name / 'config.yaml' + yaml_path = base.package_dir / module_name / "config.yaml" config = base.read_yaml(str(yaml_path)) - config['model_path'] = str(base.package_dir / config['model_path']) + config["model_path"] = str(base.package_dir / config["model_path"]) text_rec = TextRecognizer(config) - img_path = base.tests_dir / 'test_files' / 'text_rec.jpg' + img_path = base.tests_dir / "test_files" / "text_rec.jpg" img = cv2.imread(str(img_path)) rec_res, elapse = text_rec([img]) - assert rec_res[0][0] == '韩国小馆' + assert rec_res[0][0] == "韩国小馆"