Skip to content

Commit

Permalink
Adds type hinting
Browse files Browse the repository at this point in the history
Attempts to add type hints as much as possible, without breaking
usage. However, some methods may deserve refactoring if the intended
behavior differed from what mypy found. (e.g. Returning None when a
boolean was desired)
  • Loading branch information
mfulgo committed Oct 11, 2024
1 parent 049a5eb commit fea65cd
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 95 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ classifiers = [
'Topic :: Software Development :: Testing',
'Topic :: Utilities',
]
dependencies = ["pytest>=7.0.0"]
dependencies = [
"pytest >= 7.0.0",
"typing-extensions >= 4.12.2, < 5; python_version < '3.11'"
]

[project.urls]
Home = "https://github.com/okken/pytest-check"
Expand Down
11 changes: 6 additions & 5 deletions src/pytest_check/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import pytest
from . import check_functions

# make sure assert rewriting happens
pytest.register_assert_rewrite("pytest_check.check_functions")

# allow for top level helper function access:
# import pytest_check
# pytest_check.equal(1, 1)
from pytest_check.check_functions import * # noqa: F401, F402, F403, E402
from .check_functions import * # noqa: F401, F402, F403, E402

# allow to know if any_failures due to any previous check
from pytest_check.check_log import any_failures # noqa: F401, F402, F403, E402
from .check_log import any_failures # noqa: F401, F402, F403, E402

# allow top level raises:
# from pytest_check import raises
# with raises(Exception):
# raise Exception
# with raises(AssertionError):
# assert 0
from pytest_check.check_raises import raises # noqa: F401, F402, F403, E402
from .check_raises import raises # noqa: F401, F402, F403, E402

# allow for with blocks and assert:
# from pytest_check import check
# with check:
# assert 1 == 2
from pytest_check.context_manager import check # noqa: F401, F402, F403, E402
from .context_manager import check # noqa: F401, F402, F403, E402

# allow check.raises()
setattr(check, "raises", raises)
Expand All @@ -33,7 +34,7 @@

# allow check.check as a context manager.
# weird, but some people are doing it.
# decprecate this eventually
# deprecate this eventually
setattr(check, "check", check)

# allow for helper functions to be part of check context
Expand Down
103 changes: 76 additions & 27 deletions src/pytest_check/check_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
from __future__ import annotations
import functools
import sys
from typing import (
Any,
Callable,
Container,
Protocol,
SupportsFloat,
SupportsIndex,
TypeVar,
Union,
)
if sys.version_info < (3, 10): # pragma: no cover
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

import pytest
import math
Expand Down Expand Up @@ -33,12 +49,31 @@
]


def check_func(func):
_P = ParamSpec("_P")
_T = TypeVar("_T")

class _ComparableGreaterThan(Protocol):
def __gt__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableGreaterThanOrEqual(Protocol):
def __ge__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableLessThan(Protocol):
def __lt__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableLessThanOrEqual(Protocol):
def __le__(self, other: Any) -> bool: ... # pragma: no cover


def check_func(func: Callable[_P, _T]) -> Callable[_P, bool]:
@functools.wraps(func)
def wrapper(*args, **kwds):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool:
__tracebackhide__ = True
try:
func(*args, **kwds)
func(*args, **kwargs)
return True
except AssertionError as e:
log_failure(e)
Expand All @@ -47,11 +82,11 @@ def wrapper(*args, **kwds):
return wrapper


def assert_equal(a, b, msg=""): # pragma: no cover
def assert_equal(a: object, b: object, msg: str = "") -> None: # pragma: no cover
assert a == b, msg


