Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds type hinting; drops python 3.8 #169

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@ readme = "README.md"
license = {file = "LICENSE.txt"}
description="A pytest plugin that allows multiple failures per test."
version = "2.4.1"
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"License :: OSI Approved :: MIT License",
"Framework :: Pytest" ,
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Software Development :: Testing',
'Topic :: Utilities',
]
dependencies = ["pytest>=7.0.0"]
dependencies = [
"pytest >= 7.0.0",
"typing-extensions >= 4.12.2, < 5; python_version < '3.11'"
]

[project.urls]
Home = "https://github.com/okken/pytest-check"
Expand Down
11 changes: 6 additions & 5 deletions src/pytest_check/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import pytest
from . import check_functions

# make sure assert rewriting happens
pytest.register_assert_rewrite("pytest_check.check_functions")

# allow for top level helper function access:
# import pytest_check
# pytest_check.equal(1, 1)
from pytest_check.check_functions import * # noqa: F401, F402, F403, E402
from .check_functions import * # noqa: F401, F402, F403, E402

# allow to know if any_failures due to any previous check
from pytest_check.check_log import any_failures # noqa: F401, F402, F403, E402
from .check_log import any_failures # noqa: F401, F402, F403, E402

# allow top level raises:
# from pytest_check import raises
# with raises(Exception):
# raise Exception
# with raises(AssertionError):
# assert 0
from pytest_check.check_raises import raises # noqa: F401, F402, F403, E402
from .check_raises import raises # noqa: F401, F402, F403, E402

# allow for with blocks and assert:
# from pytest_check import check
# with check:
# assert 1 == 2
from pytest_check.context_manager import check # noqa: F401, F402, F403, E402
from .context_manager import check # noqa: F401, F402, F403, E402

# allow check.raises()
setattr(check, "raises", raises)
Expand All @@ -33,7 +34,7 @@

# allow check.check as a context manager.
# weird, but some people are doing it.
# decprecate this eventually
# deprecate this eventually
setattr(check, "check", check)

# allow for helper functions to be part of check context
Expand Down
103 changes: 76 additions & 27 deletions src/pytest_check/check_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
from __future__ import annotations
import functools
import sys
from typing import (
Any,
Callable,
Container,
Protocol,
SupportsFloat,
SupportsIndex,
TypeVar,
Union,
)
if sys.version_info < (3, 10): # pragma: no cover
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

import pytest
import math
Expand Down Expand Up @@ -33,12 +49,31 @@
]


def check_func(func):
_P = ParamSpec("_P")
_T = TypeVar("_T")

class _ComparableGreaterThan(Protocol):
def __gt__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableGreaterThanOrEqual(Protocol):
def __ge__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableLessThan(Protocol):
def __lt__(self, other: Any) -> bool: ... # pragma: no cover


class _ComparableLessThanOrEqual(Protocol):
def __le__(self, other: Any) -> bool: ... # pragma: no cover


def check_func(func: Callable[_P, _T]) -> Callable[_P, bool]:
@functools.wraps(func)
def wrapper(*args, **kwds):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> bool:
__tracebackhide__ = True
try:
func(*args, **kwds)
func(*args, **kwargs)
return True
except AssertionError as e:
log_failure(e)
Expand All @@ -47,11 +82,11 @@ def wrapper(*args, **kwds):
return wrapper


def assert_equal(a, b, msg=""): # pragma: no cover
def assert_equal(a: object, b: object, msg: str = "") -> None: # pragma: no cover
assert a == b, msg


