diff --git a/pyproject.toml b/pyproject.toml index 2858030..6cec4cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ classifiers = [ 'Topic :: Software Development :: Testing', 'Topic :: Utilities', ] -dependencies = ["pytest>=7.0.0"] +dependencies = [ + "pytest >= 7.0.0", + "typing-extensions >= 4.12.2, < 5; python_version < '3.11'" +] [project.urls] Home = "https://github.com/okken/pytest-check" diff --git a/src/pytest_check/__init__.py b/src/pytest_check/__init__.py index 210fb19..5c0b4bf 100644 --- a/src/pytest_check/__init__.py +++ b/src/pytest_check/__init__.py @@ -1,4 +1,5 @@ import pytest +from . import check_functions # make sure assert rewriting happens pytest.register_assert_rewrite("pytest_check.check_functions") @@ -6,10 +7,10 @@ # allow for top level helper function access: # import pytest_check # pytest_check.equal(1, 1) -from pytest_check.check_functions import * # noqa: F401, F402, F403, E402 +from .check_functions import * # noqa: F401, F402, F403, E402 # allow to know if any_failures due to any previous check -from pytest_check.check_log import any_failures # noqa: F401, F402, F403, E402 +from .check_log import any_failures # noqa: F401, F402, F403, E402 # allow top level raises: # from pytest_check import raises @@ -17,13 +18,13 @@ # raise Exception # with raises(AssertionError): # assert 0 -from pytest_check.check_raises import raises # noqa: F401, F402, F403, E402 +from .check_raises import raises # noqa: F401, F402, F403, E402 # allow for with blocks and assert: # from pytest_check import check # with check: # assert 1 == 2 -from pytest_check.context_manager import check # noqa: F401, F402, F403, E402 +from .context_manager import check # noqa: F401, F402, F403, E402 # allow check.raises() setattr(check, "raises", raises) @@ -33,7 +34,7 @@ # allow check.check as a context manager. # weird, but some people are doing it. -# decprecate this eventually +# deprecate this eventually setattr(check, "check", check) # allow for helper functions to be part of check context diff --git a/src/pytest_check/check_functions.py b/src/pytest_check/check_functions.py index 344e121..3ac35cc 100644 --- a/src/pytest_check/check_functions.py +++ b/src/pytest_check/check_functions.py @@ -1,4 +1,20 @@ +from __future__ import annotations import functools +import sys +from typing import ( + Any, + Callable, + Container, + Protocol, + SupportsFloat, + SupportsIndex, + TypeVar, + Union, +) +if sys.version_info < (3, 10): # pragma: no cover + from typing_extensions import ParamSpec +else: + from typing import ParamSpec import pytest import math @@ -33,12 +49,31 @@ ] -def check_func(func): +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class _ComparableGreaterThan(Protocol): + def __gt__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableGreaterThanOrEqual(Protocol): + def __ge__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableLessThan(Protocol): + def __lt__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableLessThanOrEqual(Protocol): + def __le__(self, other: Any) -> bool: ... # pragma: no cover + + +def check_func(func: Callable[_P, _T]) -> Callable[_P, bool]: @functools.wraps(func) - def wrapper(*args, **kwds): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool: __tracebackhide__ = True try: - func(*args, **kwds) + func(*args, **kwargs) return True except AssertionError as e: log_failure(e) @@ -47,11 +82,11 @@ def wrapper(*args, **kwds): return wrapper -def assert_equal(a, b, msg=""): # pragma: no cover +def assert_equal(a: object, b: object, msg: str = "") -> None: # pragma: no cover assert a == b, msg -def equal(a, b, msg=""): +def equal(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a == b: return True @@ -60,7 +95,7 @@ def equal(a, b, msg=""): return False -def not_equal(a, b, msg=""): +def not_equal(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a != b: return True @@ -69,7 +104,7 @@ def not_equal(a, b, msg=""): return False -def is_(a, b, msg=""): +def is_(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a is b: return True @@ -78,7 +113,7 @@ def is_(a, b, msg=""): return False -def is_not(a, b, msg=""): +def is_not(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a is not b: return True @@ -87,7 +122,7 @@ def is_not(a, b, msg=""): return False -def is_true(x, msg=""): +def is_true(x: object, msg: str = "") -> bool: __tracebackhide__ = True if bool(x): return True @@ -96,7 +131,7 @@ def is_true(x, msg=""): return False -def is_false(x, msg=""): +def is_false(x: object, msg: str = "") -> bool: __tracebackhide__ = True if not bool(x): return True @@ -105,7 +140,7 @@ def is_false(x, msg=""): return False -def is_none(x, msg=""): +def is_none(x: object, msg: str = "") -> bool: __tracebackhide__ = True if x is None: return True @@ -114,7 +149,7 @@ def is_none(x, msg=""): return False -def is_not_none(x, msg=""): +def is_not_none(x: object, msg: str = "") -> bool: __tracebackhide__ = True if x is not None: return True @@ -123,7 +158,7 @@ def is_not_none(x, msg=""): return False -def is_nan(a, msg=""): +def is_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool: __tracebackhide__ = True if math.isnan(a): return True @@ -132,7 +167,7 @@ def is_nan(a, msg=""): return False -def is_not_nan(a, msg=""): +def is_not_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool: __tracebackhide__ = True if not math.isnan(a): return True @@ -140,7 +175,8 @@ def is_not_nan(a, msg=""): log_failure(f"check {a} is not NaN", msg) return False -def is_in(a, b, msg=""): + +def is_in(a: _T, b: Container[_T], msg: str = "") -> bool: __tracebackhide__ = True if a in b: return True @@ -149,7 +185,7 @@ def is_in(a, b, msg=""): return False -def is_not_in(a, b, msg=""): +def is_not_in(a: _T, b: Container[_T], msg: str = "") -> bool: __tracebackhide__ = True if a not in b: return True @@ -158,7 +194,9 @@ def is_not_in(a, b, msg=""): return False -def is_instance(a, b, msg=""): +_TypeTuple = Union[type, tuple['_TypeTuple', ...]] + +def is_instance(a: object, b: _TypeTuple, msg: str = "") -> bool: __tracebackhide__ = True if isinstance(a, b): return True @@ -167,7 +205,7 @@ def is_instance(a, b, msg=""): return False -def is_not_instance(a, b, msg=""): +def is_not_instance(a: object, b: _TypeTuple, msg: str = "") -> bool: __tracebackhide__ = True if not isinstance(a, b): return True @@ -176,7 +214,9 @@ def is_not_instance(a, b, msg=""): return False -def almost_equal(a, b, rel=None, abs=None, msg=""): +def almost_equal( + a: object, b: object, rel: Any = None, abs: Any = None, msg: str = "" +) -> bool: """ For rel and abs tolerance, see: See https://docs.pytest.org/en/latest/builtin.html#pytest.approx @@ -189,7 +229,9 @@ def almost_equal(a, b, rel=None, abs=None, msg=""): return False -def not_almost_equal(a, b, rel=None, abs=None, msg=""): +def not_almost_equal( + a: object, b: object, rel: Any = None, abs: Any = None, msg: str = "" +) -> bool: """ For rel and abs tolerance, see: See https://docs.pytest.org/en/latest/builtin.html#pytest.approx @@ -202,7 +244,7 @@ def not_almost_equal(a, b, rel=None, abs=None, msg=""): return False -def greater(a, b, msg=""): +def greater(a: _ComparableGreaterThan, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a > b: return True @@ -211,7 +253,7 @@ def greater(a, b, msg=""): return False -def greater_equal(a, b, msg=""): +def greater_equal(a: _ComparableGreaterThanOrEqual, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a >= b: return True @@ -220,7 +262,7 @@ def greater_equal(a, b, msg=""): return False -def less(a, b, msg=""): +def less(a: _ComparableLessThan, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a < b: return True @@ -229,7 +271,7 @@ def less(a, b, msg=""): return False -def less_equal(a, b, msg=""): +def less_equal(a: _ComparableLessThanOrEqual, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a <= b: return True @@ -238,7 +280,9 @@ def less_equal(a, b, msg=""): return False -def between(b, a, c, msg="", ge=False, le=False): +def between( + b: Any, a: Any, c: Any, msg: str = "", ge: bool = False, le: bool = False +) -> bool: __tracebackhide__ = True if ge and le: if a <= b <= c: @@ -266,11 +310,16 @@ def between(b, a, c, msg="", ge=False, le=False): return False -def between_equal(b, a, c, msg=""): +def between_equal( + b: _ComparableLessThanOrEqual, + a: _ComparableLessThanOrEqual, + c: object, + msg:str = "", +) -> bool: __tracebackhide__ = True return between(b, a, c, msg, ge=True, le=True) -def fail(msg): +def fail(msg: str) -> None: __tracebackhide__ = True log_failure(msg) diff --git a/src/pytest_check/check_log.py b/src/pytest_check/check_log.py index ca5a1b0..84fad43 100644 --- a/src/pytest_check/check_log.py +++ b/src/pytest_check/check_log.py @@ -1,25 +1,29 @@ +from __future__ import annotations +from collections.abc import Iterable +from typing import Callable + from .pseudo_traceback import _build_pseudo_trace_str should_use_color = False COLOR_RED = "\x1b[31m" COLOR_RESET = "\x1b[0m" -_failures = [] +_failures: list[str] = [] _stop_on_fail = False _default_max_fail = None _default_max_report = None _default_max_tb = 1 -_max_fail = _default_max_fail -_max_report = _default_max_report +_max_fail: int | None = _default_max_fail +_max_report: int | None = _default_max_report _max_tb = _default_max_tb _num_failures = 0 -_fail_function = None +_fail_function: Callable[[str], None] | None = None _showlocals = False -def clear_failures(): - # get's called at the beginning of each test function +def clear_failures() -> None: + # gets called at the beginning of each test function global _failures, _num_failures global _max_fail, _max_report, _max_tb _failures = [] @@ -33,11 +37,13 @@ def any_failures() -> bool: return bool(get_failures()) -def get_failures(): +def get_failures() -> list[str]: return _failures -def log_failure(msg="", check_str="", tb=None): +def log_failure( + msg: object = "", check_str: str = "", tb: Iterable[str] | None = None +) -> None: global _num_failures __tracebackhide__ = True _num_failures += 1 @@ -60,7 +66,7 @@ def log_failure(msg="", check_str="", tb=None): msg = f"FAILURE: {msg}" _failures.append(msg) if _fail_function: - _fail_function(msg) + _fail_function(str(msg)) if _max_fail and (_num_failures >= _max_fail): assert_msg = f"pytest-check max fail of {_num_failures} reached" diff --git a/src/pytest_check/check_raises.py b/src/pytest_check/check_raises.py index bd301a3..0eca41c 100644 --- a/src/pytest_check/check_raises.py +++ b/src/pytest_check/check_raises.py @@ -1,9 +1,18 @@ +from __future__ import annotations +from typing import Iterable, Any + from .check_log import log_failure _stop_on_fail = False -def raises(expected_exception, *args, **kwargs): +# TODO: Returning Any isn't ideal, but returning CheckRaisesContext | None +# would require callers to type ignore or declare the type when using `with`. +# Or, it could always return CheckRaisesContext, just an empty one after +# calling the passed function. +def raises( + expected_exception: type | Iterable[type], *args: Any, **kwargs: object +) -> Any: """ Check that a given callable or context raises an error of a given type. @@ -39,23 +48,23 @@ def raises(expected_exception, *args, **kwargs): __tracebackhide__ = True if isinstance(expected_exception, type): - excepted_exceptions = (expected_exception,) + expected_exceptions: Iterable[type] = (expected_exception,) else: - excepted_exceptions = expected_exception + expected_exceptions = expected_exception assert all( isinstance(exc, type) or issubclass(exc, BaseException) - for exc in excepted_exceptions + for exc in expected_exceptions ) msg = kwargs.pop("msg", None) if not args: assert not kwargs, f"Unexpected kwargs for pytest_check.raises: {kwargs}" - return CheckRaisesContext(expected_exception, msg=msg) + return CheckRaisesContext(*expected_exceptions, msg=msg) else: func = args[0] assert callable(func) - with CheckRaisesContext(expected_exception, msg=msg): + with CheckRaisesContext(*expected_exceptions, msg=msg): func(*args[1:], **kwargs) @@ -69,18 +78,18 @@ class CheckRaisesContext: CheckContextManager. """ - def __init__(self, *expected_excs, msg=None): + def __init__(self, *expected_excs: type, msg: object = None) -> None: self.expected_excs = expected_excs self.msg = msg - def __enter__(self): + def __enter__(self) -> "CheckRaisesContext": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: type, exc_val: object, exc_tb: object) -> bool: __tracebackhide__ = True if exc_type is not None and issubclass(exc_type, self.expected_excs): - # This is the case where an error has occured within the context - # but it is the type we're expecting. Therefore we return True + # This is the case where an error has occured within the context, + # but it is the type we're expecting. Therefore, we return True # to silence this error and proceed with execution outside the # context. return True @@ -97,3 +106,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): # without raising an error, hence `return True`. log_failure(self.msg if self.msg else exc_val) return True + + # Stop on fail, so return True + return False diff --git a/src/pytest_check/context_manager.py b/src/pytest_check/context_manager.py index 4f19877..363a682 100644 --- a/src/pytest_check/context_manager.py +++ b/src/pytest_check/context_manager.py @@ -1,5 +1,9 @@ +from __future__ import annotations import warnings import traceback +from types import TracebackType +from typing import Callable, Type + from . import check_log from .check_log import log_failure @@ -12,18 +16,23 @@ class CheckContextManager: - def __init__(self): - self.msg = None + def __init__(self) -> None: + self.msg: object = None - def __enter__(self): + def __enter__(self) -> "CheckContextManager": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: __tracebackhide__ = True if exc_type is not None and issubclass(exc_type, AssertionError): if _stop_on_fail: self.msg = None - return + return None else: fmt_tb = traceback.format_exception(exc_type, exc_val, exc_tb) if self.msg is not None: @@ -33,27 +42,28 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.msg = None return True self.msg = None + return None - def __call__(self, msg=None): + def __call__(self, msg: object = None) -> "CheckContextManager": self.msg = msg return self - def set_no_tb(self): + def set_no_tb(self) -> None: warnings.warn( "set_no_tb() is deprecated; use set_max_tb(0)", DeprecationWarning ) check_log._max_tb = 0 - def set_max_fail(self, x): + def set_max_fail(self, x: int) -> None: check_log._max_fail = x - def set_max_report(self, x): + def set_max_report(self, x: int) -> None: check_log._max_report = x - def set_max_tb(self, x): + def set_max_tb(self, x: int) -> None: check_log._max_tb = x - def call_on_fail(self, func): + def call_on_fail(self, func: Callable[[str], None]) -> None: """Experimental feature - may change with any release""" check_log._fail_function = func diff --git a/src/pytest_check/plugin.py b/src/pytest_check/plugin.py index 659a446..d2e4752 100644 --- a/src/pytest_check/plugin.py +++ b/src/pytest_check/plugin.py @@ -1,28 +1,38 @@ import sys import os +from typing import Generator import pytest -from _pytest._code.code import ExceptionInfo +from pytest import CallInfo, Config, Item, Parser, TestReport from _pytest.skipping import xfailed_key -from _pytest.reports import ExceptionChainRepr -from _pytest._code.code import ExceptionRepr, ReprFileLocation +from _pytest._code.code import ( + ExceptionChainRepr, + ExceptionInfo, + ExceptionRepr, + ReprFileLocation, +) +from pluggy import Result from . import check_log, check_raises, context_manager, pseudo_traceback +from .context_manager import CheckContextManager @pytest.hookimpl(hookwrapper=True, trylast=True) -def pytest_runtest_makereport(item, call): - outcome = yield - report = outcome.get_result() +def pytest_runtest_makereport( + item: Item, call: CallInfo[None] +) -> Generator[None, Result[TestReport], None]: + outcome: Result[TestReport] = yield + report: TestReport = outcome.get_result() num_failures = check_log._num_failures failures = check_log.get_failures() check_log.clear_failures() if failures: - if item._store[xfailed_key]: + xfailed_value = item._store[xfailed_key] + if xfailed_value: report.outcome = "skipped" - report.wasxfail = item._store[xfailed_key].reason + report.wasxfail = xfailed_value.reason else: summary = f"Failed Checks: {num_failures}" @@ -57,8 +67,9 @@ def pytest_runtest_makereport(item, call): e_str = str(e) e_str = e_str.split('FAILURE: ')[1] # Remove redundant "Failure: " reprcrash = ReprFileLocation(item.nodeid, 0, e_str) - reprtraceback = ExceptionRepr(reprcrash, excinfo) - chain_repr = ExceptionChainRepr([(reprtraceback, reprcrash, str(e))]) + # FIXME - the next two lines have broken types + reprtraceback = ExceptionRepr(reprcrash, excinfo) # type: ignore + chain_repr = ExceptionChainRepr([(reprtraceback, reprcrash, str(e))]) # type: ignore report.longrepr = chain_repr else: # pragma: no cover # coverage is run on latest pytest @@ -69,7 +80,7 @@ def pytest_runtest_makereport(item, call): call.excinfo = excinfo -def pytest_configure(config): +def pytest_configure(config: Config) -> None: # Add some red to the failure output, if stdout can accommodate it. isatty = sys.stdout.isatty() color = getattr(config.option, "color", None) @@ -100,12 +111,12 @@ def pytest_configure(config): # def test_a(check): # check.equal(a, b) @pytest.fixture(name="check") -def check_fixture(): +def check_fixture() -> CheckContextManager: return context_manager.check # add some options -def pytest_addoption(parser): +def pytest_addoption(parser: Parser) -> None: parser.addoption( "--check-max-report", action="store", diff --git a/src/pytest_check/pseudo_traceback.py b/src/pytest_check/pseudo_traceback.py index ccbf1d8..e631904 100644 --- a/src/pytest_check/pseudo_traceback.py +++ b/src/pytest_check/pseudo_traceback.py @@ -1,15 +1,28 @@ +from __future__ import annotations import inspect import os import re +import sys +from collections.abc import Iterable +from inspect import FrameInfo from pprint import pformat +from typing import AnyStr, Any + +if sys.version_info < (3, 11): # pragma: no cover + from typing_extensions import LiteralString +else: + from typing import LiteralString + _traceback_style = "auto" -def get_full_context(frame): +def get_full_context( + frame: FrameInfo +) -> tuple[AnyStr | LiteralString, Any, Any, str, Any, bool]: (_, filename, line, funcname, contextlist) = frame[0:5] - locals = frame.frame.f_locals - tb_hide = locals.get("__tracebackhide__", False) + locals_ = frame.frame.f_locals + tb_hide = locals_.get("__tracebackhide__", False) try: filename = os.path.relpath(filename) except ValueError: # pragma: no cover @@ -21,13 +34,13 @@ def get_full_context(frame): # But.... we'll keep looking for a way to test it. :) filename = os.path.abspath(filename) # pragma: no cover context = contextlist[0].strip() if contextlist else "" - return (filename, line, funcname, context, locals, tb_hide) + return filename, line, funcname, context, locals_, tb_hide COLOR_RED = "\x1b[31m" COLOR_RESET = "\x1b[0m" -def reformat_raw_traceback(lines, color): - formatted = [] +def reformat_raw_traceback(lines: Iterable[str], color: bool) -> str: + formatted: list[str] = [] for line in lines: if 'Traceback (most recent call last)' in line: continue @@ -49,12 +62,14 @@ def reformat_raw_traceback(lines, color): # I don't have a test case to hit this clause yet # And I can't think of one. # But it feels weird to not have the if/else. - # Thus the "no cover" + # Thus, the "no cover" formatted.append(line) # pragma: no cover return '\n'.join(formatted) -def _build_pseudo_trace_str(showlocals, tb, color): +def _build_pseudo_trace_str( + showlocals: bool, tb: Iterable[str] | None, color: bool +) -> str: """ built traceback styles for better error message only supports no @@ -76,7 +91,7 @@ def _build_pseudo_trace_str(showlocals, tb, color): # we want to trace through user code, not 3rd party or builtin libs if "site-packages" in file: break - # if called outside of a test, we might hit this + # if called outside a test, we might hit this if "" in func: break if tb_hide: diff --git a/src/pytest_check/py.typed b/src/pytest_check/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tox.ini b/tox.ini index 01c6af7..64c9386 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,14 @@ [tox] # Environments to run by default -env_list = py39, py310, py311, py312, pytest_earliest, coverage, lint +env_list = + py39 + py310 + py311 + py312 + pytest_earliest + coverage + lint + mypy_earliest skip_missing_interpreters = true @@ -12,24 +20,34 @@ wheel_build_env = .pkg [testenv:coverage] deps = coverage -basepython = python3.12 +base_python = python3.12 commands = - coverage run --source={envsitepackagesdir}/pytest_check,tests -m pytest + coverage run --source={envsitepackagesdir}/pytest_check,tests -m pytest coverage report --fail-under=100 --show-missing description = Run pytest, with coverage [testenv:pytest_earliest] deps = pytest==7.0.0 -basepython = python3.11 +base_python = python3.11 commands = pytest {posargs} description = Run earliest supported pytest [testenv:lint] -skip_install = true -deps = ruff -basepython = python3.12 -commands = ruff check src tests examples -description = Run ruff over src, test, exampless +deps = + mypy + ruff +base_python = python3.12 +commands = + ruff check src tests examples + mypy --strict --pretty src + mypy --pretty tests +description = Run ruff and mypy over src, test, examples + +[testenv:mypy_earliest] +deps = mypy +base_python = python3.9 +commands = mypy --strict --pretty src +description = Run mypy over src for earliest supported python [pytest] addopts =