Skip to content

Commit

Permalink
add better unit tests
Browse files Browse the repository at this point in the history
also:
- limit python versions to 3.9 and 3.10
- remove unecessary black excludes
- add 2 vscode extension recommendations
- add .prettierrc.yaml
  • Loading branch information
mauwii committed Jul 20, 2023
1 parent 5ace23e commit d1bc69c
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 32 deletions.
5 changes: 5 additions & 0 deletions .prettierrc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
editorconfig: true
overrides:
- files: '*.{yaml,yml}'
options:
singleQuote: true
2 changes: 2 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
"ms-python.vscode-pylance",
"redhat.vscode-yaml",
"tamasfe.even-better-toml",
"github.vscode-github-actions",
"esbenp.prettier-vscode"
]
}
2 changes: 1 addition & 1 deletion patchmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__app_id__ = "mauwii/PyPatchMatch"
__app_name__ = "PyPatchMatch"
__version__ = "1.0.0"
__version__ = "1.0.1"
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dynamic=["version"]
license={file="LICENSE"}
name='PyPatchMatch'
readme={content-type="text/markdown", file="README.md"}
requires-python=">=3.9"
requires-python=">=3.9,<3.11"

[project.urls]
'Source Code'='https://github.com/mauwii/PyPatchMatch'
Expand Down Expand Up @@ -47,13 +47,9 @@ include=["patchmatch", "patchmatch.csrc"]
exclude='''
/(
.git
| .hg
| .mypy_cache
| .tox
| .venv
| _build
| buck-out
| build
| csrc
| dist
)/
'''
Expand All @@ -62,6 +58,9 @@ line-length=88
source=['examples', 'patchmatch']
target-version=['py39']

[tool.coverage.report]
fail_under=75

[tool.isort]
profile="black"

Expand Down
90 changes: 90 additions & 0 deletions test_patch_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import tempfile

import numpy as np
import pytest
from PIL import Image

from patchmatch import patch_match


@pytest.fixture
def input_image():
return Image.open("./examples/images/forest_pruned.bmp")


@pytest.fixture
def mask(input_image):
return np.zeros((input_image.height, input_image.width), dtype=np.uint8)


@pytest.fixture
def ijmap(input_image):
return np.zeros((input_image.height, input_image.width, 3), dtype=np.float32)


def test_inpaint_custom_patch_size(input_image):
result = Image.fromarray(patch_match.inpaint(image=input_image, patch_size=3))
assert isinstance(result, Image.Image)


def test_inpaint_global_mask(input_image):
patch_match.set_verbose(True)
source = np.array(input_image)
source[:100, :100] = 255
global_mask = np.zeros_like(source[..., 0])
global_mask[:100, :100] = 1
result = Image.fromarray(
patch_match.inpaint(image=source, global_mask=global_mask, patch_size=3)
)
assert isinstance(result, Image.Image)


def test_download_url_to_file():
url = "https://github.com/mauwii/PyPatchMatch/raw/main/examples/images/forest.bmp"
dst = tempfile.mkstemp().__dir__
patch_match.download_url_to_file(url=url, dst=str(dst))
read_image = Image.open(str(dst))
assert isinstance(read_image, Image.Image)


def test_inpaint_regularity(input_image, ijmap):
ijmap = ijmap
result = patch_match.inpaint_regularity(input_image, None, ijmap, patch_size=3)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_guide_weight(input_image, ijmap):
ijmap = ijmap
result = patch_match.inpaint_regularity(input_image, None, ijmap, guide_weight=0.5)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_global_mask(input_image, mask, ijmap):
ijmap = ijmap
global_mask = mask
result = patch_match.inpaint_regularity(
input_image, None, ijmap, global_mask=global_mask
)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_mask(input_image, mask, ijmap):
ijmap = ijmap
mask = mask
result = patch_match.inpaint_regularity(input_image, mask, ijmap)
assert isinstance(result, np.ndarray)


def test_canonize_mask_array_2d_uint8():
mask = np.zeros((10, 10), dtype=np.uint8)
result = patch_match._canonize_mask_array(mask)
assert isinstance(result, np.ndarray)
assert result.ndim == 3
assert result.shape[2] == 1
assert result.dtype == np.uint8