def equal(a, b, msg=""):
def equal(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a == b:
return True
Expand All @@ -60,7 +95,7 @@ def equal(a, b, msg=""):
return False


def not_equal(a, b, msg=""):
def not_equal(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a != b:
return True
Expand All @@ -69,7 +104,7 @@ def not_equal(a, b, msg=""):
return False


def is_(a, b, msg=""):
def is_(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a is b:
return True
Expand All @@ -78,7 +113,7 @@ def is_(a, b, msg=""):
return False


def is_not(a, b, msg=""):
def is_not(a: object, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a is not b:
return True
Expand All @@ -87,7 +122,7 @@ def is_not(a, b, msg=""):
return False


def is_true(x, msg=""):
def is_true(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if bool(x):
return True
Expand All @@ -96,7 +131,7 @@ def is_true(x, msg=""):
return False


def is_false(x, msg=""):
def is_false(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if not bool(x):
return True
Expand All @@ -105,7 +140,7 @@ def is_false(x, msg=""):
return False


def is_none(x, msg=""):
def is_none(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if x is None:
return True
Expand All @@ -114,7 +149,7 @@ def is_none(x, msg=""):
return False


def is_not_none(x, msg=""):
def is_not_none(x: object, msg: str = "") -> bool:
__tracebackhide__ = True
if x is not None:
return True
Expand All @@ -123,7 +158,7 @@ def is_not_none(x, msg=""):
return False


def is_nan(a, msg=""):
def is_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool:
__tracebackhide__ = True
if math.isnan(a):
return True
Expand All @@ -132,15 +167,16 @@ def is_nan(a, msg=""):
return False


def is_not_nan(a, msg=""):
def is_not_nan(a: SupportsFloat | SupportsIndex, msg: str = "") -> bool:
__tracebackhide__ = True
if not math.isnan(a):
return True
else:
log_failure(f"check {a} is not NaN", msg)
return False

def is_in(a, b, msg=""):

def is_in(a: _T, b: Container[_T], msg: str = "") -> bool:
__tracebackhide__ = True
if a in b:
return True
Expand All @@ -149,7 +185,7 @@ def is_in(a, b, msg=""):
return False


def is_not_in(a, b, msg=""):
def is_not_in(a: _T, b: Container[_T], msg: str = "") -> bool:
__tracebackhide__ = True
if a not in b:
return True
Expand All @@ -158,7 +194,9 @@ def is_not_in(a, b, msg=""):
return False


def is_instance(a, b, msg=""):
_TypeTuple = Union[type, tuple['_TypeTuple', ...]]

def is_instance(a: object, b: _TypeTuple, msg: str = "") -> bool:
__tracebackhide__ = True
if isinstance(a, b):
return True
Expand All @@ -167,7 +205,7 @@ def is_instance(a, b, msg=""):
return False


def is_not_instance(a, b, msg=""):
def is_not_instance(a: object, b: _TypeTuple, msg: str = "") -> bool:
__tracebackhide__ = True
if not isinstance(a, b):
return True
Expand All @@ -176,7 +214,9 @@ def is_not_instance(a, b, msg=""):
return False


def almost_equal(a, b, rel=None, abs=None, msg=""):
def almost_equal(
a: object, b: object, rel: Any = None, abs: Any = None, msg: str = ""
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nailing down types for rel and abs was proving to be a pain. So... Any it is!

) -> bool:
"""
For rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
Expand All @@ -189,7 +229,9 @@ def almost_equal(a, b, rel=None, abs=None, msg=""):
return False


def not_almost_equal(a, b, rel=None, abs=None, msg=""):
def not_almost_equal(
a: object, b: object, rel: Any = None, abs: Any = None, msg: str = ""
) -> bool:
"""
For rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
Expand All @@ -202,7 +244,7 @@ def not_almost_equal(a, b, rel=None, abs=None, msg=""):
return False


def greater(a, b, msg=""):
def greater(a: _ComparableGreaterThan, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a > b:
return True
Expand All @@ -211,7 +253,7 @@ def greater(a, b, msg=""):
return False


def greater_equal(a, b, msg=""):
def greater_equal(a: _ComparableGreaterThanOrEqual, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a >= b:
return True
Expand All @@ -220,7 +262,7 @@ def greater_equal(a, b, msg=""):
return False


def less(a, b, msg=""):
def less(a: _ComparableLessThan, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a < b:
return True
Expand All @@ -229,7 +271,7 @@ def less(a, b, msg=""):
return False


def less_equal(a, b, msg=""):
def less_equal(a: _ComparableLessThanOrEqual, b: object, msg: str = "") -> bool:
__tracebackhide__ = True
if a <= b:
return True
Expand All @@ -238,7 +280,9 @@ def less_equal(a, b, msg=""):
return False


def between(b, a, c, msg="", ge=False, le=False):
def between(
b: Any, a: Any, c: Any, msg: str = "", ge: bool = False, le: bool = False
) -> bool:
__tracebackhide__ = True
if ge and le:
if a <= b <= c:
Expand Down Expand Up @@ -266,11 +310,16 @@ def between(b, a, c, msg="", ge=False, le=False):
return False


def between_equal(b, a, c, msg=""):
def between_equal(
b: _ComparableLessThanOrEqual,
a: _ComparableLessThanOrEqual,
c: object,
msg:str = "",
) -> bool:
__tracebackhide__ = True
return between(b, a, c, msg, ge=True, le=True)


def fail(msg):
def fail(msg: str) -> None:
__tracebackhide__ = True
log_failure(msg)
Loading