def equal(a, b, msg=""):
def equal(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a == b:
return True
Expand All @@ -60,7 +95,7 @@ def equal(a, b, msg=""):
return False


def not_equal(a, b, msg=""):
def not_equal(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a != b:
return True
Expand All @@ -69,7 +104,7 @@ def not_equal(a, b, msg=""):
return False


def is_(a, b, msg=""):
def is_(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a is b:
return True
Expand All @@ -78,7 +113,7 @@ def is_(a, b, msg=""):
return False


def is_not(a, b, msg=""):
def is_not(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a is not b:
return True
Expand All @@ -87,7 +122,7 @@ def is_not(a, b, msg=""):
return False


def is_true(x, msg=""):
def is_true(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if bool(x):
return True
Expand All @@ -96,7 +131,7 @@ def is_true(x, msg=""):
return False


def is_false(x, msg=""):
def is_false(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if not bool(x):
return True
Expand All @@ -105,7 +140,7 @@ def is_false(x, msg=""):
return False


def is_none(x, msg=""):
def is_none(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if x is None:
return True
Expand All @@ -114,7 +149,7 @@ def is_none(x, msg=""):
return False


def is_not_none(x, msg=""):
def is_not_none(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if x is not None:
return True
Expand All @@ -123,7 +158,7 @@ def is_not_none(x, msg=""):
return False


def is_nan(a, msg=""):
def is_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool:
__tracebackhide__ = True
if math.isnan(a):
return True
Expand All @@ -132,15 +167,16 @@ def is_nan(a, msg=""):
return False


def is_not_nan(a, msg=""):
def is_not_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool:
__tracebackhide__ = True
if not math.isnan(a):
return True
else:
log_failure(f"check {a} is not NaN", msg)
return False

def is_in(a, b, msg=""):

def is_in(a: _T, b: Container[_T], msg: str = "") -> bool:
__tracebackhide__ = True
if a in b:
return True
Expand All @@ -149,7 +185,7 @@ def is_in(a, b, msg=""):
return False


def is_not_in(a, b, msg=""):
def is_not_in(a: _T, b: Container[_T], msg: str = "") -> bool:
__tracebackhide__ = True
if a not in b:
return True
Expand All @@ -158,7 +194,9 @@ def is_not_in(a, b, msg=""):
return False


def is_instance(a, b, msg=""):
_TypeTuple = Union[type, tuple['_TypeTuple', ...]]

def is_instance(a: object, b: _TypeTuple, msg: str = "") -> bool:
__tracebackhide__ = True
if isinstance(a, b):
return True
Expand All @@ -167,7 +205,7 @@ def is_instance(a, b, msg=""):
return False


def is_not_instance(a, b, msg=""):
def is_not_instance(a: object, b: _TypeTuple, msg: str = "") -> bool:
__tracebackhide__ = True
if not isinstance(a, b):
return True
Expand All @@ -176,7 +214,9 @@ def is_not_instance(a, b, msg=""):
return False


def almost_equal(a, b, rel=None, abs=None, msg=""):
def almost_equal(
a: object, b: object, rel: Any = None, abs: Any = None, msg: str = ""
) -> bool:
"""
For rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
Expand All @@ -189,7 +229,9 @@ def almost_equal(a, b, rel=None, abs=None, msg=""):
return False


def not_almost_equal(a, b, rel=None, abs=None, msg=""):
def not_almost_equal(
a: object, b: object, rel: Any = None, abs: Any = None, msg: str = ""
) -> bool:
"""
For rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
Expand All @@ -202,7 +244,7 @@ def not_almost_equal(a, b, rel=None, abs=None, msg=""):
return False


def greater(a, b, msg=""):
def greater(a: _ComparableGreaterThan, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a > b:
return True
Expand All @@ -211,7 +253,7 @@ def greater(a, b, msg=""):
return False


def greater_equal(a, b, msg=""):
def greater_equal(a: _ComparableGreaterThanOrEqual, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a >= b:
return True
Expand All @@ -220,7 +262,7 @@ def greater_equal(a, b, msg=""):
return False


def less(a, b, msg=""):
def less(a: _ComparableLessThan, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a < b:
return True
Expand All @@ -229,7 +271,7 @@ def less(a, b, msg=""):
return False


def less_equal(a, b, msg=""):
def less_equal(a: _ComparableLessThanOrEqual, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a <= b:
return True
Expand All @@ -238,7 +280,9 @@ def less_equal(a, b, msg=""):
return False


def between(b, a, c, msg="", ge=False, le=False):
def between(
b: Any, a: Any, c: Any, msg: str = "", ge: bool = False, le: bool = False
) -> bool:
__tracebackhide__ = True
if ge and le:
if a <= b <= c:
Expand Down Expand Up @@ -266,11 +310,16 @@ def between(b, a, c, msg="", ge=False, le=False):
return False


def between_equal(b, a, c, msg=""):
def between_equal(
b: _ComparableLessThanOrEqual,
a: _ComparableLessThanOrEqual,
c: object,
msg:str = "",
) -> bool:
__tracebackhide__ = True
return between(b, a, c, msg, ge=True, le=True)


def fail(msg):
def fail(msg: str) -> None:
__tracebackhide__ = True
log_failure(msg)
24 changes: 15 additions & 9 deletions src/pytest_check/check_log.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from __future__ import annotations
from collections.abc import Iterable
from typing import Callable

from .pseudo_traceback import _build_pseudo_trace_str

should_use_color = False
COLOR_RED = "\x1b[31m"
COLOR_RESET = "\x1b[0m"
_failures = []
_failures: list[str] = []
_stop_on_fail = False

_default_max_fail = None
_default_max_report = None
_default_max_tb = 1

_max_fail = _default_max_fail
_max_report = _default_max_report
_max_fail: int | None = _default_max_fail
_max_report: int | None = _default_max_report
_max_tb = _default_max_tb
_num_failures = 0
_fail_function = None
_fail_function: Callable[[str], None] | None = None

_showlocals = False

def clear_failures():
# get's called at the beginning of each test function
def clear_failures() -> None:
# gets called at the beginning of each test function
global _failures, _num_failures
global _max_fail, _max_report, _max_tb
_failures = []
Expand All @@ -33,11 +37,13 @@ def any_failures() -> bool:
return bool(get_failures())


def get_failures():
def get_failures() -> list[str]:
return _failures


def log_failure(msg="", check_str="", tb=None):
def log_failure(
msg: object = "", check_str: str = "", tb: Iterable[str] | None = None
) -> None:
global _num_failures
__tracebackhide__ = True
_num_failures += 1
Expand All @@ -60,7 +66,7 @@ def log_failure(msg="", check_str="", tb=None):
msg = f"FAILURE: {msg}"
_failures.append(msg)
if _fail_function:
_fail_function(msg)
_fail_function(str(msg))

if _max_fail and (_num_failures >= _max_fail):
assert_msg = f"pytest-check max fail of {_num_failures} reached"
Expand Down
Loading

0 comments on commit fea65cd

Please sign in to comment.