Skip to content

Commit

Permalink
chore: Optimize code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Sep 6, 2024
1 parent 034e7f5 commit 6b60256
Show file tree
Hide file tree
Showing 14 changed files with 3 additions and 115 deletions.
14 changes: 0 additions & 14 deletions python/rapidocr_onnxruntime/ch_ppocr_v2_cls/config.yaml

This file was deleted.

29 changes: 0 additions & 29 deletions python/rapidocr_onnxruntime/ch_ppocr_v3_det/config.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions python/rapidocr_onnxruntime/ch_ppocr_v3_rec/config.yaml

This file was deleted.

6 changes: 3 additions & 3 deletions python/rapidocr_onnxruntime/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import cv2
import numpy as np

from .ch_ppocr_v2_cls import TextClassifier
from .ch_ppocr_v3_det import TextDetector
from .ch_ppocr_v3_rec import TextRecognizer
from .ch_ppocr_cls import TextClassifier
from .ch_ppocr_det import TextDetector
from .ch_ppocr_rec import TextRecognizer
from .utils import (
LoadImage,
UpdateParameters,
Expand Down
57 changes: 0 additions & 57 deletions python/tests/test_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import cv2
import numpy as np
import pytest
from base_module import BaseModule

root_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(root_dir))
Expand Down Expand Up @@ -208,59 +207,3 @@ def test_input_three_ndim_one_channel():

assert result[0][1] == "正品促销"
assert len(result) == 17


def test_det():
module_name = "ch_ppocr_v3_det"
class_name = "TextDetector"

base = BaseModule(package_name)
TextDetector = base.init_module(module_name, class_name)

yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
config["model_path"] = str(base.package_dir / config["model_path"])

text_det = TextDetector(config)
img_path = base.tests_dir / "test_files" / "text_det.jpg"
img = cv2.imread(str(img_path))
dt_boxes, elapse = text_det(img)
assert dt_boxes.shape == (18, 4, 2)


def test_cls():
module_name = "ch_ppocr_v2_cls"
class_name = "TextClassifier"

base = BaseModule(package_name=package_name)
TextClassifier = base.init_module(module_name, class_name)

yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
config["model_path"] = str(base.package_dir / config["model_path"])

text_cls = TextClassifier(config)

img_path = base.tests_dir / "test_files" / "text_cls.jpg"
img = cv2.imread(str(img_path))
result = text_cls([img])
assert result[1][0][0] == "180"


def test_rec():
module_name = "ch_ppocr_v3_rec"
class_name = "TextRecognizer"

base = BaseModule(package_name)
TextRecognizer = base.init_module(module_name, class_name)

yaml_path = base.package_dir / module_name / "config.yaml"
config = base.read_yaml(str(yaml_path))
config["model_path"] = str(base.package_dir / config["model_path"])

text_rec = TextRecognizer(config)

img_path = base.tests_dir / "test_files" / "text_rec.jpg"
img = cv2.imread(str(img_path))
rec_res, elapse = text_rec(img)
assert rec_res[0][0] == "韩国小馆"

0 comments on commit 6b60256

Please sign in to comment.