Skip to content

Commit

Permalink
Update LoadImage Class and support Image as input.
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Feb 28, 2024
1 parent e55c3a2 commit 2caa0b0
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 60 deletions.
44 changes: 25 additions & 19 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError

root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]


class OrtInferSession:
Expand Down Expand Up @@ -122,8 +122,9 @@ def __call__(self, img: InputType) -> np.ndarray:
f"The img type {type(img)} does not in {InputType.__args__}"
)

origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img)
img = self.convert_img(img, origin_img_type)
return img

def load_img(self, img: InputType) -> np.ndarray:
Expand All @@ -142,9 +143,12 @@ def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, np.ndarray):
return img

if isinstance(img, Image.Image):
return np.array(img)

raise LoadImageError(f"{type(img)} is not supported!")

def convert_img(self, img: np.ndarray):
def convert_img(self, img: np.ndarray, origin_img_type):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

Expand All @@ -156,31 +160,20 @@ def convert_img(self, img: np.ndarray):
if channel == 2:
return self.cvt_two_to_three(img)

if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

if channel == 4:
return self.cvt_four_to_three(img)

if channel == 3:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)

raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
Expand All @@ -195,6 +188,19 @@ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
Expand Down
44 changes: 25 additions & 19 deletions python/rapidocr_openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError

root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]


class OpenVINOInferSession:
Expand Down Expand Up @@ -52,8 +52,9 @@ def __call__(self, img: InputType) -> np.ndarray:
f"The img type {type(img)} does not in {InputType.__args__}"
)

origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img)
img = self.convert_img(img, origin_img_type)
return img

def load_img(self, img: InputType) -> np.ndarray:
Expand All @@ -72,9 +73,12 @@ def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, np.ndarray):
return img

if isinstance(img, Image.Image):
return np.array(img)

raise LoadImageError(f"{type(img)} is not supported!")

def convert_img(self, img: np.ndarray):
def convert_img(self, img: np.ndarray, origin_img_type):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

Expand All @@ -86,31 +90,20 @@ def convert_img(self, img: np.ndarray):
if channel == 2:
return self.cvt_two_to_three(img)

if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

if channel == 4:
return self.cvt_four_to_three(img)

if channel == 3:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)

raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
Expand All @@ -125,6 +118,19 @@ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
Expand Down
44 changes: 25 additions & 19 deletions python/rapidocr_paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError

root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path]
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]


class PaddleInferSession:
Expand Down Expand Up @@ -136,8 +136,9 @@ def __call__(self, img: InputType) -> np.ndarray:
f"The img type {type(img)} does not in {InputType.__args__}"
)

origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img)
img = self.convert_img(img, origin_img_type)
return img

def load_img(self, img: InputType) -> np.ndarray:
Expand All @@ -156,9 +157,12 @@ def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, np.ndarray):
return img

if isinstance(img, Image.Image):
return np.array(img)

raise LoadImageError(f"{type(img)} is not supported!")

def convert_img(self, img: np.ndarray):
def convert_img(self, img: np.ndarray, origin_img_type):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

Expand All @@ -170,31 +174,20 @@ def convert_img(self, img: np.ndarray):
if channel == 2:
return self.cvt_two_to_three(img)

if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

if channel == 4:
return self.cvt_four_to_three(img)

if channel == 3:
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)

raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
Expand All @@ -209,6 +202,19 @@ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_cls_rec():

def test_det_cls_rec():
img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

result, _ = engine(img)
assert result[0][1] == "正品促销"
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_cls_rec():

def test_det_cls_rec():
img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

result, _ = engine(img)
assert result[0][1] == "正品促销"
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_vino.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_cls_rec():

def test_det_cls_rec():
img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

result, _ = engine(img)
assert result[0][1] == "正品促销"
Expand Down

0 comments on commit 2caa0b0

Please sign in to comment.