def test_canonize_mask_array_invalid_input():
with pytest.raises(AssertionError):
mask = np.zeros((10, 10, 3), dtype=np.uint8)
patch_match._canonize_mask_array(mask)
25 changes: 0 additions & 25 deletions tests/test_inpaint.py

This file was deleted.

133 changes: 133 additions & 0 deletions tests/test_patch_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os

import numpy as np
import pytest
from PIL import Image

from patchmatch import patch_match


@pytest.fixture
def input_image():
return Image.open("./examples/images/forest_pruned.bmp")


@pytest.fixture
def ijmap(input_image):
return np.zeros((input_image.height, input_image.width, 3), dtype=np.float32)


@pytest.fixture
def mask(input_image):
return np.zeros((input_image.height, input_image.width), dtype=np.uint8)


@pytest.fixture
def temp_file():
filename = "temp.txt"
with open(filename, "w") as f:
f.write("test")
yield filename
os.remove(filename)


def test_inpaint_custom_patch_size(input_image, mask):
result = patch_match.inpaint(input_image, mask, patch_size=1)
assert isinstance(result, np.ndarray)


def test_inpaint_global_mask(input_image):
source = np.array(input_image)
source[:100, :100] = 255
global_mask = np.zeros_like(source[..., 0])
global_mask[:100, :100] = 1
result = Image.fromarray(
patch_match.inpaint(image=source, global_mask=global_mask, patch_size=3)
)
assert isinstance(result, Image.Image)


def test_np_to_pymat():
npmat = np.zeros((10, 10, 3), dtype=np.uint8)
result = patch_match.np_to_pymat(npmat)
assert isinstance(result, patch_match.CMatT)


def test_pymat_to_np():
pymat = patch_match.CMatT(None, patch_match.CShapeT(10, 10, 3), 0)
npmat = np.zeros((10, 10, 3), dtype=np.uint8)
pymat.data_ptr = npmat.ctypes.data
result = patch_match.pymat_to_np(pymat)
assert isinstance(result, np.ndarray)
assert result.shape == npmat.shape
assert result.dtype == npmat.dtype
assert np.allclose(result, npmat)


def test_canonize_mask_array_2d_uint8():
mask = np.zeros((10, 10), dtype=np.uint8)
result = patch_match._canonize_mask_array(mask)
assert isinstance(result, np.ndarray)
assert result.ndim == 3
assert result.shape[2] == 1
assert result.dtype == np.uint8


def test_canonize_mask_array_3d_uint8():
mask = np.zeros((10, 10, 1), dtype=np.uint8)
result = patch_match._canonize_mask_array(mask)
assert isinstance(result, np.ndarray)
assert result.ndim == 3
assert result.shape[2] == 1
assert result.dtype == np.uint8


def test_canonize_mask_array_pil_image():
mask = Image.new("L", (10, 10))
result = patch_match._canonize_mask_array(mask)
assert isinstance(result, np.ndarray)
assert result.ndim == 3
assert result.shape[2] == 1
assert result.dtype == np.uint8


def test_canonize_mask_array_invalid_input():
with pytest.raises(AssertionError):
mask = np.zeros((10, 10, 3), dtype=np.uint8)
patch_match._canonize_mask_array(mask)


def test_inpaint_regularity(input_image, ijmap):
result = patch_match.inpaint_regularity(input_image, None, ijmap, patch_size=3)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_patch_size(input_image, ijmap):
result = patch_match.inpaint_regularity(input_image, None, ijmap, patch_size=5)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_guide_weight(input_image, ijmap):
result = patch_match.inpaint_regularity(
input_image, None, ijmap, patch_size=3, guide_weight=0.5
)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_global_mask(input_image, ijmap, mask):
global_mask = mask
result = patch_match.inpaint_regularity(
input_image, None, ijmap, global_mask=global_mask, patch_size=3
)
assert isinstance(result, np.ndarray)


def test_inpaint_regularity_custom_mask(input_image, ijmap, mask):
result = patch_match.inpaint_regularity(input_image, mask, ijmap, patch_size=3)
assert isinstance(result, np.ndarray)


def test_download_url_to_file(temp_file):
url = "https://www.google.com"
patch_match.download_url_to_file(url, temp_file)
assert os.path.exists(temp_file)

0 comments on commit d1bc69c

Please sign in to comment.