From f4fd5cdcba2e2e9540f6c58d650965ba1a8534ed Mon Sep 17 00:00:00 2001
From: SWHL
Date: Sat, 24 Jun 2023 10:27:12 +0800
Subject: [PATCH] Format by black
---
README.md | 3 +-
api/rapidocr_api/api.py | 29 ++-
docs/README_en.md | 3 +-
ocrweb/rapidocr_web/ocrweb.py | 26 +--
ocrweb_multi/build.py | 26 +--
ocrweb_multi/main.py | 60 ++---
ocrweb_multi/rapidocr/classify.py | 18 +-
ocrweb_multi/rapidocr/detect.py | 12 +-
ocrweb_multi/rapidocr/detect_process.py | 75 +++---
ocrweb_multi/rapidocr/main.py | 14 +-
ocrweb_multi/rapidocr/rapid_ocr_api.py | 54 ++---
ocrweb_multi/rapidocr/recognize.py | 24 +-
ocrweb_multi/utils/config.py | 4 +-
ocrweb_multi/utils/utils.py | 30 +--
python/demo.py | 78 ++++---
.../ch_ppocr_v2_cls/text_cls.py | 23 +-
.../ch_ppocr_v2_cls/utils.py | 9 +-
.../ch_ppocr_v3_det/text_detect.py | 51 ++--
.../ch_ppocr_v3_det/utils.py | 145 ++++++------
.../ch_ppocr_v3_rec/text_recognize.py | 24 +-
.../ch_ppocr_v3_rec/utils.py | 26 +--
python/rapidocr_onnxruntime/rapid_ocr_api.py | 90 ++++----
python/rapidocr_onnxruntime/utils.py | 217 +++++++++---------
.../ch_ppocr_v2_cls/text_cls.py | 23 +-
.../ch_ppocr_v2_cls/utils.py | 11 +-
.../ch_ppocr_v3_det/text_detect.py | 19 +-
.../ch_ppocr_v3_det/utils.py | 145 ++++++------
.../ch_ppocr_v3_rec/text_recognize.py | 26 +--
.../ch_ppocr_v3_rec/utils.py | 26 +--
python/rapidocr_openvino/rapid_ocr_api.py | 90 ++++----
python/rapidocr_openvino/utils.py | 159 +++++++------
python/setup_onnxruntime.py | 47 ++--
python/setup_openvino.py | 49 ++--
python/tests/base_module.py | 12 +-
python/tests/benchmark/benchmark.py | 18 +-
python/tests/test_all_ort.py | 34 +--
python/tests/test_all_vino.py | 34 +--
python/tests/test_cls.py | 16 +-
python/tests/test_det.py | 15 +-
python/tests/test_rec.py | 17 +-
40 files changed, 918 insertions(+), 864 deletions(-)
diff --git a/README.md b/README.md
index 405f83319..bc869b87b 100755
--- a/README.md
+++ b/README.md
@@ -17,10 +17,11 @@
-
+
+
diff --git a/api/rapidocr_api/api.py b/api/rapidocr_api/api.py
index c599092e0..d00b407cb 100644
--- a/api/rapidocr_api/api.py
+++ b/api/rapidocr_api/api.py
@@ -19,7 +19,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent))
-class OCRAPIUtils():
+class OCRAPIUtils:
def __init__(self) -> None:
self.ocr = RapidOCR()
@@ -31,10 +31,10 @@ def __call__(self, img):
if not ocr_res:
return json.dumps({})
- out_dict = {str(i): {'rec_txt': rec,
- 'dt_boxes': dt_box,
- 'score': score}
- for i, (dt_box, rec, score) in enumerate(ocr_res)}
+ out_dict = {
+ str(i): {"rec_txt": rec, "dt_boxes": dt_box, "score": score}
+ for i, (dt_box, rec, score) in enumerate(ocr_res)
+ }
return out_dict
@@ -44,10 +44,10 @@ def __call__(self, img):
@app.get("/")
async def root():
- return {'message': 'Welcome to RapidOCR Server!'}
+ return {"message": "Welcome to RapidOCR Server!"}
-@app.post('/ocr')
+@app.post("/ocr")
async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
if image_file:
img = Image.open(image_file.file)
@@ -57,25 +57,24 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
img = Image.open(io.BytesIO(img_b64decode))
else:
raise ValueError(
- 'When sending a post request, data or files must have a value.')
+ "When sending a post request, data or files must have a value."
+ )
ocr_res = processor(img)
return ocr_res
def main():
- parser = argparse.ArgumentParser('rapidocr_api')
- parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0',
- help='IP Address')
- parser.add_argument('-p', '--port', type=int, default=9003,
- help='IP port')
+ parser = argparse.ArgumentParser("rapidocr_api")
+ parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
+ parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
args = parser.parse_args()
cur_file_path = Path(__file__).resolve()
- app_path = f'{cur_file_path.parent.name}.{cur_file_path.stem}:app'
+ app_path = f"{cur_file_path.parent.name}.{cur_file_path.stem}:app"
print(app_path)
uvicorn.run(app_path, host=args.ip, port=args.port, reload=True)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/docs/README_en.md b/docs/README_en.md
index 292430340..94937469a 100644
--- a/docs/README_en.md
+++ b/docs/README_en.md
@@ -16,10 +16,11 @@
-
+
+
diff --git a/ocrweb/rapidocr_web/ocrweb.py b/ocrweb/rapidocr_web/ocrweb.py
index 64ed43c66..45825e7bc 100644
--- a/ocrweb/rapidocr_web/ocrweb.py
+++ b/ocrweb/rapidocr_web/ocrweb.py
@@ -14,36 +14,34 @@
root_dir = Path(__file__).resolve().parent
-app = Flask(__name__, template_folder='templates')
-app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024
+app = Flask(__name__, template_folder="templates")
+app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024
processor = OCRWebUtils()
-@app.route('/')
+@app.route("/")
def index():
- return render_template('index.html')
+ return render_template("index.html")
-@app.route('/ocr', methods=['POST'])
+@app.route("/ocr", methods=["POST"])
def ocr():
- if request.method == 'POST':
- img_str = request.get_json().get('file', None)
+ if request.method == "POST":
+ img_str = request.get_json().get("file", None)
ocr_res = processor(img_str)
return ocr_res
def main():
- parser = argparse.ArgumentParser('rapidocr_web')
- parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0',
- help='IP Address')
- parser.add_argument('-p', '--port', type=int, default=9003,
- help='IP port')
+ parser = argparse.ArgumentParser("rapidocr_web")
+ parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
+ parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
args = parser.parse_args()
- print(f'Successfully launched and visit https://{args.ip}:{args.port} to view.')
+ print(f"Successfully launched and visit https://{args.ip}:{args.port} to view.")
server = make_server(args.ip, args.port, app)
server.serve_forever()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/ocrweb_multi/build.py b/ocrweb_multi/build.py
index a23ee2791..5b09f91b0 100644
--- a/ocrweb_multi/build.py
+++ b/ocrweb_multi/build.py
@@ -1,21 +1,21 @@
import os
import shutil
-print('Compile ocrweb')
-os.system('pyinstaller -y main.spec')
+print("Compile ocrweb")
+os.system("pyinstaller -y main.spec")
-print('Compile wrapper')
-os.system('windres .\wrapper.rc -O coff -o wrapper.res')
-os.system('gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe')
+print("Compile wrapper")
+os.system("windres .\wrapper.rc -O coff -o wrapper.res")
+os.system("gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe")
-print('Copy config.yaml')
-shutil.copy2('config.yaml', 'dist/config.yaml')
+print("Copy config.yaml")
+shutil.copy2("config.yaml", "dist/config.yaml")
-print('Copy models')
-shutil.copytree('models', 'dist/models', dirs_exist_ok=True)
-os.remove('dist/models/.gitkeep')
+print("Copy models")
+shutil.copytree("models", "dist/models", dirs_exist_ok=True)
+os.remove("dist/models/.gitkeep")
-print('Pack to ocrweb.zip')
-shutil.make_archive('ocrweb', 'zip', 'dist')
+print("Pack to ocrweb.zip")
+shutil.make_archive("ocrweb", "zip", "dist")
-print('Done')
+print("Done")
diff --git a/ocrweb_multi/main.py b/ocrweb_multi/main.py
index 05f773699..270f2a47b 100644
--- a/ocrweb_multi/main.py
+++ b/ocrweb_multi/main.py
@@ -13,61 +13,67 @@
from utils.utils import tojson, parse_bool
app = Flask(__name__)
-log = logging.getLogger('app')
+log = logging.getLogger("app")
# 设置上传文件大小
-app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024
+app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024
-@app.route('/')
+@app.route("/")
def index():
- return send_file('static/index.html')
+ return send_file("static/index.html")
def json_response(data, status=200):
- return make_response(tojson(data), status, {"content-type": 'application/json'})
+ return make_response(tojson(data), status, {"content-type": "application/json"})
-@app.route('/lang')
+@app.route("/lang")
def get_languages():
"""返回可用语言列表"""
data = [
- {'code': key, 'name': val['name']} for key, val in conf['languages'].items()
+ {"code": key, "name": val["name"]} for key, val in conf["languages"].items()
]
- result = {'msg': 'OK', 'data': data}
- log.info('Send langs: %s', data)
+ result = {"msg": "OK", "data": data}
+ log.info("Send langs: %s", data)
return json_response(result)
-@app.route('/ocr', methods=['POST', 'GET'])
+@app.route("/ocr", methods=["POST", "GET"])
def ocr():
"""执行文字识别"""
- if conf['server'].get('token'):
- if request.values.get('token') != conf['server']['token']:
- return json_response({'msg': 'invalid token'}, status=403)
+ if conf["server"].get("token"):
+ if request.values.get("token") != conf["server"]["token"]:
+ return json_response({"msg": "invalid token"}, status=403)
- lang = request.values.get('lang') or 'ch'
- detect = parse_bool(request.values.get('detect') or 'true')
- classify = parse_bool(request.values.get('classify') or 'true')
+ lang = request.values.get("lang") or "ch"
+ detect = parse_bool(request.values.get("detect") or "true")
+ classify = parse_bool(request.values.get("classify") or "true")
- image_file = request.files.get('image')
+ image_file = request.files.get("image")
if not image_file:
- return json_response({'msg': 'no image'}, 400)
+ return json_response({"msg": "no image"}, 400)
nparr = np.frombuffer(image_file.stream.read(), np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
- log.info('Input: image %s, lang=%s, detect=%s, classify=%s', image.shape, lang, detect, classify)
+ log.info(
+ "Input: image %s, lang=%s, detect=%s, classify=%s",
+ image.shape,
+ lang,
+ detect,
+ classify,
+ )
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
result = detect_recognize(image, lang=lang, detect=detect, classify=classify)
- log.info('OCR Done %s %s', result['ts'], len(result['results']))
- return json_response({'msg': 'OK', 'data': result})
+ log.info("OCR Done %s %s", result["ts"], len(result["results"]))
+ return json_response({"msg": "OK", "data": result})
-if __name__ == '__main__':
- logging.basicConfig(level='INFO')
- logging.getLogger('waitress').setLevel(logging.INFO)
- if parse_bool(conf.get('debug', '0')):
+if __name__ == "__main__":
+ logging.basicConfig(level="INFO")
+ logging.getLogger("waitress").setLevel(logging.INFO)
+ if parse_bool(conf.get("debug", "0")):
# Debug
- app.run(host=conf['server']['host'], port=conf['server']['port'], debug=True)
+ app.run(host=conf["server"]["host"], port=conf["server"]["port"], debug=True)
else:
# Deploy with waitress
- serve(app, host=conf['server']['host'], port=conf['server']['port'])
+ serve(app, host=conf["server"]["host"], port=conf["server"]["port"])
diff --git a/ocrweb_multi/rapidocr/classify.py b/ocrweb_multi/rapidocr/classify.py
index 5c2b37bcc..5a3e4c4e5 100644
--- a/ocrweb_multi/rapidocr/classify.py
+++ b/ocrweb_multi/rapidocr/classify.py
@@ -21,7 +21,7 @@
from utils.utils import OrtInferSession
-class ClsPostProcess():
+class ClsPostProcess:
"""Convert between text-label and text-index"""
def __init__(self, label_list):
@@ -40,18 +40,18 @@ def __call__(self, preds, label=None):
return decode_out, label
-class TextClassifier():
+class TextClassifier:
def __init__(self, path, config):
- self.cls_batch_num = config['batch_size']
- self.cls_thresh = config['score_thresh']
+ self.cls_batch_num = config["batch_size"]
+ self.cls_thresh = config["score_thresh"]
session_instance = OrtInferSession(path)
self.session = session_instance.session
metamap = self.session.get_modelmeta().custom_metadata_map
- self.cls_image_shape = json.loads(metamap['shape'])
+ self.cls_image_shape = json.loads(metamap["shape"])
- labels = json.loads(metamap['labels'])
+ labels = json.loads(metamap["labels"])
self.postprocess_op = ClsPostProcess(labels)
self.input_name = session_instance.get_input_name()
@@ -65,7 +65,7 @@ def resize_norm_img(self, img):
resized_w = int(math.ceil(img_h * ratio))
resized_image = cv2.resize(img, (resized_w, img_h))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if img_c == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -91,7 +91,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- cls_res = [['', 0.0]] * img_num
+ cls_res = [["", 0.0]] * img_num
batch_num = self.cls_batch_num
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
@@ -115,7 +115,7 @@ def __call__(self, img_list: List[np.ndarray]):
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
- if label == '180' and score > self.cls_thresh:
+ if label == "180" and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1
)
diff --git a/ocrweb_multi/rapidocr/detect.py b/ocrweb_multi/rapidocr/detect.py
index 136b0bf67..a6810ac78 100644
--- a/ocrweb_multi/rapidocr/detect.py
+++ b/ocrweb_multi/rapidocr/detect.py
@@ -21,10 +21,10 @@
from .detect_process import DBPostProcess, create_operators, transform
-class TextDetector():
+class TextDetector:
def __init__(self, path, config):
- self.preprocess_op = create_operators(config['pre_process'])
- self.postprocess_op = DBPostProcess(**config['post_process'])
+ self.preprocess_op = create_operators(config["pre_process"])
+ self.postprocess_op = DBPostProcess(**config["post_process"])
session_instance = OrtInferSession(path)
self.session = session_instance.session
@@ -32,11 +32,11 @@ def __init__(self, path, config):
def __call__(self, img):
if img is None:
- raise ValueError('img is None')
+ raise ValueError("img is None")
ori_im_shape = img.shape[:2]
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -49,7 +49,7 @@ def __call__(self, img):
post_result = self.postprocess_op(preds[0], shape_list)
- dt_boxes = post_result[0]['points']
+ dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
return dt_boxes
diff --git a/ocrweb_multi/rapidocr/detect_process.py b/ocrweb_multi/rapidocr/detect_process.py
index 2b4307b74..e941039f5 100644
--- a/ocrweb_multi/rapidocr/detect_process.py
+++ b/ocrweb_multi/rapidocr/detect_process.py
@@ -18,7 +18,6 @@
# @Contact: liekkaskono@163.com
from copy import deepcopy
import sys
-import warnings
import cv2
import numpy as np
@@ -27,15 +26,15 @@
from shapely.geometry import Polygon
-class DecodeImage():
+class DecodeImage:
"""decode image"""
- def __init__(self, img_mode='RGB', channel_first=False):
+ def __init__(self, img_mode="RGB", channel_first=False):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
- img = data['image']
+ img = data["image"]
if six.PY2:
assert (
type(img) is str and len(img) > 0
@@ -45,53 +44,53 @@ def __call__(self, data):
type(img) is bytes and len(img) > 0
), "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
+ img = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(img, 1)
if img is None:
return None
- if self.img_mode == 'GRAY':
+ if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]'
+ elif self.img_mode == "RGB":
+ assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
- data['image'] = img
+ data["image"] = img
return data
-class NormalizeImage():
+class NormalizeImage:
"""normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order='chw'):
+ def __init__(self, scale=None, mean=None, std=None, order="chw"):
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
- img = np.array(data['image']).astype(np.float32)
- data['image'] = (img * self.scale - self.mean) / self.std
+ img = np.array(data["image"]).astype(np.float32)
+ data["image"] = (img * self.scale - self.mean) / self.std
return data
-class ToCHWImage():
+class ToCHWImage:
"""convert hwc image to chw image"""
def __init__(self):
pass
def __call__(self, data):
- img = data['image']
- data['image'] = img.transpose((2, 0, 1))
+ img = data["image"]
+ data["image"] = img.transpose((2, 0, 1))
return data
-class KeepKeys():
+class KeepKeys:
def __init__(self, keep_keys):
self.keep_keys = keep_keys
@@ -102,25 +101,25 @@ def __call__(self, data):
return data_list
-class DetResizeForTest():
+class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs['limit_side_len']
- self.limit_type = kwargs.get('limit_type', 'min')
- elif 'resize_long' in kwargs:
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs["limit_side_len"]
+ self.limit_type = kwargs.get("limit_type", "min")
+ elif "resize_long" in kwargs:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
else:
self.limit_side_len = 736
- self.limit_type = 'min'
+ self.limit_type = "min"
def __call__(self, data):
- img = data['image']
+ img = data["image"]
src_h, src_w, _ = img.shape
if self.resize_type == 0:
@@ -131,8 +130,8 @@ def __call__(self, data):
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
- data['image'] = img
- data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ data["image"] = img
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
@@ -156,7 +155,7 @@ def resize_image_type0(self, img):
h, w, _ = img.shape
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
@@ -234,7 +233,7 @@ def create_operators(op_param_list):
ops = []
for args in op_param_list:
args = deepcopy(args)
- op_class = op_map[args.pop('class')]
+ op_class = op_map[args.pop("class")]
ops.append(op_class(**args))
return ops
@@ -247,7 +246,7 @@ def draw_text_det_res(dt_boxes, img_path):
return src_im
-class DBPostProcess():
+class DBPostProcess:
"""The post process for Differentiable Binarization (DB)."""
def __init__(
@@ -270,10 +269,10 @@ def __init__(
self.dilation_kernel = None
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
- '''
+ """
bitmap = _bitmap
height, width = bitmap.shape
@@ -376,5 +375,5 @@ def __call__(self, pred, shape_list):
pred[batch_index], mask, src_w, src_h
)
- boxes_batch.append({'points': boxes})
+ boxes_batch.append({"points": boxes})
return boxes_batch
diff --git a/ocrweb_multi/rapidocr/main.py b/ocrweb_multi/rapidocr/main.py
index d88f8e755..016ebea35 100644
--- a/ocrweb_multi/rapidocr/main.py
+++ b/ocrweb_multi/rapidocr/main.py
@@ -13,21 +13,21 @@
@lru_cache(maxsize=None)
-def load_language_model(lang='ch'):
- models = conf['languages'][lang]
- print('model', models)
+def load_language_model(lang="ch"):
+ models = conf["languages"][lang]
+ print("model", models)
return RapidOCR(models)
-def detect_recognize(image, lang='ch', detect=True, classify=True):
+def detect_recognize(image, lang="ch", detect=True, classify=True):
model = load_language_model(lang)
results, ts = model(image, detect=detect, classify=classify)
- ts['total'] = sum(ts.values())
- return {'ts': ts, 'results': results}
+ ts["total"] = sum(ts.values())
+ return {"ts": ts, "results": results}
def check_and_read_gif(img_path):
- if Path(img_path).suffix.lower() == 'gif':
+ if Path(img_path).suffix.lower() == "gif":
gif = cv2.VideoCapture(img_path)
ret, frame = gif.read()
if not ret:
diff --git a/ocrweb_multi/rapidocr/rapid_ocr_api.py b/ocrweb_multi/rapidocr/rapid_ocr_api.py
index d5bc8e8e1..d2710a99b 100644
--- a/ocrweb_multi/rapidocr/rapid_ocr_api.py
+++ b/ocrweb_multi/rapidocr/rapid_ocr_api.py
@@ -52,61 +52,61 @@ def get_rotate_crop_image(img, points):
@lru_cache(maxsize=None)
def load_onnx_model(step, name):
- model_config = conf['models'][step][name]
+ model_config = conf["models"][step][name]
model_class = {
- 'detect': TextDetector,
- 'classify': TextClassifier,
- 'recognize': TextRecognizer,
+ "detect": TextDetector,
+ "classify": TextClassifier,
+ "recognize": TextRecognizer,
}[step]
- return model_class(model_config['path'], model_config.get('config'))
+ return model_class(model_config["path"], model_config.get("config"))
-class RapidOCR():
+class RapidOCR:
def __init__(self, config):
super(RapidOCR).__init__()
self.config = config
- self.text_score = config['config']['text_score']
- self.min_height = config['config']['min_height']
+ self.text_score = config["config"]["text_score"]
+ self.min_height = config["config"]["min_height"]
- models = config['models']
- self.text_detector = load_onnx_model('detect', models['detect'])
- self.text_recognizer = load_onnx_model('recognize', models['recognize'])
- self.text_cls = load_onnx_model('classify', models['classify'])
+ models = config["models"]
+ self.text_detector = load_onnx_model("detect", models["detect"])
+ self.text_recognizer = load_onnx_model("recognize", models["recognize"])
+ self.text_cls = load_onnx_model("classify", models["classify"])
def __call__(self, img: np.ndarray, detect=True, classify=True):
ticker = Ticker()
h, w = img.shape[:2]
if not detect or h < self.min_height:
dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w)
- ticker.tick('detect')
+ ticker.tick("detect")
else:
dt_boxes = self.text_detector(img)
- ticker.tick('detect')
+ ticker.tick("detect")
if dt_boxes is None or len(dt_boxes) < 1:
return [], ticker.maps
- if conf['global']['verbose']:
- print(f'boxes num: {len(dt_boxes)}')
+ if conf["global"]["verbose"]:
+ print(f"boxes num: {len(dt_boxes)}")
dt_boxes = self.sorted_boxes(dt_boxes)
img_crop_list = self.get_crop_img_list(img, dt_boxes)
- ticker.tick('post-detect')
+ ticker.tick("post-detect")
if classify:
# 进行子图像角度修正
img_crop_list, _ = self.text_cls(img_crop_list)
- ticker.tick('classify')
- if conf['global']['verbose']:
- print(f'cls num: {len(img_crop_list)}')
+ ticker.tick("classify")
+ if conf["global"]["verbose"]:
+ print(f"cls num: {len(img_crop_list)}")
recog_result = self.text_recognizer(img_crop_list)
- ticker.tick('recognize')
- if conf['global']['verbose']:
- print(f'rec_res num: {len(recog_result)}')
+ ticker.tick("recognize")
+ if conf["global"]["verbose"]:
+ print(f"rec_res num: {len(recog_result)}")
results = self.filter_boxes_rec_by_score(dt_boxes, recog_result)
- ticker.tick('post-recognize')
+ ticker.tick("post-recognize")
return results, ticker.maps
def get_boxes_img_without_det(self, img, h, w):
@@ -134,13 +134,13 @@ def sorted_boxes(dt_boxes):
sorted boxes(array) with shape [4, 2]
"""
- class AlignBox():
+ class AlignBox:
def __init__(self, data) -> None:
self.data = data
self.x = data[0][0]
self.y = data[0][1]
- def __lt__(self, other: 'AlignBox'):
+ def __lt__(self, other: "AlignBox"):
dy = self.y - other.y
# y差距小于10, 视为相等, 根据x排序
if abs(dy) < 10:
@@ -156,5 +156,5 @@ def filter_boxes_rec_by_score(self, dt_boxes, rec_res):
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
if score >= self.text_score:
- results.append({'box': box, 'text': text, 'score': score})
+ results.append({"box": box, "text": text, "score": score})
return results
diff --git a/ocrweb_multi/rapidocr/recognize.py b/ocrweb_multi/rapidocr/recognize.py
index ae14ff741..ff79e2540 100644
--- a/ocrweb_multi/rapidocr/recognize.py
+++ b/ocrweb_multi/rapidocr/recognize.py
@@ -11,26 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import argparse
import json
import math
-import time
from typing import List
import cv2
import numpy as np
-from utils.utils import get_resource_path, OrtInferSession
+from utils.utils import OrtInferSession
-class CTCLabelDecode():
+class CTCLabelDecode:
"""Convert between text-label and text-index"""
def __init__(self, characters: List[str]):
super(CTCLabelDecode, self).__init__()
self.characters = characters
- self.characters.append(' ')
+ self.characters.append(" ")
dict_character = self.add_special_char(self.characters)
self.character = dict_character
@@ -49,7 +47,7 @@ def __call__(self, preds, label=None):
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
def get_ignored_tokens(self):
@@ -81,22 +79,22 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
conf_list.append(1)
# avoid `Mean of empty slice.` warning
score = np.mean(conf_list) if conf_list else 0
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, score))
return result_list
-class TextRecognizer():
+class TextRecognizer:
def __init__(self, path, config):
- self.rec_batch_num = config.get('rec_batch_num', 6)
+ self.rec_batch_num = config.get("rec_batch_num", 6)
session_instance = OrtInferSession(path)
self.session = session_instance.session
metamap = session_instance.session.get_modelmeta().custom_metadata_map
- chars = metamap['dictionary'].splitlines()
+ chars = metamap["dictionary"].splitlines()
self.postprocess_op = CTCLabelDecode(chars)
- self.rec_image_shape = json.loads(metamap['shape'])
+ self.rec_image_shape = json.loads(metamap["shape"])
self.input_name = session_instance.get_input_name()
def resize_norm_img(self, img, max_wh_ratio):
@@ -114,7 +112,7 @@ def resize_norm_img(self, img, max_wh_ratio):
resized_w = int(math.ceil(img_height * ratio))
resized_image = cv2.resize(img, (resized_w, img_height))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
@@ -134,7 +132,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- rec_res = [['', 0.0]] * img_num
+ rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
for beg_img_no in range(0, img_num, batch_num):
diff --git a/ocrweb_multi/utils/config.py b/ocrweb_multi/utils/config.py
index aca5c6a68..fda18f8f8 100644
--- a/ocrweb_multi/utils/config.py
+++ b/ocrweb_multi/utils/config.py
@@ -19,9 +19,9 @@ def get_resource_path(name: str):
Path(name),
]:
if path.exists():
- print('Loaded:', path)
+ print("Loaded:", path)
return path
raise FileNotFoundError(name)
-conf = yaml.safe_load(get_resource_path('config.yaml').read_text(encoding='utf-8'))
+conf = yaml.safe_load(get_resource_path("config.yaml").read_text(encoding="utf-8"))
diff --git a/ocrweb_multi/utils/utils.py b/ocrweb_multi/utils/utils.py
index f1916113f..9e9b4ac78 100644
--- a/ocrweb_multi/utils/utils.py
+++ b/ocrweb_multi/utils/utils.py
@@ -14,33 +14,33 @@
def parse_bool(val):
if not isinstance(val, str):
return bool(val)
- return val.lower() in ('1', 'true', 'yes')
+ return val.lower() in ("1", "true", "yes")
def default(obj):
- if hasattr(obj, 'tolist'):
+ if hasattr(obj, "tolist"):
return obj.tolist()
return obj
def tojson(obj, **kws):
- return json.dumps(obj, default=default, ensure_ascii=False, **kws) + '\n'
+ return json.dumps(obj, default=default, ensure_ascii=False, **kws) + "\n"
-class OrtInferSession():
+class OrtInferSession:
def __init__(self, model_path):
- ort_conf = conf['global']
+ ort_conf = conf["global"]
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
- cuda_ep = 'CUDAExecutionProvider'
- cpu_ep = 'CPUExecutionProvider'
+ cuda_ep = "CUDAExecutionProvider"
+ cpu_ep = "CPUExecutionProvider"
providers = []
if (
- ort_conf['use_cuda']
- and get_device() == 'GPU'
+ ort_conf["use_cuda"]
+ and get_device() == "GPU"
and cuda_ep in get_available_providers()
):
providers = [(cuda_ep, ort_conf[cuda_ep])]
@@ -53,12 +53,12 @@ def __init__(self, model_path):
providers=providers,
)
- if ort_conf['use_cuda'] and cuda_ep not in self.session.get_providers():
+ if ort_conf["use_cuda"] and cuda_ep not in self.session.get_providers():
warnings.warn(
- f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
- 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
- 'you can check their relations from the offical web site: '
- 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
+ f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
+ "you can check their relations from the offical web site: "
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
RuntimeWarning,
)
@@ -69,7 +69,7 @@ def get_output_name(self, output_idx=0):
return self.session.get_outputs()[output_idx].name
-class Ticker():
+class Ticker:
def __init__(self, reset=True) -> None:
self.ts = time.perf_counter()
self.reset = reset
diff --git a/python/demo.py b/python/demo.py
index ab4e963c2..e33dd6eaf 100644
--- a/python/demo.py
+++ b/python/demo.py
@@ -10,18 +10,20 @@
from PIL import Image, ImageDraw, ImageFont
from rapidocr_onnxruntime import RapidOCR
+
# from rapidocr_openvino import RapidOCR
-def draw_ocr_box_txt(image, boxes, txts, font_path,
- scores=None, text_score=0.5):
+def draw_ocr_box_txt(image, boxes, txts, font_path, scores=None, text_score=0.5):
if not Path(font_path).exists():
- raise FileNotFoundError(f'The {font_path} does not exists! \n'
- f'Please download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing')
+ raise FileNotFoundError(
+ f"The {font_path} does not exists! \n"
+ f"Please download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing"
+ )
h, w = image.height, image.width
img_left = image.copy()
- img_right = Image.new('RGB', (w, h), (255, 255, 255))
+ img_right = Image.new("RGB", (w, h), (255, 255, 255))
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
@@ -30,40 +32,45 @@ def draw_ocr_box_txt(image, boxes, txts, font_path,
if scores is not None and float(scores[idx]) < text_score:
continue
- color = (random.randint(0, 255),
- random.randint(0, 255),
- random.randint(0, 255))
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw_left.polygon(box, fill=color)
- draw_right.polygon([box[0][0], box[0][1],
- box[1][0], box[1][1],
- box[2][0], box[2][1],
- box[3][0], box[3][1]],
- outline=color)
-
- box_height = math.sqrt((box[0][0] - box[3][0])**2
- + (box[0][1] - box[3][1])**2)
-
- box_width = math.sqrt((box[0][0] - box[1][0])**2
- + (box[0][1] - box[1][1])**2)
+ draw_right.polygon(
+ [
+ box[0][0],
+ box[0][1],
+ box[1][0],
+ box[1][1],
+ box[2][0],
+ box[2][1],
+ box[3][0],
+ box[3][1],
+ ],
+ outline=color,
+ )
+
+ box_height = math.sqrt(
+ (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2
+ )
+
+ box_width = math.sqrt(
+ (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2
+ )
if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10)
- font = ImageFont.truetype(font_path, font_size,
- encoding="utf-8")
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
cur_y = box[0][1]
for c in txt:
char_size = font.getsize(c)
- draw_right.text((box[0][0] + 3, cur_y), c,
- fill=(0, 0, 0), font=font)
+ draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
cur_y += char_size[1]
else:
font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
- draw_right.text([box[0][0], box[0][1]], txt,
- fill=(0, 0, 0), font=font)
+ draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
- img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
+ img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
return np.array(img_show)
@@ -73,29 +80,28 @@ def visualize(image_path, result, font_path="resources/fonts/FZYTK.TTF"):
image = Image.open(image_path)
boxes, txts, scores = list(zip(*result))
- draw_img = draw_ocr_box_txt(image, np.array(boxes),
- txts, font_path,
- scores,
- text_score=0.5)
+ draw_img = draw_ocr_box_txt(
+ image, np.array(boxes), txts, font_path, scores, text_score=0.5
+ )
draw_img_save = Path("./inference_results/")
if not draw_img_save.exists():
draw_img_save.mkdir(parents=True, exist_ok=True)
- image_save = str(draw_img_save / f'infer_{Path(image_path).name}')
+ image_save = str(draw_img_save / f"infer_{Path(image_path).name}")
cv2.imwrite(image_save, draw_img[:, :, ::-1])
- print(f'The infer result has saved in {image_save}')
+ print(f"The infer result has saved in {image_save}")
-if __name__ == '__main__':
+if __name__ == "__main__":
rapid_ocr = RapidOCR()
- image_path = 'tests/test_files/ch_en_num.jpg'
- with open(image_path, 'rb') as f:
+ image_path = "tests/test_files/ch_en_num.jpg"
+ with open(image_path, "rb") as f:
img = f.read()
result, elapse_list = rapid_ocr(img)
print(result)
print(elapse_list)
if result:
- visualize(image_path, result, font_path='resources/fonts/FZYTK.TTF')
+ visualize(image_path, result, font_path="resources/fonts/FZYTK.TTF")
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py
index d6abdacbe..7c289b457 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/text_cls.py
@@ -25,12 +25,12 @@
from .utils import ClsPostProcess
-class TextClassifier():
+class TextClassifier:
def __init__(self, config):
- self.cls_image_shape = config['cls_image_shape']
- self.cls_batch_num = config['cls_batch_num']
- self.cls_thresh = config['cls_thresh']
- self.postprocess_op = ClsPostProcess(config['label_list'])
+ self.cls_image_shape = config["cls_image_shape"]
+ self.cls_batch_num = config["cls_batch_num"]
+ self.cls_thresh = config["cls_thresh"]
+ self.postprocess_op = ClsPostProcess(config["label_list"])
self.infer = OrtInferSession(config)
@@ -47,7 +47,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- cls_res = [['', 0.0]] * img_num
+ cls_res = [["", 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
@@ -68,9 +68,10 @@ def __call__(self, img_list: List[np.ndarray]):
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
- if '180' in label and score > self.cls_thresh:
+ if "180" in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
- img_list[indices[beg_img_no + rno]], 1)
+ img_list[indices[beg_img_no + rno]], 1
+ )
return img_list, cls_res, elapse
def resize_norm_img(self, img):
@@ -83,7 +84,7 @@ def resize_norm_img(self, img):
resized_w = int(math.ceil(img_h * ratio))
resized_image = cv2.resize(img, (resized_w, img_h))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if img_c == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -99,8 +100,8 @@ def resize_norm_img(self, img):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--image_path', type=str, help='image_dir|image_path')
- parser.add_argument('--config_path', type=str, default='config.yaml')
+ parser.add_argument("--image_path", type=str, help="image_dir|image_path")
+ parser.add_argument("--config_path", type=str, default="config.yaml")
args = parser.parse_args()
config = read_yaml(args.config_path)
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py
index 466fee949..5c75d54ee 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v2_cls/utils.py
@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-class ClsPostProcess():
- """ Convert between text-label and text-index """
+class ClsPostProcess:
+ """Convert between text-label and text-index"""
def __init__(self, label_list):
super(ClsPostProcess, self).__init__()
@@ -20,8 +20,9 @@ def __init__(self, label_list):
def __call__(self, preds, label=None):
pred_idxs = preds.argmax(axis=1)
- decode_out = [(self.label_list[idx], preds[i, idx])
- for i, idx in enumerate(pred_idxs)]
+ decode_out = [
+ (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
+ ]
if label is None:
return decode_out
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py
index 99f7a00de..0e42377f9 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/text_detect.py
@@ -25,33 +25,31 @@
from .utils import DBPostProcess, create_operators, transform
-class TextDetector():
+class TextDetector:
def __init__(self, config):
pre_process_list = {
- 'DetResizeForTest': {
- 'limit_side_len': config.get('limit_side_len', 736),
- 'limit_type': config.get('limit_type', 'min')
+ "DetResizeForTest": {
+ "limit_side_len": config.get("limit_side_len", 736),
+ "limit_type": config.get("limit_type", "min"),
},
- 'NormalizeImage': {
- 'std': [0.229, 0.224, 0.225],
- 'mean': [0.485, 0.456, 0.406],
- 'scale': '1./255.',
- 'order': 'hwc'
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
},
- 'ToCHWImage': None,
- 'KeepKeys': {
- 'keep_keys': ['image', 'shape']
- }
+ "ToCHWImage": None,
+ "KeepKeys": {"keep_keys": ["image", "shape"]},
}
self.preprocess_op = create_operators(pre_process_list)
post_process = {
- 'thresh': config.get('thresh', 0.3),
- 'box_thresh': config.get('box_thresh', 0.5),
- 'max_candidates': config.get('max_candidates', 1000),
- 'unclip_ratio': config.get('unclip_ratio', 1.6),
- 'use_dilation': config.get('use_dilation', True),
- 'score_mode': config.get('score_mode', 'fast'),
+ "thresh": config.get("thresh", 0.3),
+ "box_thresh": config.get("box_thresh", 0.5),
+ "max_candidates": config.get("max_candidates", 1000),
+ "unclip_ratio": config.get("unclip_ratio", 1.6),
+ "use_dilation": config.get("use_dilation", True),
+ "score_mode": config.get("score_mode", "fast"),
}
self.postprocess_op = DBPostProcess(**post_process)
@@ -59,11 +57,11 @@ def __init__(self, config):
def __call__(self, img):
if img is None:
- raise ValueError('img is None')
+ raise ValueError("img is None")
ori_im_shape = img.shape[:2]
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -76,7 +74,7 @@ def __call__(self, img):
preds = self.infer(img)[0]
post_result = self.postprocess_op(preds, shape_list)
- dt_boxes = post_result[0]['points']
+ dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
elapse = time.time() - starttime
return dt_boxes, elapse
@@ -129,8 +127,8 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--config_path', type=str, default='config.yaml')
- parser.add_argument('--image_path', type=str, default=None)
+ parser.add_argument("--config_path", type=str, default="config.yaml")
+ parser.add_argument("--image_path", type=str, default=None)
args = parser.parse_args()
config = read_yaml(args.config_path)
@@ -141,6 +139,7 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
dt_boxes, elapse = text_detector(img)
from utils import draw_text_det_res
+
src_im = draw_text_det_res(dt_boxes, args.image_path)
- cv2.imwrite('det_results.jpg', src_im)
- print('The det_results.jpg has been saved in the current directory.')
+ cv2.imwrite("det_results.jpg", src_im)
+ print("The det_results.jpg has been saved in the current directory.")
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py
index b1b489f2f..3b08e07ea 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_det/utils.py
@@ -25,69 +25,74 @@
from shapely.geometry import Polygon
-class DecodeImage():
- """ decode image """
+class DecodeImage:
+ """decode image"""
- def __init__(self, img_mode='RGB', channel_first=False):
+ def __init__(self, img_mode="RGB", channel_first=False):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
- img = data['image']
+ img = data["image"]
if six.PY2:
- assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is str and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
else:
- assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is bytes and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
+ img = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(img, 1)
if img is None:
return None
- if self.img_mode == 'GRAY':
+ if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]'
+ elif self.img_mode == "RGB":
+ assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
- data['image'] = img
+ data["image"] = img
return data
-class NormalizeImage():
- """ normalize image such as substract mean, divide std"""
+class NormalizeImage:
+ """normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order='chw'):
+ def __init__(self, scale=None, mean=None, std=None, order="chw"):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
- img = np.array(data['image']).astype(np.float32)
- data['image'] = (img * self.scale - self.mean) / self.std
+ img = np.array(data["image"]).astype(np.float32)
+ data["image"] = (img * self.scale - self.mean) / self.std
return data
-class ToCHWImage():
- """ convert hwc image to chw image"""
+class ToCHWImage:
+ """convert hwc image to chw image"""
+
def __init__(self):
pass
def __call__(self, data):
- img = np.array(data['image'])
- data['image'] = img.transpose((2, 0, 1))
+ img = np.array(data["image"])
+ data["image"] = img.transpose((2, 0, 1))
return data
-class KeepKeys():
+class KeepKeys:
def __init__(self, keep_keys):
self.keep_keys = keep_keys
@@ -98,26 +103,26 @@ def __call__(self, data):
return data_list
-class DetResizeForTest():
+class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs.get('limit_side_len', 736)
- self.limit_type = kwargs.get('limit_type', 'min')
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs.get("limit_side_len", 736)
+ self.limit_type = kwargs.get("limit_type", "min")
- if 'resize_long' in kwargs:
+ if "resize_long" in kwargs:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
else:
- self.limit_side_len = kwargs.get('limit_side_len', 736)
- self.limit_type = kwargs.get('limit_type', 'min')
+ self.limit_side_len = kwargs.get("limit_side_len", 736)
+ self.limit_type = kwargs.get("limit_type", "min")
def __call__(self, data):
- img = data['image']
+ img = data["image"]
src_h, src_w = img.shape[:2]
if self.resize_type == 0:
@@ -128,8 +133,8 @@ def __call__(self, data):
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
- data['image'] = img
- data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ data["image"] = img
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
@@ -153,14 +158,14 @@ def resize_image_type0(self, img):
h, w = img.shape[:2]
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
else:
if min(h, w) < limit_side_len:
if h < w:
@@ -168,7 +173,7 @@ def resize_image_type0(self, img):
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
resize_h = int(h * ratio)
resize_w = int(w * ratio)
@@ -212,7 +217,7 @@ def resize_image_type2(self, img):
def transform(data, ops=None):
- """ transform """
+ """transform"""
if ops is None:
ops = []
@@ -240,21 +245,22 @@ def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
- cv2.polylines(src_im, [box], True,
- color=(255, 255, 0), thickness=2)
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
-class DBPostProcess():
+class DBPostProcess:
"""The post process for Differentiable Binarization (DB)."""
- def __init__(self,
- thresh=0.3,
- box_thresh=0.7,
- max_candidates=1000,
- unclip_ratio=2.0,
- score_mode="fast",
- use_dilation=False):
+ def __init__(
+ self,
+ thresh=0.3,
+ box_thresh=0.7,
+ max_candidates=1000,
+ unclip_ratio=2.0,
+ score_mode="fast",
+ use_dilation=False,
+ ):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
@@ -268,16 +274,17 @@ def __init__(self,
self.dilation_kernel = None
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
- '''
+ """
bitmap = _bitmap
height, width = bitmap.shape
- outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
- cv2.CHAIN_APPROX_SIMPLE)
+ outs = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
@@ -306,10 +313,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
continue
box = np.array(box)
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
@@ -341,9 +348,7 @@ def get_mini_boxes(self, contour):
index_2 = 3
index_3 = 2
- box = [
- points[index_1], points[index_2], points[index_3], points[index_4]
- ]
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
@@ -358,12 +363,12 @@ def box_score_fast(self, bitmap, _box):
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
- '''
+ """
box_score_slow: use polyon mean score as the mean score
- '''
+ """
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
@@ -379,7 +384,7 @@ def box_score_slow(self, bitmap, contour):
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def __call__(self, pred, shape_list):
pred = pred[:, 0, :, :]
@@ -391,11 +396,13 @@ def __call__(self, pred, shape_list):
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
- self.dilation_kernel)
+ self.dilation_kernel,
+ )
else:
mask = segmentation[batch_index]
- boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
- src_w, src_h)
+ boxes, scores = self.boxes_from_bitmap(
+ pred[batch_index], mask, src_w, src_h
+ )
- boxes_batch.append({'points': boxes})
+ boxes_batch.append({"points": boxes})
return boxes_batch
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py
index 1ce14a621..b487cd296 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/text_recognize.py
@@ -24,18 +24,18 @@
from .utils import CTCLabelDecode
-class TextRecognizer():
+class TextRecognizer:
def __init__(self, config):
self.session = OrtInferSession(config)
if self.session.have_key():
self.character_dict_path = self.session.get_character_list()
else:
- self.character_dict_path = config.get('keys_path', None)
+ self.character_dict_path = config.get("keys_path", None)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)
- self.rec_batch_num = config['rec_batch_num']
- self.rec_image_shape = config['rec_img_shape']
+ self.rec_batch_num = config["rec_batch_num"]
+ self.rec_image_shape = config["rec_img_shape"]
def __call__(self, img_list: List[np.ndarray]):
if isinstance(img_list, np.ndarray):
@@ -48,7 +48,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- rec_res = [['', 0.0]] * img_num
+ rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
@@ -62,8 +62,7 @@ def __call__(self, img_list: List[np.ndarray]):
norm_img_batch = []
for ino in range(beg_img_no, end_img_no):
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img_batch.append(norm_img[np.newaxis, :])
norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
@@ -90,21 +89,20 @@ def resize_norm_img(self, img, max_wh_ratio):
resized_w = int(math.ceil(img_height * ratio))
resized_image = cv2.resize(img, (resized_w, img_height))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
- padding_im = np.zeros((img_channel, img_height, img_width),
- dtype=np.float32)
+ padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--image_path', type=str, help='image_dir|image_path')
- parser.add_argument('--config_path', type=str, default='config.yaml')
+ parser.add_argument("--image_path", type=str, help="image_dir|image_path")
+ parser.add_argument("--config_path", type=str, default="config.yaml")
args = parser.parse_args()
config = read_yaml(args.config_path)
@@ -112,4 +110,4 @@ def resize_norm_img(self, img, max_wh_ratio):
img = cv2.imread(args.image_path)
rec_res, predict_time = text_recognizer(img)
- print(f'rec result: {rec_res}\t cost: {predict_time}s')
+ print(f"rec result: {rec_res}\t cost: {predict_time}s")
diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py
index a1a53b8f2..1cde51931 100644
--- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py
+++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py
@@ -4,8 +4,8 @@
import numpy as np
-class CTCLabelDecode():
- """ Convert between text-label and text-index """
+class CTCLabelDecode:
+ """Convert between text-label and text-index"""
def __init__(self, character_dict_path):
super(CTCLabelDecode, self).__init__()
@@ -17,11 +17,11 @@ def __init__(self, character_dict_path):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
else:
self.character_str = character_dict_path
- self.character_str.append(' ')
+ self.character_str.append(" ")
dict_character = self.add_special_char(self.character_str)
self.character = dict_character
@@ -33,22 +33,21 @@ def __init__(self, character_dict_path):
def __call__(self, preds, label=None):
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob,
- is_remove_duplicate=True)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
def get_ignored_tokens(self):
return [0] # for ctc blank
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
@@ -61,15 +60,16 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
continue
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list + [1e-50])))
return result_list
diff --git a/python/rapidocr_onnxruntime/rapid_ocr_api.py b/python/rapidocr_onnxruntime/rapid_ocr_api.py
index 7f92c0f95..0c351b3c7 100644
--- a/python/rapidocr_onnxruntime/rapid_ocr_api.py
+++ b/python/rapidocr_onnxruntime/rapid_ocr_api.py
@@ -12,17 +12,16 @@
from .ch_ppocr_v2_cls import TextClassifier
from .ch_ppocr_v3_det import TextDetector
from .ch_ppocr_v3_rec import TextRecognizer
-from .utils import (LoadImage, UpdateParameters, concat_model_path, init_args,
- read_yaml)
+from .utils import LoadImage, UpdateParameters, concat_model_path, init_args, read_yaml
root_dir = Path(__file__).resolve().parent
-class RapidOCR():
+class RapidOCR:
def __init__(self, **kwargs):
- config_path = str(root_dir / 'config.yaml')
+ config_path = str(root_dir / "config.yaml")
if not Path(config_path).exists():
- raise FileExistsError(f'{config_path} does not exist!')
+ raise FileExistsError(f"{config_path} does not exist!")
config = read_yaml(config_path)
config = concat_model_path(config)
@@ -30,30 +29,29 @@ def __init__(self, **kwargs):
updater = UpdateParameters()
config = updater(config, **kwargs)
- global_config = config['Global']
- self.print_verbose = global_config['print_verbose']
- self.text_score = global_config['text_score']
- self.min_height = global_config['min_height']
- self.width_height_ratio = global_config['width_height_ratio']
+ global_config = config["Global"]
+ self.print_verbose = global_config["print_verbose"]
+ self.text_score = global_config["text_score"]
+ self.min_height = global_config["min_height"]
+ self.width_height_ratio = global_config["width_height_ratio"]
- self.use_text_det = config['Global']['use_text_det']
+ self.use_text_det = config["Global"]["use_text_det"]
if self.use_text_det:
- self.text_detector = TextDetector(config['Det'])
+ self.text_detector = TextDetector(config["Det"])
- self.text_recognizer = TextRecognizer(config['Rec'])
+ self.text_recognizer = TextRecognizer(config["Rec"])
- self.use_angle_cls = config['Global']['use_angle_cls']
+ self.use_angle_cls = config["Global"]["use_angle_cls"]
if self.use_angle_cls:
- self.text_cls = TextClassifier(config['Cls'])
+ self.text_cls = TextClassifier(config["Cls"])
self.load_img = LoadImage()
- def __call__(self,
- img_content: Union[str, np.ndarray, bytes, Path], **kwargs):
+ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path], **kwargs):
if kwargs:
- box_thresh = kwargs.get('box_thresh', 0.5)
- unclip_ratio = kwargs.get('unclip_ratio', 1.6)
- text_score = kwargs.get('text_score', 0.5)
+ box_thresh = kwargs.get("box_thresh", 0.5)
+ unclip_ratio = kwargs.get("unclip_ratio", 1.6)
+ text_score = kwargs.get("text_score", 0.5)
self.text_detector.postprocess_op.box_thresh = box_thresh
self.text_detector.postprocess_op.unclip_ratio = unclip_ratio
@@ -66,9 +64,7 @@ def __call__(self,
else:
use_limit_ratio = w / h > self.width_height_ratio
- if not self.use_text_det \
- or h <= self.min_height \
- or use_limit_ratio:
+ if not self.use_text_det or h <= self.min_height or use_limit_ratio:
dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w)
det_elapse = 0.0
else:
@@ -77,7 +73,7 @@ def __call__(self,
return None, None
if self.print_verbose:
- print(f'dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}')
+ print(f"dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}")
dt_boxes = self.sorted_boxes(dt_boxes)
img_crop_list = self.get_crop_img_list(img, dt_boxes)
@@ -87,16 +83,17 @@ def __call__(self,
img_crop_list, _, cls_elapse = self.text_cls(img_crop_list)
if self.print_verbose:
- print(f'cls num: {len(img_crop_list)}, elapse: {cls_elapse}')
+ print(f"cls num: {len(img_crop_list)}, elapse: {cls_elapse}")
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
if self.print_verbose:
- print(f'rec_res num: {len(rec_res)}, elapse: {rec_elapse}')
+ print(f"rec_res num: {len(rec_res)}, elapse: {rec_elapse}")
- filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes,
- rec_res)
- fina_result = [[dt.tolist(), rec[0], str(rec[1])]
- for dt, rec in zip(filter_boxes, filter_rec_res)]
+ filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, rec_res)
+ fina_result = [
+ [dt.tolist(), rec[0], str(rec[1])]
+ for dt, rec in zip(filter_boxes, filter_rec_res)
+ ]
if fina_result:
return fina_result, [det_elapse, cls_elapse, rec_elapse]
return None, None
@@ -118,20 +115,31 @@ def get_rotate_crop_image(img, points):
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
+ np.linalg.norm(points[2] - points[3]),
+ )
+ )
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
+ np.linalg.norm(points[1] - points[2]),
+ )
+ )
+ pts_std = np.float32(
+ [
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ]
+ )
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
- M, (img_crop_width, img_crop_height),
+ M,
+ (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
+ flags=cv2.INTER_CUBIC,
+ )
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
@@ -159,8 +167,10 @@ def sorted_boxes(dt_boxes):
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 \
- and _boxes[j + 1][0][0] < _boxes[j][0][0]:
+ if (
+ abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
+ and _boxes[j + 1][0][0] < _boxes[j][0][0]
+ ):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
@@ -188,5 +198,5 @@ def main():
print(elapse_list)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py
index a9fedefd5..62e062ade 100644
--- a/python/rapidocr_onnxruntime/utils.py
+++ b/python/rapidocr_onnxruntime/utils.py
@@ -10,69 +10,83 @@
import cv2
import numpy as np
import yaml
-from onnxruntime import (GraphOptimizationLevel, InferenceSession,
- SessionOptions, get_available_providers, get_device)
+from onnxruntime import (
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+)
from PIL import Image, UnidentifiedImageError
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
-class OrtInferSession():
+class OrtInferSession:
def __init__(self, config):
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cpu_ep = 'CPUExecutionProvider'
+ cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
- 'arena_extend_strategy': 'kSameAsRequested',
+ "arena_extend_strategy": "kSameAsRequested",
}
- cuda_ep = 'CUDAExecutionProvider'
+ cuda_ep = "CUDAExecutionProvider"
cuda_provider_options = {
- 'device_id': 0,
- 'arena_extend_strategy': 'kNextPowerOfTwo',
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
- 'do_copy_in_default_stream': True
+ "device_id": 0,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": True,
}
EP_list = []
- if config['use_cuda'] and get_device() == 'GPU' \
- and cuda_ep in get_available_providers():
+ if (
+ config["use_cuda"]
+ and get_device() == "GPU"
+ and cuda_ep in get_available_providers()
+ ):
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
- self._verify_model(config['model_path'])
- self.session = InferenceSession(config['model_path'],
- sess_options=sess_opt,
- providers=EP_list)
+ self._verify_model(config["model_path"])
+ self.session = InferenceSession(
+ config["model_path"], sess_options=sess_opt, providers=EP_list
+ )
- if config['use_cuda'] and cuda_ep not in self.session.get_providers():
- warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
- 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
- 'you can check their relations from the offical web site: '
- 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
- RuntimeWarning)
+ if config["use_cuda"] and cuda_ep not in self.session.get_providers():
+ warnings.warn(
+ f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
+ "you can check their relations from the offical web site: "
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
+ RuntimeWarning,
+ )
def __call__(self, input_content: np.ndarray) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), [input_content]))
try:
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
- raise ONNXRuntimeError('ONNXRuntime inference failed.') from e
+ raise ONNXRuntimeError("ONNXRuntime inference failed.") from e
- def get_input_names(self, ):
+ def get_input_names(
+ self,
+ ):
return [v.name for v in self.session.get_inputs()]
- def get_output_names(self,):
+ def get_output_names(
+ self,
+ ):
return [v.name for v in self.session.get_outputs()]
- def get_character_list(self, key: str = 'character'):
+ def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
- def have_key(self, key: str = 'character') -> bool:
+ def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
return True
@@ -82,23 +96,26 @@ def have_key(self, key: str = 'character') -> bool:
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
- raise FileNotFoundError(f'{model_path} does not exists.')
+ raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
- raise FileExistsError(f'{model_path} is not a file.')
+ raise FileExistsError(f"{model_path} is not a file.")
class ONNXRuntimeError(Exception):
pass
-class LoadImage():
- def __init__(self, ):
+class LoadImage:
+ def __init__(
+ self,
+ ):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
- f'The img type {type(img)} does not in {InputType.__args__}')
+ f"The img type {type(img)} does not in {InputType.__args__}"
+ )
img = self.load_img(img)
@@ -117,8 +134,7 @@ def load_img(self, img: InputType) -> np.ndarray:
img = np.array(Image.open(img))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
except UnidentifiedImageError as e:
- raise LoadImageError(
- f'cannot identify image file {img}') from e
+ raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
@@ -129,12 +145,11 @@ def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, np.ndarray):
return img
- raise LoadImageError(f'{type(img)} is not supported!')
+ raise LoadImageError(f"{type(img)} is not supported!")
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- '''RGBA → RGB
- '''
+ """RGBA → RGB"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
@@ -148,7 +163,7 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
- raise LoadImageError(f'{file_path} does not exist.')
+ raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
@@ -156,77 +171,74 @@ class LoadImageError(Exception):
def read_yaml(yaml_path):
- with open(yaml_path, 'rb') as f:
+ with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
def concat_model_path(config):
- key = 'model_path'
- config['Det'][key] = str(root_dir / config['Det'][key])
- config['Rec'][key] = str(root_dir / config['Rec'][key])
- config['Cls'][key] = str(root_dir / config['Cls'][key])
+ key = "model_path"
+ config["Det"][key] = str(root_dir / config["Det"][key])
+ config["Rec"][key] = str(root_dir / config["Rec"][key])
+ config["Cls"][key] = str(root_dir / config["Cls"][key])
return config
def init_args():
parser = argparse.ArgumentParser()
- parser.add_argument('-img', '--img_path', type=str, default=None,
- required=True)
- parser.add_argument('-p', '--print_cost',
- action='store_true', default=False)
-
- global_group = parser.add_argument_group(title='Global')
- global_group.add_argument('--text_score', type=float, default=0.5)
- global_group.add_argument('--use_angle_cls', type=bool, default=True)
- global_group.add_argument('--use_text_det', type=bool, default=True)
- global_group.add_argument('--print_verbose', type=bool, default=False)
- global_group.add_argument('--min_height', type=int, default=30)
- global_group.add_argument('--width_height_ratio', type=int, default=8)
-
- det_group = parser.add_argument_group(title='Det')
- det_group.add_argument('--det_model_path', type=str, default=None)
- det_group.add_argument('--det_limit_side_len', type=float, default=736)
- det_group.add_argument('--det_limit_type', type=str, default='min',
- choices=['max', 'min'])
- det_group.add_argument('--det_thresh', type=float, default=0.3)
- det_group.add_argument('--det_box_thresh', type=float, default=0.5)
- det_group.add_argument('--det_unclip_ratio', type=float, default=1.6)
- det_group.add_argument('--det_use_dilation', type=bool, default=True)
- det_group.add_argument('--det_score_mode', type=str, default='fast',
- choices=['slow', 'fast'])
-
- cls_group = parser.add_argument_group(title='Cls')
- cls_group.add_argument('--cls_model_path', type=str, default=None)
- cls_group.add_argument('--cls_image_shape', type=list,
- default=[3, 48, 192])
- cls_group.add_argument('--cls_label_list', type=list,
- default=['0', '180'])
- cls_group.add_argument('--cls_batch_num', type=int, default=6)
- cls_group.add_argument('--cls_thresh', type=float, default=0.9)
-
- rec_group = parser.add_argument_group(title='Rec')
- rec_group.add_argument('--rec_model_path', type=str, default=None)
- rec_group.add_argument('--rec_img_shape', type=list,
- default=[3, 48, 320])
- rec_group.add_argument('--rec_batch_num', type=int, default=6)
+ parser.add_argument("-img", "--img_path", type=str, default=None, required=True)
+ parser.add_argument("-p", "--print_cost", action="store_true", default=False)
+
+ global_group = parser.add_argument_group(title="Global")
+ global_group.add_argument("--text_score", type=float, default=0.5)
+ global_group.add_argument("--use_angle_cls", type=bool, default=True)
+ global_group.add_argument("--use_text_det", type=bool, default=True)
+ global_group.add_argument("--print_verbose", type=bool, default=False)
+ global_group.add_argument("--min_height", type=int, default=30)
+ global_group.add_argument("--width_height_ratio", type=int, default=8)
+
+ det_group = parser.add_argument_group(title="Det")
+ det_group.add_argument("--det_model_path", type=str, default=None)
+ det_group.add_argument("--det_limit_side_len", type=float, default=736)
+ det_group.add_argument(
+ "--det_limit_type", type=str, default="min", choices=["max", "min"]
+ )
+ det_group.add_argument("--det_thresh", type=float, default=0.3)
+ det_group.add_argument("--det_box_thresh", type=float, default=0.5)
+ det_group.add_argument("--det_unclip_ratio", type=float, default=1.6)
+ det_group.add_argument("--det_use_dilation", type=bool, default=True)
+ det_group.add_argument(
+ "--det_score_mode", type=str, default="fast", choices=["slow", "fast"]
+ )
+
+ cls_group = parser.add_argument_group(title="Cls")
+ cls_group.add_argument("--cls_model_path", type=str, default=None)
+ cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192])
+ cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"])
+ cls_group.add_argument("--cls_batch_num", type=int, default=6)
+ cls_group.add_argument("--cls_thresh", type=float, default=0.9)
+
+ rec_group = parser.add_argument_group(title="Rec")
+ rec_group.add_argument("--rec_model_path", type=str, default=None)
+ rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
+ rec_group.add_argument("--rec_batch_num", type=int, default=6)
args = parser.parse_args()
return args
-class UpdateParameters():
+class UpdateParameters:
def __init__(self) -> None:
pass
def parse_kwargs(self, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {}
for k, v in kwargs.items():
- if k.startswith('det'):
+ if k.startswith("det"):
det_dict[k] = v
- elif k.startswith('cls'):
+ elif k.startswith("cls"):
cls_dict[k] = v
- elif k.startswith('rec'):
+ elif k.startswith("rec"):
rec_dict[k] = v
else:
global_dict[k] = v
@@ -235,11 +247,10 @@ def parse_kwargs(self, **kwargs):
def __call__(self, config, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs)
new_config = {
- 'Global': self.update_global_params(config['Global'],
- global_dict),
- 'Det': self.update_det_params(config['Det'], det_dict),
- 'Cls': self.update_cls_params(config['Cls'], cls_dict),
- 'Rec': self.update_rec_params(config['Rec'], rec_dict)
+ "Global": self.update_global_params(config["Global"], global_dict),
+ "Det": self.update_det_params(config["Det"], det_dict),
+ "Cls": self.update_cls_params(config["Cls"], cls_dict),
+ "Rec": self.update_rec_params(config["Rec"], rec_dict),
}
return new_config
@@ -250,38 +261,36 @@ def update_global_params(self, config, global_dict):
def update_det_params(self, config, det_dict):
if det_dict:
- det_dict = {k.split('det_')[1]: v for k, v in det_dict.items()}
- if not det_dict['model_path']:
- det_dict['model_path'] = str(root_dir / config['model_path'])
+ det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()}
+ if not det_dict["model_path"]:
+ det_dict["model_path"] = str(root_dir / config["model_path"])
config.update(det_dict)
return config
def update_cls_params(self, config, cls_dict):
if cls_dict:
- need_remove_prefix = ['cls_label_list', 'cls_model_path']
+ need_remove_prefix = ["cls_label_list", "cls_model_path"]
new_cls_dict = {}
for k, v in cls_dict.items():
if k in need_remove_prefix:
- k = k.split('cls_')[1]
+ k = k.split("cls_")[1]
new_cls_dict[k] = v
- if not new_cls_dict['model_path']:
- new_cls_dict['model_path'] = str(
- root_dir / config['model_path'])
+ if not new_cls_dict["model_path"]:
+ new_cls_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_cls_dict)
return config
def update_rec_params(self, config, rec_dict):
if rec_dict:
- need_remove_prefix = ['rec_model_path']
+ need_remove_prefix = ["rec_model_path"]
new_rec_dict = {}
for k, v in rec_dict.items():
if k in need_remove_prefix:
- k = k.split('rec_')[1]
+ k = k.split("rec_")[1]
new_rec_dict[k] = v
- if not new_rec_dict['model_path']:
- new_rec_dict['model_path'] = str(
- root_dir / config['model_path'])
+ if not new_rec_dict["model_path"]:
+ new_rec_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_rec_dict)
return config
diff --git a/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py b/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py
index a397d3514..5c44dbd57 100644
--- a/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py
+++ b/python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py
@@ -25,12 +25,12 @@
from .utils import ClsPostProcess
-class TextClassifier():
+class TextClassifier:
def __init__(self, config):
- self.cls_image_shape = config['cls_image_shape']
- self.cls_batch_num = config['cls_batch_num']
- self.cls_thresh = config['cls_thresh']
- self.postprocess_op = ClsPostProcess(config['label_list'])
+ self.cls_image_shape = config["cls_image_shape"]
+ self.cls_batch_num = config["cls_batch_num"]
+ self.cls_thresh = config["cls_thresh"]
+ self.postprocess_op = ClsPostProcess(config["label_list"])
self.infer = OpenVINOInferSession(config)
@@ -47,7 +47,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- cls_res = [['', 0.0]] * img_num
+ cls_res = [["", 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
@@ -68,9 +68,10 @@ def __call__(self, img_list: List[np.ndarray]):
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
- if '180' in label and score > self.cls_thresh:
+ if "180" in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
- img_list[indices[beg_img_no + rno]], 1)
+ img_list[indices[beg_img_no + rno]], 1
+ )
return img_list, cls_res, elapse
def resize_norm_img(self, img):
@@ -83,7 +84,7 @@ def resize_norm_img(self, img):
resized_w = int(math.ceil(img_h * ratio))
resized_image = cv2.resize(img, (resized_w, img_h))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
if img_c == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
@@ -99,8 +100,8 @@ def resize_norm_img(self, img):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--image_path', type=str, help='image_dir|image_path')
- parser.add_argument('--config_path', type=str, default='config.yaml')
+ parser.add_argument("--image_path", type=str, help="image_dir|image_path")
+ parser.add_argument("--config_path", type=str, default="config.yaml")
args = parser.parse_args()
config = read_yaml(args.config_path)
diff --git a/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py b/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py
index cb579fa05..5c75d54ee 100644
--- a/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py
+++ b/python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py
@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-class ClsPostProcess():
- """ Convert between text-label and text-index """
+class ClsPostProcess:
+ """Convert between text-label and text-index"""
def __init__(self, label_list):
super(ClsPostProcess, self).__init__()
@@ -20,10 +20,11 @@ def __init__(self, label_list):
def __call__(self, preds, label=None):
pred_idxs = preds.argmax(axis=1)
- decode_out = [(self.label_list[idx], preds[i, idx])
- for i, idx in enumerate(pred_idxs)]
+ decode_out = [
+ (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
+ ]
if label is None:
return decode_out
label = [(self.label_list[idx], 1.0) for idx in label]
- return decode_out, label
\ No newline at end of file
+ return decode_out, label
diff --git a/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py b/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py
index 5c50cbbdc..0b11a314f 100644
--- a/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py
+++ b/python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py
@@ -23,16 +23,16 @@
from .utils import DBPostProcess, create_operators, transform
-class TextDetector():
+class TextDetector:
def __init__(self, config):
- self.preprocess_op = create_operators(config['pre_process'])
- self.postprocess_op = DBPostProcess(**config['post_process'])
+ self.preprocess_op = create_operators(config["pre_process"])
+ self.postprocess_op = DBPostProcess(**config["post_process"])
self.infer = OpenVINOInferSession(config)
def __call__(self, img):
ori_im = img.copy()
- data = {'image': img}
+ data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -44,7 +44,7 @@ def __call__(self, img):
starttime = time.time()
preds = self.infer(img)
post_result = self.postprocess_op(preds, shape_list)
- dt_boxes = post_result[0]['points']
+ dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse
@@ -96,8 +96,8 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--config_path', type=str, default='config.yaml')
- parser.add_argument('--image_path', type=str, default=None)
+ parser.add_argument("--config_path", type=str, default="config.yaml")
+ parser.add_argument("--image_path", type=str, default=None)
args = parser.parse_args()
config = read_yaml(args.config_path)
@@ -108,6 +108,7 @@ def filter_tag_det_res(self, dt_boxes, image_shape):
dt_boxes, elapse = text_detector(img)
from utils import draw_text_det_res
+
src_im = draw_text_det_res(dt_boxes, args.image_path)
- cv2.imwrite('det_results.jpg', src_im)
- print('The det_results.jpg has been saved in the current directory.')
+ cv2.imwrite("det_results.jpg", src_im)
+ print("The det_results.jpg has been saved in the current directory.")
diff --git a/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py b/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py
index 781dbb1a9..2e586b7a5 100644
--- a/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py
+++ b/python/rapidocr_openvino/ch_ppocr_v3_det/utils.py
@@ -25,69 +25,74 @@
from shapely.geometry import Polygon
-class DecodeImage():
- """ decode image """
+class DecodeImage:
+ """decode image"""
- def __init__(self, img_mode='RGB', channel_first=False):
+ def __init__(self, img_mode="RGB", channel_first=False):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
- img = data['image']
+ img = data["image"]
if six.PY2:
- assert type(img) is str and len(img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is str and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
else:
- assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
+ assert (
+ type(img) is bytes and len(img) > 0
+ ), "invalid input 'img' in DecodeImage"
- img = np.frombuffer(img, dtype='uint8')
+ img = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(img, 1)
if img is None:
return None
- if self.img_mode == 'GRAY':
+ if self.img_mode == "GRAY":
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- elif self.img_mode == 'RGB':
- assert img.shape[2] == 3, f'invalid shape of image[{img.shape}]'
+ elif self.img_mode == "RGB":
+ assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
- data['image'] = img
+ data["image"] = img
return data
-class NormalizeImage():
- """ normalize image such as substract mean, divide std"""
+class NormalizeImage:
+ """normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order='chw'):
+ def __init__(self, scale=None, mean=None, std=None, order="chw"):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype('float32')
- self.std = np.array(std).reshape(shape).astype('float32')
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
- img = np.array(data['image']).astype(np.float32)
- data['image'] = (img * self.scale - self.mean) / self.std
+ img = np.array(data["image"]).astype(np.float32)
+ data["image"] = (img * self.scale - self.mean) / self.std
return data
-class ToCHWImage():
- """ convert hwc image to chw image"""
+class ToCHWImage:
+ """convert hwc image to chw image"""
+
def __init__(self):
pass
def __call__(self, data):
- img = np.array(data['image'])
- data['image'] = img.transpose((2, 0, 1))
+ img = np.array(data["image"])
+ data["image"] = img.transpose((2, 0, 1))
return data
-class KeepKeys():
+class KeepKeys:
def __init__(self, keep_keys):
self.keep_keys = keep_keys
@@ -98,26 +103,26 @@ def __call__(self, data):
return data_list
-class DetResizeForTest():
+class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
- if 'image_shape' in kwargs:
- self.image_shape = kwargs['image_shape']
+ if "image_shape" in kwargs:
+ self.image_shape = kwargs["image_shape"]
self.resize_type = 1
- elif 'limit_side_len' in kwargs:
- self.limit_side_len = kwargs.get('limit_side_len', 736)
- self.limit_type = kwargs.get('limit_type', 'min')
+ elif "limit_side_len" in kwargs:
+ self.limit_side_len = kwargs.get("limit_side_len", 736)
+ self.limit_type = kwargs.get("limit_type", "min")
- if 'resize_long' in kwargs:
+ if "resize_long" in kwargs:
self.resize_type = 2
- self.resize_long = kwargs.get('resize_long', 960)
+ self.resize_long = kwargs.get("resize_long", 960)
else:
- self.limit_side_len = kwargs.get('limit_side_len', 736)
- self.limit_type = kwargs.get('limit_type', 'min')
+ self.limit_side_len = kwargs.get("limit_side_len", 736)
+ self.limit_type = kwargs.get("limit_type", "min")
def __call__(self, data):
- img = data['image']
+ img = data["image"]
src_h, src_w = img.shape[:2]
if self.resize_type == 0:
@@ -128,8 +133,8 @@ def __call__(self, data):
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
- data['image'] = img
- data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ data["image"] = img
+ data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
@@ -153,14 +158,14 @@ def resize_image_type0(self, img):
h, w = img.shape[:2]
# limit the max side
- if self.limit_type == 'max':
+ if self.limit_type == "max":
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
else:
if min(h, w) < limit_side_len:
if h < w:
@@ -168,7 +173,7 @@ def resize_image_type0(self, img):
else:
ratio = float(limit_side_len) / w
else:
- ratio = 1.
+ ratio = 1.0
resize_h = int(h * ratio)
resize_w = int(w * ratio)
@@ -212,7 +217,7 @@ def resize_image_type2(self, img):
def transform(data, ops=None):
- """ transform """
+ """transform"""
if ops is None:
ops = []
for op in ops:
@@ -239,21 +244,22 @@ def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
- cv2.polylines(src_im, [box], True,
- color=(255, 255, 0), thickness=2)
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
return src_im
-class DBPostProcess():
+class DBPostProcess:
"""The post process for Differentiable Binarization (DB)."""
- def __init__(self,
- thresh=0.3,
- box_thresh=0.7,
- max_candidates=1000,
- unclip_ratio=2.0,
- score_mode="fast",
- use_dilation=False):
+ def __init__(
+ self,
+ thresh=0.3,
+ box_thresh=0.7,
+ max_candidates=1000,
+ unclip_ratio=2.0,
+ score_mode="fast",
+ use_dilation=False,
+ ):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
@@ -267,16 +273,17 @@ def __init__(self,
self.dilation_kernel = None
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
- '''
+ """
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
- '''
+ """
bitmap = _bitmap
height, width = bitmap.shape
- outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
- cv2.CHAIN_APPROX_SIMPLE)
+ outs = cv2.findContours(
+ (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+ )
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
@@ -305,10 +312,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
continue
box = np.array(box)
- box[:, 0] = np.clip(
- np.round(box[:, 0] / width * dest_width), 0, dest_width)
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
- np.round(box[:, 1] / height * dest_height), 0, dest_height)
+ np.round(box[:, 1] / height * dest_height), 0, dest_height
+ )
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
@@ -340,9 +347,7 @@ def get_mini_boxes(self, contour):
index_2 = 3
index_3 = 2
- box = [
- points[index_1], points[index_2], points[index_3], points[index_4]
- ]
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
@@ -357,12 +362,12 @@ def box_score_fast(self, bitmap, _box):
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
- '''
+ """
box_score_slow: use polyon mean score as the mean score
- '''
+ """
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
@@ -378,7 +383,7 @@ def box_score_slow(self, bitmap, contour):
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
- return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
def __call__(self, pred, shape_list):
pred = pred[:, 0, :, :]
@@ -390,11 +395,13 @@ def __call__(self, pred, shape_list):
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
- self.dilation_kernel)
+ self.dilation_kernel,
+ )
else:
mask = segmentation[batch_index]
- boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
- src_w, src_h)
+ boxes, scores = self.boxes_from_bitmap(
+ pred[batch_index], mask, src_w, src_h
+ )
- boxes_batch.append({'points': boxes})
+ boxes_batch.append({"points": boxes})
return boxes_batch
diff --git a/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py
index 1bf6bb095..ff274f805 100644
--- a/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py
+++ b/python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py
@@ -28,13 +28,13 @@
from .utils import CTCLabelDecode
-class TextRecognizer():
+class TextRecognizer:
def __init__(self, config):
- self.rec_image_shape = config['rec_img_shape']
- self.rec_batch_num = config['rec_batch_num']
+ self.rec_image_shape = config["rec_img_shape"]
+ self.rec_batch_num = config["rec_batch_num"]
- dict_path = str(Path(__file__).parent / 'ppocr_keys_v1.txt')
- self.character_dict_path = config.get('keys_path', dict_path)
+ dict_path = str(Path(__file__).parent / "ppocr_keys_v1.txt")
+ self.character_dict_path = config.get("keys_path", dict_path)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)
self.infer = OpenVINOInferSession(config)
@@ -50,7 +50,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))
img_num = len(img_list)
- rec_res = [['', 0.0]] * img_num
+ rec_res = [["", 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
@@ -64,8 +64,7 @@ def __call__(self, img_list: List[np.ndarray]):
norm_img_batch = []
for ino in range(beg_img_no, end_img_no):
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img_batch.append(norm_img[np.newaxis, :])
norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32)
@@ -91,21 +90,20 @@ def resize_norm_img(self, img, max_wh_ratio):
resized_w = int(math.ceil(img_height * ratio))
resized_image = cv2.resize(img, (resized_w, img_height))
- resized_image = resized_image.astype('float32')
+ resized_image = resized_image.astype("float32")
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
- padding_im = np.zeros((img_channel, img_height, img_width),
- dtype=np.float32)
+ padding_im = np.zeros((img_channel, img_height, img_width), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--image_path', type=str, help='image_dir|image_path')
- parser.add_argument('--config_path', type=str, default='config.yaml')
+ parser.add_argument("--image_path", type=str, help="image_dir|image_path")
+ parser.add_argument("--config_path", type=str, default="config.yaml")
args = parser.parse_args()
config = read_yaml(args.config_path)
@@ -113,4 +111,4 @@ def resize_norm_img(self, img, max_wh_ratio):
img = cv2.imread(args.image_path)
rec_res, predict_time = text_recognizer(img)
- print(f'rec result: {rec_res}\t cost: {predict_time}s')
+ print(f"rec result: {rec_res}\t cost: {predict_time}s")
diff --git a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py
index f40d906f6..9587af350 100644
--- a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py
+++ b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py
@@ -4,8 +4,8 @@
import numpy as np
-class CTCLabelDecode():
- """ Convert between text-label and text-index """
+class CTCLabelDecode:
+ """Convert between text-label and text-index"""
def __init__(self, character_dict_path):
super(CTCLabelDecode, self).__init__()
@@ -15,9 +15,9 @@ def __init__(self, character_dict_path):
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
- self.character_str.append(' ')
+ self.character_str.append(" ")
dict_character = self.add_special_char(self.character_str)
self.character = dict_character
@@ -29,22 +29,21 @@ def __init__(self, character_dict_path):
def __call__(self, preds, label=None):
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob,
- is_remove_duplicate=True)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
+ dict_character = ["blank"] + dict_character
return dict_character
def get_ignored_tokens(self):
return [0] # for ctc blank
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
+ """convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
@@ -57,15 +56,16 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
continue
if is_remove_duplicate:
# only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
+ if (
+ idx > 0
+ and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]
+ ):
continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
- text = ''.join(char_list)
+ text = "".join(char_list)
result_list.append((text, np.mean(conf_list + [1e-50])))
return result_list
diff --git a/python/rapidocr_openvino/rapid_ocr_api.py b/python/rapidocr_openvino/rapid_ocr_api.py
index 7f92c0f95..0c351b3c7 100644
--- a/python/rapidocr_openvino/rapid_ocr_api.py
+++ b/python/rapidocr_openvino/rapid_ocr_api.py
@@ -12,17 +12,16 @@
from .ch_ppocr_v2_cls import TextClassifier
from .ch_ppocr_v3_det import TextDetector
from .ch_ppocr_v3_rec import TextRecognizer
-from .utils import (LoadImage, UpdateParameters, concat_model_path, init_args,
- read_yaml)
+from .utils import LoadImage, UpdateParameters, concat_model_path, init_args, read_yaml
root_dir = Path(__file__).resolve().parent
-class RapidOCR():
+class RapidOCR:
def __init__(self, **kwargs):
- config_path = str(root_dir / 'config.yaml')
+ config_path = str(root_dir / "config.yaml")
if not Path(config_path).exists():
- raise FileExistsError(f'{config_path} does not exist!')
+ raise FileExistsError(f"{config_path} does not exist!")
config = read_yaml(config_path)
config = concat_model_path(config)
@@ -30,30 +29,29 @@ def __init__(self, **kwargs):
updater = UpdateParameters()
config = updater(config, **kwargs)
- global_config = config['Global']
- self.print_verbose = global_config['print_verbose']
- self.text_score = global_config['text_score']
- self.min_height = global_config['min_height']
- self.width_height_ratio = global_config['width_height_ratio']
+ global_config = config["Global"]
+ self.print_verbose = global_config["print_verbose"]
+ self.text_score = global_config["text_score"]
+ self.min_height = global_config["min_height"]
+ self.width_height_ratio = global_config["width_height_ratio"]
- self.use_text_det = config['Global']['use_text_det']
+ self.use_text_det = config["Global"]["use_text_det"]
if self.use_text_det:
- self.text_detector = TextDetector(config['Det'])
+ self.text_detector = TextDetector(config["Det"])
- self.text_recognizer = TextRecognizer(config['Rec'])
+ self.text_recognizer = TextRecognizer(config["Rec"])
- self.use_angle_cls = config['Global']['use_angle_cls']
+ self.use_angle_cls = config["Global"]["use_angle_cls"]
if self.use_angle_cls:
- self.text_cls = TextClassifier(config['Cls'])
+ self.text_cls = TextClassifier(config["Cls"])
self.load_img = LoadImage()
- def __call__(self,
- img_content: Union[str, np.ndarray, bytes, Path], **kwargs):
+ def __call__(self, img_content: Union[str, np.ndarray, bytes, Path], **kwargs):
if kwargs:
- box_thresh = kwargs.get('box_thresh', 0.5)
- unclip_ratio = kwargs.get('unclip_ratio', 1.6)
- text_score = kwargs.get('text_score', 0.5)
+ box_thresh = kwargs.get("box_thresh", 0.5)
+ unclip_ratio = kwargs.get("unclip_ratio", 1.6)
+ text_score = kwargs.get("text_score", 0.5)
self.text_detector.postprocess_op.box_thresh = box_thresh
self.text_detector.postprocess_op.unclip_ratio = unclip_ratio
@@ -66,9 +64,7 @@ def __call__(self,
else:
use_limit_ratio = w / h > self.width_height_ratio
- if not self.use_text_det \
- or h <= self.min_height \
- or use_limit_ratio:
+ if not self.use_text_det or h <= self.min_height or use_limit_ratio:
dt_boxes, img_crop_list = self.get_boxes_img_without_det(img, h, w)
det_elapse = 0.0
else:
@@ -77,7 +73,7 @@ def __call__(self,
return None, None
if self.print_verbose:
- print(f'dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}')
+ print(f"dt_boxes num: {len(dt_boxes)}, elapse: {det_elapse}")
dt_boxes = self.sorted_boxes(dt_boxes)
img_crop_list = self.get_crop_img_list(img, dt_boxes)
@@ -87,16 +83,17 @@ def __call__(self,
img_crop_list, _, cls_elapse = self.text_cls(img_crop_list)
if self.print_verbose:
- print(f'cls num: {len(img_crop_list)}, elapse: {cls_elapse}')
+ print(f"cls num: {len(img_crop_list)}, elapse: {cls_elapse}")
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
if self.print_verbose:
- print(f'rec_res num: {len(rec_res)}, elapse: {rec_elapse}')
+ print(f"rec_res num: {len(rec_res)}, elapse: {rec_elapse}")
- filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes,
- rec_res)
- fina_result = [[dt.tolist(), rec[0], str(rec[1])]
- for dt, rec in zip(filter_boxes, filter_rec_res)]
+ filter_boxes, filter_rec_res = self.filter_boxes_rec_by_score(dt_boxes, rec_res)
+ fina_result = [
+ [dt.tolist(), rec[0], str(rec[1])]
+ for dt, rec in zip(filter_boxes, filter_rec_res)
+ ]
if fina_result:
return fina_result, [det_elapse, cls_elapse, rec_elapse]
return None, None
@@ -118,20 +115,31 @@ def get_rotate_crop_image(img, points):
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
+ np.linalg.norm(points[2] - points[3]),
+ )
+ )
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
+ np.linalg.norm(points[1] - points[2]),
+ )
+ )
+ pts_std = np.float32(
+ [
+ [0, 0],
+ [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height],
+ ]
+ )
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
- M, (img_crop_width, img_crop_height),
+ M,
+ (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
+ flags=cv2.INTER_CUBIC,
+ )
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
@@ -159,8 +167,10 @@ def sorted_boxes(dt_boxes):
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 \
- and _boxes[j + 1][0][0] < _boxes[j][0][0]:
+ if (
+ abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
+ and _boxes[j + 1][0][0] < _boxes[j][0][0]
+ ):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
@@ -188,5 +198,5 @@ def main():
print(elapse_list)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/python/rapidocr_openvino/utils.py b/python/rapidocr_openvino/utils.py
index 4f01f387c..3c699b9f5 100644
--- a/python/rapidocr_openvino/utils.py
+++ b/python/rapidocr_openvino/utils.py
@@ -16,14 +16,14 @@
InputType = Union[str, np.ndarray, bytes, Path]
-class OpenVINOInferSession():
+class OpenVINOInferSession:
def __init__(self, config):
ie = Core()
- config['model_path'] = str(root_dir / config['model_path'])
- self._verify_model(config['model_path'])
- model_onnx = ie.read_model(config['model_path'])
- compile_model = ie.compile_model(model=model_onnx, device_name='CPU')
+ config["model_path"] = str(root_dir / config["model_path"])
+ self._verify_model(config["model_path"])
+ model_onnx = ie.read_model(config["model_path"])
+ compile_model = ie.compile_model(model=model_onnx, device_name="CPU")
self.session = compile_model.create_infer_request()
def __call__(self, input_content: np.ndarray) -> np.ndarray:
@@ -34,19 +34,22 @@ def __call__(self, input_content: np.ndarray) -> np.ndarray:
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
- raise FileNotFoundError(f'{model_path} does not exists.')
+ raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
- raise FileExistsError(f'{model_path} is not a file.')
+ raise FileExistsError(f"{model_path} is not a file.")
-class LoadImage():
- def __init__(self, ):
+class LoadImage:
+ def __init__(
+ self,
+ ):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
- f'The img type {type(img)} does not in {InputType.__args__}')
+ f"The img type {type(img)} does not in {InputType.__args__}"
+ )
img = self.load_img(img)
@@ -65,8 +68,7 @@ def load_img(self, img: InputType) -> np.ndarray:
img = np.array(Image.open(img))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
except UnidentifiedImageError as e:
- raise LoadImageError(
- f'cannot identify image file {img}') from e
+ raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
@@ -77,12 +79,11 @@ def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, np.ndarray):
return img
- raise LoadImageError(f'{type(img)} is not supported!')
+ raise LoadImageError(f"{type(img)} is not supported!")
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- '''RGBA → RGB
- '''
+ """RGBA → RGB"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
@@ -96,7 +97,7 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
- raise LoadImageError(f'{file_path} does not exist.')
+ raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
@@ -104,77 +105,74 @@ class LoadImageError(Exception):
def read_yaml(yaml_path):
- with open(yaml_path, 'rb') as f:
+ with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
def concat_model_path(config):
- key = 'model_path'
- config['Det'][key] = str(root_dir / config['Det'][key])
- config['Rec'][key] = str(root_dir / config['Rec'][key])
- config['Cls'][key] = str(root_dir / config['Cls'][key])
+ key = "model_path"
+ config["Det"][key] = str(root_dir / config["Det"][key])
+ config["Rec"][key] = str(root_dir / config["Rec"][key])
+ config["Cls"][key] = str(root_dir / config["Cls"][key])
return config
def init_args():
parser = argparse.ArgumentParser()
- parser.add_argument('-img', '--img_path', type=str, default=None,
- required=True)
- parser.add_argument('-p', '--print_cost',
- action='store_true', default=False)
-
- global_group = parser.add_argument_group(title='Global')
- global_group.add_argument('--text_score', type=float, default=0.5)
- global_group.add_argument('--use_angle_cls', type=bool, default=True)
- global_group.add_argument('--use_text_det', type=bool, default=True)
- global_group.add_argument('--print_verbose', type=bool, default=False)
- global_group.add_argument('--min_height', type=int, default=30)
- global_group.add_argument('--width_height_ratio', type=int, default=8)
-
- det_group = parser.add_argument_group(title='Det')
- det_group.add_argument('--det_model_path', type=str, default=None)
- det_group.add_argument('--det_limit_side_len', type=float, default=736)
- det_group.add_argument('--det_limit_type', type=str, default='min',
- choices=['max', 'min'])
- det_group.add_argument('--det_thresh', type=float, default=0.3)
- det_group.add_argument('--det_box_thresh', type=float, default=0.5)
- det_group.add_argument('--det_unclip_ratio', type=float, default=1.6)
- det_group.add_argument('--det_use_dilation', type=bool, default=True)
- det_group.add_argument('--det_score_mode', type=str, default='fast',
- choices=['slow', 'fast'])
-
- cls_group = parser.add_argument_group(title='Cls')
- cls_group.add_argument('--cls_model_path', type=str, default=None)
- cls_group.add_argument('--cls_image_shape', type=list,
- default=[3, 48, 192])
- cls_group.add_argument('--cls_label_list', type=list,
- default=['0', '180'])
- cls_group.add_argument('--cls_batch_num', type=int, default=6)
- cls_group.add_argument('--cls_thresh', type=float, default=0.9)
-
- rec_group = parser.add_argument_group(title='Rec')
- rec_group.add_argument('--rec_model_path', type=str, default=None)
- rec_group.add_argument('--rec_image_shape', type=list,
- default=[3, 48, 320])
- rec_group.add_argument('--rec_batch_num', type=int, default=6)
+ parser.add_argument("-img", "--img_path", type=str, default=None, required=True)
+ parser.add_argument("-p", "--print_cost", action="store_true", default=False)
+
+ global_group = parser.add_argument_group(title="Global")
+ global_group.add_argument("--text_score", type=float, default=0.5)
+ global_group.add_argument("--use_angle_cls", type=bool, default=True)
+ global_group.add_argument("--use_text_det", type=bool, default=True)
+ global_group.add_argument("--print_verbose", type=bool, default=False)
+ global_group.add_argument("--min_height", type=int, default=30)
+ global_group.add_argument("--width_height_ratio", type=int, default=8)
+
+ det_group = parser.add_argument_group(title="Det")
+ det_group.add_argument("--det_model_path", type=str, default=None)
+ det_group.add_argument("--det_limit_side_len", type=float, default=736)
+ det_group.add_argument(
+ "--det_limit_type", type=str, default="min", choices=["max", "min"]
+ )
+ det_group.add_argument("--det_thresh", type=float, default=0.3)
+ det_group.add_argument("--det_box_thresh", type=float, default=0.5)
+ det_group.add_argument("--det_unclip_ratio", type=float, default=1.6)
+ det_group.add_argument("--det_use_dilation", type=bool, default=True)
+ det_group.add_argument(
+ "--det_score_mode", type=str, default="fast", choices=["slow", "fast"]
+ )
+
+ cls_group = parser.add_argument_group(title="Cls")
+ cls_group.add_argument("--cls_model_path", type=str, default=None)
+ cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192])
+ cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"])
+ cls_group.add_argument("--cls_batch_num", type=int, default=6)
+ cls_group.add_argument("--cls_thresh", type=float, default=0.9)
+
+ rec_group = parser.add_argument_group(title="Rec")
+ rec_group.add_argument("--rec_model_path", type=str, default=None)
+ rec_group.add_argument("--rec_image_shape", type=list, default=[3, 48, 320])
+ rec_group.add_argument("--rec_batch_num", type=int, default=6)
args = parser.parse_args()
return args
-class UpdateParameters():
+class UpdateParameters:
def __init__(self) -> None:
pass
def parse_kwargs(self, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {}
for k, v in kwargs.items():
- if k.startswith('det'):
+ if k.startswith("det"):
det_dict[k] = v
- elif k.startswith('cls'):
+ elif k.startswith("cls"):
cls_dict[k] = v
- elif k.startswith('rec'):
+ elif k.startswith("rec"):
rec_dict[k] = v
else:
global_dict[k] = v
@@ -183,11 +181,10 @@ def parse_kwargs(self, **kwargs):
def __call__(self, config, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs)
new_config = {
- 'Global': self.update_global_params(config['Global'],
- global_dict),
- 'Det': self.update_det_params(config['Det'], det_dict),
- 'Cls': self.update_cls_params(config['Cls'], cls_dict),
- 'Rec': self.update_rec_params(config['Rec'], rec_dict)
+ "Global": self.update_global_params(config["Global"], global_dict),
+ "Det": self.update_det_params(config["Det"], det_dict),
+ "Cls": self.update_cls_params(config["Cls"], cls_dict),
+ "Rec": self.update_rec_params(config["Rec"], rec_dict),
}
return new_config
@@ -198,38 +195,36 @@ def update_global_params(self, config, global_dict):
def update_det_params(self, config, det_dict):
if det_dict:
- det_dict = {k.split('det_')[1]: v for k, v in det_dict.items()}
- if not det_dict['model_path']:
- det_dict['model_path'] = str(root_dir / config['model_path'])
+ det_dict = {k.split("det_")[1]: v for k, v in det_dict.items()}
+ if not det_dict["model_path"]:
+ det_dict["model_path"] = str(root_dir / config["model_path"])
config.update(det_dict)
return config
def update_cls_params(self, config, cls_dict):
if cls_dict:
- need_remove_prefix = ['cls_label_list', 'cls_model_path']
+ need_remove_prefix = ["cls_label_list", "cls_model_path"]
new_cls_dict = {}
for k, v in cls_dict.items():
if k in need_remove_prefix:
- k = k.split('cls_')[1]
+ k = k.split("cls_")[1]
new_cls_dict[k] = v
- if not new_cls_dict['model_path']:
- new_cls_dict['model_path'] = str(
- root_dir / config['model_path'])
+ if not new_cls_dict["model_path"]:
+ new_cls_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_cls_dict)
return config
def update_rec_params(self, config, rec_dict):
if rec_dict:
- need_remove_prefix = ['rec_model_path']
+ need_remove_prefix = ["rec_model_path"]
new_rec_dict = {}
for k, v in rec_dict.items():
if k in need_remove_prefix:
- k = k.split('rec_')[1]
+ k = k.split("rec_")[1]
new_rec_dict[k] = v
- if not new_rec_dict['model_path']:
- new_rec_dict['model_path'] = str(
- root_dir / config['model_path'])
+ if not new_rec_dict["model_path"]:
+ new_rec_dict["model_path"] = str(root_dir / config["model_path"])
config.update(new_rec_dict)
return config
diff --git a/python/setup_onnxruntime.py b/python/setup_onnxruntime.py
index d0dc3ca3d..d5e181a06 100644
--- a/python/setup_onnxruntime.py
+++ b/python/setup_onnxruntime.py
@@ -10,14 +10,14 @@
def get_readme():
root_dir = Path(__file__).resolve().parent.parent
- readme_path = str(root_dir / 'docs' / 'doc_whl_rapidocr_ort.md')
+ readme_path = str(root_dir / "docs" / "doc_whl_rapidocr_ort.md")
print(readme_path)
- with open(readme_path, 'r', encoding='utf-8') as f:
+ with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read()
return readme
-MODULE_NAME = 'rapidocr_onnxruntime'
+MODULE_NAME = "rapidocr_onnxruntime"
obtainer = GetPyPiLatestVersion()
latest_version = obtainer(MODULE_NAME)
@@ -25,7 +25,7 @@ def get_readme():
# 优先提取commit message中的语义化版本号,如无,则自动加1
if len(sys.argv) > 2:
- match_str = ' '.join(sys.argv[2:])
+ match_str = " ".join(sys.argv[2:])
matched_versions = obtainer.extract_version(match_str)
if matched_versions:
VERSION_NUM = matched_versions
@@ -37,31 +37,38 @@ def get_readme():
platforms="Any",
description="A cross platform OCR Library based on OnnxRuntime.",
long_description=get_readme(),
- long_description_content_type='text/markdown',
+ long_description_content_type="text/markdown",
author="SWHL",
author_email="liekkaskono@163.com",
url="https://github.com/RapidAI/RapidOCR",
- license='Apache-2.0',
+ license="Apache-2.0",
include_package_data=True,
install_requires=[
- "pyclipper>=1.2.1", "onnxruntime>=1.7.0", "opencv_python>=4.5.1.48",
- "numpy>=1.19.3", "six>=1.15.0", "Shapely>=1.7.1", 'PyYAML', 'Pillow'
+ "pyclipper>=1.2.1",
+ "onnxruntime>=1.7.0",
+ "opencv_python>=4.5.1.48",
+ "numpy>=1.19.3",
+ "six>=1.15.0",
+ "Shapely>=1.7.1",
+ "PyYAML",
+ "Pillow",
],
- package_dir={'': MODULE_NAME},
+ package_dir={"": MODULE_NAME},
packages=setuptools.find_namespace_packages(where=MODULE_NAME),
- package_data={'': ['*.onnx', '*.yaml']},
+ package_data={"": ["*.onnx", "*.yaml"]},
keywords=[
- 'ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr'
+ "ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr"
],
classifiers=[
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Programming Language :: Python :: 3.8',
- 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10',
- 'Programming Language :: Python :: 3.11',
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
],
- python_requires='>=3.6,<3.12',
+ python_requires=">=3.6,<3.12",
entry_points={
- 'console_scripts': [f'{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main'],
- })
+ "console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main"],
+ },
+)
diff --git a/python/setup_openvino.py b/python/setup_openvino.py
index d52e9d6cc..6431c93a3 100644
--- a/python/setup_openvino.py
+++ b/python/setup_openvino.py
@@ -10,14 +10,14 @@
def get_readme():
root_dir = Path(__file__).resolve().parent.parent
- readme_path = str(root_dir / 'docs' / 'doc_whl_rapidocr_vino.md')
+ readme_path = str(root_dir / "docs" / "doc_whl_rapidocr_vino.md")
print(readme_path)
- with open(readme_path, 'r', encoding='utf-8') as f:
+ with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read()
return readme
-MODULE_NAME = 'rapidocr_openvino'
+MODULE_NAME = "rapidocr_openvino"
obtainer = GetPyPiLatestVersion()
latest_version = obtainer(MODULE_NAME)
@@ -25,7 +25,7 @@ def get_readme():
# 优先提取commit message中的语义化版本号,如无,则自动加1
if len(sys.argv) > 2:
- match_str = ' '.join(sys.argv[2:])
+ match_str = " ".join(sys.argv[2:])
matched_versions = obtainer.extract_version(match_str)
if matched_versions:
VERSION_NUM = matched_versions
@@ -37,31 +37,38 @@ def get_readme():
platforms="Any",
description="A cross platform OCR Library based on OpenVINO.",
long_description=get_readme(),
- long_description_content_type='text/markdown',
+ long_description_content_type="text/markdown",
author="SWHL",
author_email="liekkaskono@163.com",
url="https://github.com/RapidAI/RapidOCR",
- license='Apache-2.0',
+ license="Apache-2.0",
include_package_data=True,
- install_requires=["pyclipper>=1.2.1", "openvino>=2022.2.0",
- "opencv_python>=4.5.1.48", "numpy>=1.19.3",
- "six>=1.15.0", "Shapely>=1.7.1", 'PyYAML', 'Pillow'],
- package_dir={'': MODULE_NAME},
+ install_requires=[
+ "pyclipper>=1.2.1",
+ "openvino>=2022.2.0",
+ "opencv_python>=4.5.1.48",
+ "numpy>=1.19.3",
+ "six>=1.15.0",
+ "Shapely>=1.7.1",
+ "PyYAML",
+ "Pillow",
+ ],
+ package_dir={"": MODULE_NAME},
packages=setuptools.find_namespace_packages(where=MODULE_NAME),
- package_data={'': ['*.onnx', '*.yaml', '*.txt']},
+ package_data={"": ["*.onnx", "*.yaml", "*.txt"]},
keywords=[
- 'ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr'
+ "ocr,text_detection,text_recognition,db,onnxruntime,paddleocr,openvino,rapidocr"
],
classifiers=[
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Programming Language :: Python :: 3.8',
- 'Programming Language :: Python :: 3.9',
- 'Programming Language :: Python :: 3.10',
- 'Programming Language :: Python :: 3.11',
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
],
- python_requires='>=3.6,<3.12',
+ python_requires=">=3.6,<3.12",
entry_points={
- 'console_scripts': [f'{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main'],
- }
+ "console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.rapid_ocr_api:main"],
+ },
)
diff --git a/python/tests/base_module.py b/python/tests/base_module.py
index 13bd0bdfe..73026d233 100644
--- a/python/tests/base_module.py
+++ b/python/tests/base_module.py
@@ -8,25 +8,25 @@
import yaml
-class BaseModule():
- def __init__(self, package_name: str = 'rapidocr_onnxruntime'):
+class BaseModule:
+ def __init__(self, package_name: str = "rapidocr_onnxruntime"):
self.package_name = package_name
self.root_dir = Path(__file__).resolve().parent.parent
self.package_dir = self.root_dir / self.package_name
- self.tests_dir = self.root_dir / 'tests'
+ self.tests_dir = self.root_dir / "tests"
sys.path.append(str(self.root_dir))
sys.path.append(str(self.package_dir))
def init_module(self, module_name: str, class_name: str = None):
if class_name is None:
- module_part = importlib.import_module(f'{self.package_name}')
+ module_part = importlib.import_module(f"{self.package_name}")
return module_part
- module_part = importlib.import_module(f'{self.package_name}.{module_name}')
+ module_part = importlib.import_module(f"{self.package_name}.{module_name}")
return getattr(module_part, class_name)
@staticmethod
def read_yaml(yaml_path: str):
- with open(yaml_path, 'rb') as f:
+ with open(yaml_path, "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
diff --git a/python/tests/benchmark/benchmark.py b/python/tests/benchmark/benchmark.py
index 8b55e8d20..7afe0233a 100644
--- a/python/tests/benchmark/benchmark.py
+++ b/python/tests/benchmark/benchmark.py
@@ -14,22 +14,22 @@
from rapidocr_onnxruntime import RapidOCR
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--yaml_path', type=str, default='config.yaml')
+ parser.add_argument("--yaml_path", type=str, default="config.yaml")
args = parser.parse_args()
yaml_path = cur_dir / args.yaml_path
rapid_ocr = RapidOCR(yaml_path)
- image_dir = cur_dir / 'test_images_benchmark'
+ image_dir = cur_dir / "test_images_benchmark"
if not image_dir.exists():
- raise FileNotFoundError(f'{image_dir} does not exits!!')
+ raise FileNotFoundError(f"{image_dir} does not exits!!")
image_list = list(image_dir.iterdir())
cost_time_list = []
- for image_path in tqdm(image_list, desc='Test'):
+ for image_path in tqdm(image_list, desc="Test"):
img = cv2.imread(str(image_path))
start_time = time.time()
@@ -40,6 +40,8 @@
total_time = sum(cost_time_list)
avg_time = total_time / len(cost_time_list)
- print(f'Total Files: {len(image_list)}, '
- f'Total Time: {total_time:.5f}, '
- f'Average Time: {avg_time:.5f}')
+ print(
+ f"Total Files: {len(image_list)}, "
+ f"Total Time: {total_time:.5f}, "
+ f"Average Time: {avg_time:.5f}"
+ )
diff --git a/python/tests/test_all_ort.py b/python/tests/test_all_ort.py
index 6267c4199..4427c6cf4 100644
--- a/python/tests/test_all_ort.py
+++ b/python/tests/test_all_ort.py
@@ -16,14 +16,14 @@
from rapidocr_onnxruntime import RapidOCR, LoadImageError
rapid_ocr = RapidOCR()
-tests_dir = root_dir / 'tests' / 'test_files'
+tests_dir = root_dir / "tests" / "test_files"
def test_normal():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
img = cv2.imread(str(image_path))
result, _ = rapid_ocr(img)
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
@@ -42,29 +42,29 @@ def test_zeros():
def test_input_str():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
result, _ = rapid_ocr(str(image_path))
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_bytes():
- image_path = tests_dir / 'ch_en_num.jpg'
- with open(image_path, 'rb') as f:
+ image_path = tests_dir / "ch_en_num.jpg"
+ with open(image_path, "rb") as f:
result, _ = rapid_ocr(f.read())
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_path():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
result, _ = rapid_ocr(image_path)
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
rapid_ocr = RapidOCR(text_score=1)
result, _ = rapid_ocr(image_path)
@@ -72,27 +72,27 @@ def test_input_parameters():
def test_input_det_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(det_model_path='1.onnx')
+ rapid_ocr = RapidOCR(det_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
def test_input_cls_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(cls_model_path='1.onnx')
+ rapid_ocr = RapidOCR(cls_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
def test_input_rec_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(rec_model_path='1.onnx')
+ rapid_ocr = RapidOCR(rec_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
diff --git a/python/tests/test_all_vino.py b/python/tests/test_all_vino.py
index 5551632aa..530f04ac0 100644
--- a/python/tests/test_all_vino.py
+++ b/python/tests/test_all_vino.py
@@ -16,14 +16,14 @@
rapid_ocr = RapidOCR()
-tests_dir = root_dir / 'tests' / 'test_files'
+tests_dir = root_dir / "tests" / "test_files"
def test_normal():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
img = cv2.imread(str(image_path))
result, _ = rapid_ocr(img)
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
@@ -42,29 +42,29 @@ def test_zeros():
def test_input_str():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
result, _ = rapid_ocr(str(image_path))
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_bytes():
- image_path = tests_dir / 'ch_en_num.jpg'
- with open(image_path, 'rb') as f:
+ image_path = tests_dir / "ch_en_num.jpg"
+ with open(image_path, "rb") as f:
result, _ = rapid_ocr(f.read())
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_path():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
result, _ = rapid_ocr(image_path)
- assert result[0][1] == '正品促销'
+ assert result[0][1] == "正品促销"
assert len(result) == 17
def test_input_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
rapid_ocr = RapidOCR(text_score=1)
result, _ = rapid_ocr(image_path)
@@ -72,27 +72,27 @@ def test_input_parameters():
def test_input_det_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(det_model_path='1.onnx')
+ rapid_ocr = RapidOCR(det_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
def test_input_cls_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(cls_model_path='1.onnx')
+ rapid_ocr = RapidOCR(cls_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
def test_input_rec_parameters():
- image_path = tests_dir / 'ch_en_num.jpg'
+ image_path = tests_dir / "ch_en_num.jpg"
with pytest.raises(FileNotFoundError) as exc_info:
- rapid_ocr = RapidOCR(rec_model_path='1.onnx')
+ rapid_ocr = RapidOCR(rec_model_path="1.onnx")
result, _ = rapid_ocr(image_path)
raise FileNotFoundError()
assert exc_info.type is FileNotFoundError
diff --git a/python/tests/test_cls.py b/python/tests/test_cls.py
index 72c5b75e9..a5ebd044c 100644
--- a/python/tests/test_cls.py
+++ b/python/tests/test_cls.py
@@ -7,24 +7,22 @@
@pytest.mark.parametrize(
- 'package_name',
- [('rapidocr_onnxruntime'),
- ('rapidocr_openvino')]
+ "package_name", [("rapidocr_onnxruntime"), ("rapidocr_openvino")]
)
def test_cls(package_name: str):
- module_name = 'ch_ppocr_v2_cls'
- class_name = 'TextClassifier'
+ module_name = "ch_ppocr_v2_cls"
+ class_name = "TextClassifier"
base = BaseModule(package_name=package_name)
TextClassifier = base.init_module(module_name, class_name)
- yaml_path = base.package_dir / module_name / 'config.yaml'
+ yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
- config['model_path'] = str(base.package_dir / config['model_path'])
+ config["model_path"] = str(base.package_dir / config["model_path"])
text_cls = TextClassifier(config)
- img_path = base.tests_dir / 'test_files' / 'text_cls.jpg'
+ img_path = base.tests_dir / "test_files" / "text_cls.jpg"
img = cv2.imread(str(img_path))
result = text_cls([img])
- assert result[1][0][0] == '180'
+ assert result[1][0][0] == "180"
diff --git a/python/tests/test_det.py b/python/tests/test_det.py
index c8e9ec516..d7bd0570d 100644
--- a/python/tests/test_det.py
+++ b/python/tests/test_det.py
@@ -6,23 +6,20 @@
from base_module import BaseModule
-@pytest.mark.parametrize(
- 'package_name',
- ['rapidocr_onnxruntime', 'rapidocr_openvino']
-)
+@pytest.mark.parametrize("package_name", ["rapidocr_onnxruntime", "rapidocr_openvino"])
def test_det(package_name):
- module_name = 'ch_ppocr_v3_det'
- class_name = 'TextDetector'
+ module_name = "ch_ppocr_v3_det"
+ class_name = "TextDetector"
base = BaseModule(package_name)
TextDetector = base.init_module(module_name, class_name)
- yaml_path = base.package_dir / module_name / 'config.yaml'
+ yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
- config['model_path'] = str(base.package_dir / config['model_path'])
+ config["model_path"] = str(base.package_dir / config["model_path"])
text_det = TextDetector(config)
- img_path = base.tests_dir / 'test_files' / 'text_det.jpg'
+ img_path = base.tests_dir / "test_files" / "text_det.jpg"
img = cv2.imread(str(img_path))
dt_boxes, elapse = text_det(img)
assert dt_boxes.shape == (18, 4, 2)
diff --git a/python/tests/test_rec.py b/python/tests/test_rec.py
index 98b79d085..e75a9d6ce 100644
--- a/python/tests/test_rec.py
+++ b/python/tests/test_rec.py
@@ -6,24 +6,21 @@
from base_module import BaseModule
-@pytest.mark.parametrize(
- 'package_name',
- ['rapidocr_onnxruntime', 'rapidocr_openvino']
-)
+@pytest.mark.parametrize("package_name", ["rapidocr_onnxruntime", "rapidocr_openvino"])
def test_det(package_name):
- module_name = 'ch_ppocr_v3_rec'
- class_name = 'TextRecognizer'
+ module_name = "ch_ppocr_v3_rec"
+ class_name = "TextRecognizer"
base = BaseModule(package_name)
TextRecognizer = base.init_module(module_name, class_name)
- yaml_path = base.package_dir / module_name / 'config.yaml'
+ yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
- config['model_path'] = str(base.package_dir / config['model_path'])
+ config["model_path"] = str(base.package_dir / config["model_path"])
text_rec = TextRecognizer(config)
- img_path = base.tests_dir / 'test_files' / 'text_rec.jpg'
+ img_path = base.tests_dir / "test_files" / "text_rec.jpg"
img = cv2.imread(str(img_path))
rec_res, elapse = text_rec([img])
- assert rec_res[0][0] == '韩国小馆'
+ assert rec_res[0][0] == "韩国小馆"