From fea65cd17d0e4fad7127ca724e412781041fa899 Mon Sep 17 00:00:00 2001 From: Matt F Date: Thu, 10 Oct 2024 23:10:07 -0400 Subject: [PATCH] Adds type hinting Attempts to add type hints as much as possible, without breaking usage. However, some methods may deserve refactoring if the intended behavior differed from what mypy found. (e.g. Returning None when a boolean was desired) --- pyproject.toml | 5 +- src/pytest_check/__init__.py | 11 +-- src/pytest_check/check_functions.py | 103 ++++++++++++++++++++------- src/pytest_check/check_log.py | 24 ++++--- src/pytest_check/check_raises.py | 34 ++++++--- src/pytest_check/context_manager.py | 32 ++++++--- src/pytest_check/plugin.py | 37 ++++++---- src/pytest_check/pseudo_traceback.py | 33 ++++++--- src/pytest_check/py.typed | 0 tox.ini | 36 +++++++--- 10 files changed, 220 insertions(+), 95 deletions(-) create mode 100644 src/pytest_check/py.typed diff --git a/pyproject.toml b/pyproject.toml index 2858030..6cec4cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ classifiers = [ 'Topic :: Software Development :: Testing', 'Topic :: Utilities', ] -dependencies = ["pytest>=7.0.0"] +dependencies = [ + "pytest >= 7.0.0", + "typing-extensions >= 4.12.2, < 5; python_version < '3.11'" +] [project.urls] Home = "https://github.com/okken/pytest-check" diff --git a/src/pytest_check/__init__.py b/src/pytest_check/__init__.py index 210fb19..5c0b4bf 100644 --- a/src/pytest_check/__init__.py +++ b/src/pytest_check/__init__.py @@ -1,4 +1,5 @@ import pytest +from . import check_functions # make sure assert rewriting happens pytest.register_assert_rewrite("pytest_check.check_functions") @@ -6,10 +7,10 @@ # allow for top level helper function access: # import pytest_check # pytest_check.equal(1, 1) -from pytest_check.check_functions import * # noqa: F401, F402, F403, E402 +from .check_functions import * # noqa: F401, F402, F403, E402 # allow to know if any_failures due to any previous check -from pytest_check.check_log import any_failures # noqa: F401, F402, F403, E402 +from .check_log import any_failures # noqa: F401, F402, F403, E402 # allow top level raises: # from pytest_check import raises @@ -17,13 +18,13 @@ # raise Exception # with raises(AssertionError): # assert 0 -from pytest_check.check_raises import raises # noqa: F401, F402, F403, E402 +from .check_raises import raises # noqa: F401, F402, F403, E402 # allow for with blocks and assert: # from pytest_check import check # with check: # assert 1 == 2 -from pytest_check.context_manager import check # noqa: F401, F402, F403, E402 +from .context_manager import check # noqa: F401, F402, F403, E402 # allow check.raises() setattr(check, "raises", raises) @@ -33,7 +34,7 @@ # allow check.check as a context manager. # weird, but some people are doing it. -# decprecate this eventually +# deprecate this eventually setattr(check, "check", check) # allow for helper functions to be part of check context diff --git a/src/pytest_check/check_functions.py b/src/pytest_check/check_functions.py index 344e121..3ac35cc 100644 --- a/src/pytest_check/check_functions.py +++ b/src/pytest_check/check_functions.py @@ -1,4 +1,20 @@ +from __future__ import annotations import functools +import sys +from typing import ( + Any, + Callable, + Container, + Protocol, + SupportsFloat, + SupportsIndex, + TypeVar, + Union, +) +if sys.version_info < (3, 10): # pragma: no cover + from typing_extensions import ParamSpec +else: + from typing import ParamSpec import pytest import math @@ -33,12 +49,31 @@ ] -def check_func(func): +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class _ComparableGreaterThan(Protocol): + def __gt__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableGreaterThanOrEqual(Protocol): + def __ge__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableLessThan(Protocol): + def __lt__(self, other: Any) -> bool: ... # pragma: no cover + + +class _ComparableLessThanOrEqual(Protocol): + def __le__(self, other: Any) -> bool: ... # pragma: no cover + + +def check_func(func: Callable[_P, _T]) -> Callable[_P, bool]: @functools.wraps(func) - def wrapper(*args, **kwds): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool: __tracebackhide__ = True try: - func(*args, **kwds) + func(*args, **kwargs) return True except AssertionError as e: log_failure(e) @@ -47,11 +82,11 @@ def wrapper(*args, **kwds): return wrapper -def assert_equal(a, b, msg=""): # pragma: no cover +def assert_equal(a: object, b: object, msg: str = "") -> None: # pragma: no cover assert a == b, msg -def equal(a, b, msg=""): +def equal(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a == b: return True @@ -60,7 +95,7 @@ def equal(a, b, msg=""): return False -def not_equal(a, b, msg=""): +def not_equal(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a != b: return True @@ -69,7 +104,7 @@ def not_equal(a, b, msg=""): return False -def is_(a, b, msg=""): +def is_(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a is b: return True @@ -78,7 +113,7 @@ def is_(a, b, msg=""): return False -def is_not(a, b, msg=""): +def is_not(a: object, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a is not b: return True @@ -87,7 +122,7 @@ def is_not(a, b, msg=""): return False -def is_true(x, msg=""): +def is_true(x: object, msg: str = "") -> bool: __tracebackhide__ = True if bool(x): return True @@ -96,7 +131,7 @@ def is_true(x, msg=""): return False -def is_false(x, msg=""): +def is_false(x: object, msg: str = "") -> bool: __tracebackhide__ = True if not bool(x): return True @@ -105,7 +140,7 @@ def is_false(x, msg=""): return False -def is_none(x, msg=""): +def is_none(x: object, msg: str = "") -> bool: __tracebackhide__ = True if x is None: return True @@ -114,7 +149,7 @@ def is_none(x, msg=""): return False -def is_not_none(x, msg=""): +def is_not_none(x: object, msg: str = "") -> bool: __tracebackhide__ = True if x is not None: return True @@ -123,7 +158,7 @@ def is_not_none(x, msg=""): return False -def is_nan(a, msg=""): +def is_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool: __tracebackhide__ = True if math.isnan(a): return True @@ -132,7 +167,7 @@ def is_nan(a, msg=""): return False -def is_not_nan(a, msg=""): +def is_not_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool: __tracebackhide__ = True if not math.isnan(a): return True @@ -140,7 +175,8 @@ def is_not_nan(a, msg=""): log_failure(f"check {a} is not NaN", msg) return False -def is_in(a, b, msg=""): + +def is_in(a: _T, b: Container[_T], msg: str = "") -> bool: __tracebackhide__ = True if a in b: return True @@ -149,7 +185,7 @@ def is_in(a, b, msg=""): return False -def is_not_in(a, b, msg=""): +def is_not_in(a: _T, b: Container[_T], msg: str = "") -> bool: __tracebackhide__ = True if a not in b: return True @@ -158,7 +194,9 @@ def is_not_in(a, b, msg=""): return False -def is_instance(a, b, msg=""): +_TypeTuple = Union[type, tuple['_TypeTuple', ...]] + +def is_instance(a: object, b: _TypeTuple, msg: str = "") -> bool: __tracebackhide__ = True if isinstance(a, b): return True @@ -167,7 +205,7 @@ def is_instance(a, b, msg=""): return False -def is_not_instance(a, b, msg=""): +def is_not_instance(a: object, b: _TypeTuple, msg: str = "") -> bool: __tracebackhide__ = True if not isinstance(a, b): return True @@ -176,7 +214,9 @@ def is_not_instance(a, b, msg=""): return False -def almost_equal(a, b, rel=None, abs=None, msg=""): +def almost_equal( + a: object, b: object, rel: Any = None, abs: Any = None, msg: str = "" +) -> bool: """ For rel and abs tolerance, see: See https://docs.pytest.org/en/latest/builtin.html#pytest.approx @@ -189,7 +229,9 @@ def almost_equal(a, b, rel=None, abs=None, msg=""): return False -def not_almost_equal(a, b, rel=None, abs=None, msg=""): +def not_almost_equal( + a: object, b: object, rel: Any = None, abs: Any = None, msg: str = "" +) -> bool: """ For rel and abs tolerance, see: See https://docs.pytest.org/en/latest/builtin.html#pytest.approx @@ -202,7 +244,7 @@ def not_almost_equal(a, b, rel=None, abs=None, msg=""): return False -def greater(a, b, msg=""): +def greater(a: _ComparableGreaterThan, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a > b: return True @@ -211,7 +253,7 @@ def greater(a, b, msg=""): return False -def greater_equal(a, b, msg=""): +def greater_equal(a: _ComparableGreaterThanOrEqual, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a >= b: return True @@ -220,7 +262,7 @@ def greater_equal(a, b, msg=""): return False -def less(a, b, msg=""): +def less(a: _ComparableLessThan, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a < b: return True @@ -229,7 +271,7 @@ def less(a, b, msg=""): return False -def less_equal(a, b, msg=""): +def less_equal(a: _ComparableLessThanOrEqual, b: object, msg: str = "") -> bool: __tracebackhide__ = True if a <= b: return True @@ -238,7 +280,9 @@ def less_equal(a, b, msg=""): return False -def between(b, a, c, msg="", ge=False, le=False): +def between( + b: Any, a: Any, c: Any, msg: str = "", ge: bool = False, le: bool = False +) -> bool: __tracebackhide__ = True if ge and le: if a <= b <= c: @@ -266,11 +310,16 @@ def between(b, a, c, msg="", ge=False, le=False): return False -def between_equal(b, a, c, msg=""): +def between_equal( + b: _ComparableLessThanOrEqual, + a: _ComparableLessThanOrEqual, + c: object, + msg:str = "", +) -> bool: __tracebackhide__ = True return between(b, a, c, msg, ge=True, le=True) -def fail(msg): +def fail(msg: str) -> None: __tracebackhide__ = True log_failure(msg) diff --git a/src/pytest_check/check_log.py b/src/pytest_check/check_log.py index ca5a1b0..84fad43 100644 --- a/src/pytest_check/check_log.py +++ b/src/pytest_check/check_log.py @@ -1,25 +1,29 @@ +from __future__ import annotations +from collections.abc import Iterable +from typing import Callable + from .pseudo_traceback import _build_pseudo_trace_str should_use_color = False COLOR_RED = "\x1b[31m" COLOR_RESET = "\x1b[0m" -_failures = [] +_failures: list[str] = [] _stop_on_fail = False _default_max_fail = None _default_max_report = None _default_max_tb = 1 -_max_fail = _default_max_fail -_max_report = _default_max_report +_max_fail: int | None = _default_max_fail +_max_report: int | None = _default_max_report _max_tb = _default_max_tb _num_failures = 0 -_fail_function = None +_fail_function: Callable[[str], None] | None = None _showlocals = False -def clear_failures(): - # get's called at the beginning of each test function +def clear_failures() -> None: + # gets called at the beginning of each test function global _failures, _num_failures global _max_fail, _max_report, _max_tb _failures = [] @@ -33,11 +37,13 @@ def any_failures() -> bool: return bool(get_failures()) -def get_failures(): +def get_failures() -> list[str]: return _failures -def log_failure(msg="", check_str="", tb=None): +def log_failure( + msg: object = "", check_str: str = "", tb: Iterable[str] | None = None +) -> None: global _num_failures __tracebackhide__ = True _num_failures += 1 @@ -60,7 +66,7 @@ def log_failure(msg="", check_str="", tb=None): msg = f"FAILURE: {msg}" _failures.append(msg) if _fail_function: - _fail_function(msg) + _fail_function(str(msg)) if _max_fail and (_num_failures >= _max_fail): assert_msg = f"pytest-check max fail of {_num_failures} reached" diff --git a/src/pytest_check/check_raises.py b/src/pytest_check/check_raises.py index bd301a3..0eca41c 100644 --- a/src/pytest_check/check_raises.py +++ b/src/pytest_check/check_raises.py @@ -1,9 +1,18 @@ +from __future__ import annotations +from typing import Iterable, Any + from .check_log import log_failure _stop_on_fail = False -def raises(expected_exception, *args, **kwargs): +# TODO: Returning Any isn't ideal, but returning CheckRaisesContext | None +# would require callers to type ignore or declare the type when using `with`. +# Or, it could always return CheckRaisesContext, just an empty one after +# calling the passed function. +def raises( + expected_exception: type | Iterable[type], *args: Any, **kwargs: object +) -> Any: """ Check that a given callable or context raises an error of a given type. @@ -39,23 +48,23 @@ def raises(expected_exception, *args, **kwargs): __tracebackhide__ = True if isinstance(expected_exception, type): - excepted_exceptions = (expected_exception,) + expected_exceptions: Iterable[type] = (expected_exception,) else: - excepted_exceptions = expected_exception + expected_exceptions = expected_exception assert all( isinstance(exc, type) or issubclass(exc, BaseException) - for exc in excepted_exceptions + for exc in expected_exceptions ) msg = kwargs.pop("msg", None) if not args: assert not kwargs, f"Unexpected kwargs for pytest_check.raises: {kwargs}" - return CheckRaisesContext(expected_exception, msg=msg) + return CheckRaisesContext(*expected_exceptions, msg=msg) else: func = args[0] assert callable(func) - with CheckRaisesContext(expected_exception, msg=msg): + with CheckRaisesContext(*expected_exceptions, msg=msg): func(*args[1:], **kwargs) @@ -69,18 +78,18 @@ class CheckRaisesContext: CheckContextManager. """ - def __init__(self, *expected_excs, msg=None): + def __init__(self, *expected_excs: type, msg: object = None) -> None: self.expected_excs = expected_excs self.msg = msg - def __enter__(self): + def __enter__(self) -> "CheckRaisesContext": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: type, exc_val: object, exc_tb: object) -> bool: __tracebackhide__ = True if exc_type is not None and issubclass(exc_type, self.expected_excs): - # This is the case where an error has occured within the context - # but it is the type we're expecting. Therefore we return True + # This is the case where an error has occured within the context, + # but it is the type we're expecting. Therefore, we return True # to silence this error and proceed with execution outside the # context. return True @@ -97,3 +106,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): # without raising an error, hence `return True`. log_failure(self.msg if self.msg else exc_val) return True + + # Stop on fail, so return True + return False diff --git a/src/pytest_check/context_manager.py b/src/pytest_check/context_manager.py index 4f19877..363a682 100644 --- a/src/pytest_check/context_manager.py +++ b/src/pytest_check/context_manager.py @@ -1,5 +1,9 @@ +from __future__ import annotations import warnings import traceback +from types import TracebackType +from typing import Callable, Type + from . import check_log from .check_log import log_failure @@ -12,18 +16,23 @@ class CheckContextManager: - def __init__(self): - self.msg = None + def __init__(self) -> None: + self.msg: object = None - def __enter__(self): + def __enter__(self) -> "CheckContextManager": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: __tracebackhide__ = True if exc_type is not None and issubclass(exc_type, AssertionError): if _stop_on_fail: self.msg = None - return + return None else: fmt_tb = traceback.format_exception(exc_type, exc_val, exc_tb) if self.msg is not None: @@ -33,27 +42,28 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.msg = None return True self.msg = None + return None - def __call__(self, msg=None): + def __call__(self, msg: object = None) -> "CheckContextManager": self.msg = msg return self - def set_no_tb(self): + def set_no_tb(self) -> None: warnings.warn( "set_no_tb() is deprecated; use set_max_tb(0)", DeprecationWarning ) check_log._max_tb = 0 - def set_max_fail(self, x): + def set_max_fail(self, x: int) -> None: check_log._max_fail = x - def set_max_report(self, x): + def set_max_report(self, x: int) -> None: check_log._max_report = x - def set_max_tb(self, x): + def set_max_tb(self, x: int) -> None: check_log._max_tb = x - def call_on_fail(self, func): + def call_on_fail(self, func: Callable[[str], None]) -> None: """Experimental feature - may change with any release""" check_log._fail_function = func diff --git a/src/pytest_check/plugin.py b/src/pytest_check/plugin.py index 659a446..d2e4752 100644 --- a/src/pytest_check/plugin.py +++ b/src/pytest_check/plugin.py @@ -1,28 +1,38 @@ import sys import os +from typing import Generator import pytest -from _pytest._code.code import ExceptionInfo +from pytest import CallInfo, Config, Item, Parser, TestReport from _pytest.skipping import xfailed_key -from _pytest.reports import ExceptionChainRepr -from _pytest._code.code import ExceptionRepr, ReprFileLocation +from _pytest._code.code import ( + ExceptionChainRepr, + ExceptionInfo, + ExceptionRepr, + ReprFileLocation, +) +from pluggy import Result from . import check_log, check_raises, context_manager, pseudo_traceback +from .context_manager import CheckContextManager @pytest.hookimpl(hookwrapper=True, trylast=True) -def pytest_runtest_makereport(item, call): - outcome = yield - report = outcome.get_result() +def pytest_runtest_makereport( + item: Item, call: CallInfo[None] +) -> Generator[None, Result[TestReport], None]: + outcome: Result[TestReport] = yield + report: TestReport = outcome.get_result() num_failures = check_log._num_failures failures = check_log.get_failures() check_log.clear_failures() if failures: - if item._store[xfailed_key]: + xfailed_value = item._store[xfailed_key] + if xfailed_value: report.outcome = "skipped" - report.wasxfail = item._store[xfailed_key].reason + report.wasxfail = xfailed_value.reason else: summary = f"Failed Checks: {num_failures}" @@ -57,8 +67,9 @@ def pytest_runtest_makereport(item, call): e_str = str(e) e_str = e_str.split('FAILURE: ')[1] # Remove redundant "Failure: " reprcrash = ReprFileLocation(item.nodeid, 0, e_str) - reprtraceback = ExceptionRepr(reprcrash, excinfo) - chain_repr = ExceptionChainRepr([(reprtraceback, reprcrash, str(e))]) + # FIXME - the next two lines have broken types + reprtraceback = ExceptionRepr(reprcrash, excinfo) # type: ignore + chain_repr = ExceptionChainRepr([(reprtraceback, reprcrash, str(e))]) # type: ignore report.longrepr = chain_repr else: # pragma: no cover # coverage is run on latest pytest @@ -69,7 +80,7 @@ def pytest_runtest_makereport(item, call): call.excinfo = excinfo -def pytest_configure(config): +def pytest_configure(config: Config) -> None: # Add some red to the failure output, if stdout can accommodate it. isatty = sys.stdout.isatty() color = getattr(config.option, "color", None) @@ -100,12 +111,12 @@ def pytest_configure(config): # def test_a(check): # check.equal(a, b) @pytest.fixture(name="check") -def check_fixture(): +def check_fixture() -> CheckContextManager: return context_manager.check # add some options -def pytest_addoption(parser): +def pytest_addoption(parser: Parser) -> None: parser.addoption( "--check-max-report", action="store", diff --git a/src/pytest_check/pseudo_traceback.py b/src/pytest_check/pseudo_traceback.py index ccbf1d8..e631904 100644 --- a/src/pytest_check/pseudo_traceback.py +++ b/src/pytest_check/pseudo_traceback.py @@ -1,15 +1,28 @@ +from __future__ import annotations import inspect import os import re +import sys +from collections.abc import Iterable +from inspect import FrameInfo from pprint import pformat +from typing import AnyStr, Any + +if sys.version_info < (3, 11): # pragma: no cover + from typing_extensions import LiteralString +else: + from typing import LiteralString + _traceback_style = "auto" -def get_full_context(frame): +def get_full_context( + frame: FrameInfo +) -> tuple[AnyStr | LiteralString, Any, Any, str, Any, bool]: (_, filename, line, funcname, contextlist) = frame[0:5] - locals = frame.frame.f_locals - tb_hide = locals.get("__tracebackhide__", False) + locals_ = frame.frame.f_locals + tb_hide = locals_.get("__tracebackhide__", False) try: filename = os.path.relpath(filename) except ValueError: # pragma: no cover @@ -21,13 +34,13 @@ def get_full_context(frame): # But.... we'll keep looking for a way to test it. :) filename = os.path.abspath(filename) # pragma: no cover context = contextlist[0].strip() if contextlist else "" - return (filename, line, funcname, context, locals, tb_hide) + return filename, line, funcname, context, locals_, tb_hide COLOR_RED = "\x1b[31m" COLOR_RESET = "\x1b[0m" -def reformat_raw_traceback(lines, color): - formatted = [] +def reformat_raw_traceback(lines: Iterable[str], color: bool) -> str: + formatted: list[str] = [] for line in lines: if 'Traceback (most recent call last)' in line: continue @@ -49,12 +62,14 @@ def reformat_raw_traceback(lines, color): # I don't have a test case to hit this clause yet # And I can't think of one. # But it feels weird to not have the if/else. - # Thus the "no cover" + # Thus, the "no cover" formatted.append(line) # pragma: no cover return '\n'.join(formatted) -def _build_pseudo_trace_str(showlocals, tb, color): +def _build_pseudo_trace_str( + showlocals: bool, tb: Iterable[str] | None, color: bool +) -> str: """ built traceback styles for better error message only supports no @@ -76,7 +91,7 @@ def _build_pseudo_trace_str(showlocals, tb, color): # we want to trace through user code, not 3rd party or builtin libs if "site-packages" in file: break - # if called outside of a test, we might hit this + # if called outside a test, we might hit this if "" in func: break if tb_hide: diff --git a/src/pytest_check/py.typed b/src/pytest_check/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tox.ini b/tox.ini index 01c6af7..64c9386 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,14 @@ [tox] # Environments to run by default -env_list = py39, py310, py311, py312, pytest_earliest, coverage, lint +env_list = + py39 + py310 + py311 + py312 + pytest_earliest + coverage + lint + mypy_earliest skip_missing_interpreters = true @@ -12,24 +20,34 @@ wheel_build_env = .pkg [testenv:coverage] deps = coverage -basepython = python3.12 +base_python = python3.12 commands = - coverage run --source={envsitepackagesdir}/pytest_check,tests -m pytest + coverage run --source={envsitepackagesdir}/pytest_check,tests -m pytest coverage report --fail-under=100 --show-missing description = Run pytest, with coverage [testenv:pytest_earliest] deps = pytest==7.0.0 -basepython = python3.11 +base_python = python3.11 commands = pytest {posargs} description = Run earliest supported pytest [testenv:lint] -skip_install = true -deps = ruff -basepython = python3.12 -commands = ruff check src tests examples -description = Run ruff over src, test, exampless +deps = + mypy + ruff +base_python = python3.12 +commands = + ruff check src tests examples + mypy --strict --pretty src + mypy --pretty tests +description = Run ruff and mypy over src, test, examples + +[testenv:mypy_earliest] +deps = mypy +base_python = python3.9 +commands = mypy --strict --pretty src +description = Run mypy over src for earliest supported python [pytest] addopts =