diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 47c9de7..6a24aee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,12 +8,12 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/psf/black - rev: "23.9.1" + rev: "23.10.0" hooks: - id: black language_version: python3 - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.1 hooks: - id: ruff args: ["--fix"] @@ -29,3 +29,9 @@ repos: hooks: - id: taplo language_version: stable + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.6.1 + hooks: + - id: mypy + additional_dependencies: + - pytest diff --git a/.vscode/settings.json b/.vscode/settings.json index 946aed9..b538f66 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,13 +1,13 @@ { "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": [], + "python.testing.pytestArgs": ["--color=yes"], "[python]": { "editor.formatOnSave": true, "editor.defaultFormatter": "ms-python.black-formatter", "editor.codeActionsOnSave": { - "source.fixAll.ruff": true, - "source.organizeImports.ruff": true, + "source.fixAll": true, + "source.organizeImports": true, }, }, "[toml]": { diff --git a/pyproject.toml b/pyproject.toml index 3176476..55b78f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ ignore = [ ] allowed-confusables = ["’"] [tool.ruff.isort] - known-first-party = ["legacy_api_wrap"] +required-imports = ["from __future__ import annotations"] [tool.ruff.extend-per-file-ignores] "src/testing/*.py" = ["INP001"] "tests/**/test_*.py" = [ @@ -45,6 +45,11 @@ known-first-party = ["legacy_api_wrap"] "S101", # tests use `assert` ] +[tool.mypy] +strict = true +explicit_package_bases = true +mypy_path = "src" + [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", diff --git a/src/legacy_api_wrap/__init__.py b/src/legacy_api_wrap/__init__.py index e569116..3b75edd 100644 --- a/src/legacy_api_wrap/__init__.py +++ b/src/legacy_api_wrap/__init__.py @@ -10,13 +10,15 @@ from __future__ import annotations +import sys from functools import wraps from inspect import Parameter, signature -from typing import TYPE_CHECKING, Callable, TypeVar +from typing import TYPE_CHECKING from warnings import warn if TYPE_CHECKING: - from typing import ParamSpec + from collections.abc import Callable + from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -25,6 +27,8 @@ POS_TYPES = {Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD} +# The actual returned Callable of course accepts more positional parameters, +# but we want the type to lie so end users don’t rely on the deprecated API. def legacy_api(*old_positionals: str) -> Callable[[Callable[P, R]], Callable[P, R]]: """Legacy API wrapper. @@ -57,30 +61,33 @@ def wrapper(fn: Callable[P, R]) -> Callable[P, R]: par_types = [p.kind for p in sig.parameters.values()] has_var = Parameter.VAR_POSITIONAL in par_types n_required = sum(1 for p in sig.parameters.values() if p.default is Parameter.empty) - n_positional = INF if has_var else sum(1 for p in par_types if p in POS_TYPES) + n_positional = sys.maxsize if has_var else sum(1 for p in par_types if p in POS_TYPES) @wraps(fn) - def fn_compatible(*args: P.args, **kw: P.kwargs) -> R: - if len(args) > n_positional: - args, args_rest = args[:n_positional], args[n_positional:] - if args_rest: - if len(args_rest) > len(old_positionals): - n_max = n_positional + len(old_positionals) - msg = ( - f"{fn.__name__}() takes from {n_required} to {n_max} parameters, " - f"but {len(args) + len(args_rest)} were given." - ) - raise TypeError(msg) - warn( - f"The specified parameters {old_positionals[:len(args_rest)]!r} are " - "no longer positional. " - f"Please specify them like `{old_positionals[0]}={args_rest[0]!r}`", - DeprecationWarning, - stacklevel=2, - ) - kw = {**kw, **dict(zip(old_positionals, args_rest))} - - return fn(*args, **kw) + def fn_compatible(*args_all: P.args, **kw: P.kwargs) -> R: + if len(args_all) <= n_positional: + return fn(*args_all, **kw) + + args_pos: P.args + args_pos, args_rest = args_all[:n_positional], args_all[n_positional:] + + if len(args_rest) > len(old_positionals): + n_max = n_positional + len(old_positionals) + msg = ( + f"{fn.__name__}() takes from {n_required} to {n_max} parameters, " + f"but {len(args_pos) + len(args_rest)} were given." + ) + raise TypeError(msg) + warn( + f"The specified parameters {old_positionals[:len(args_rest)]!r} are " + "no longer positional. " + f"Please specify them like `{old_positionals[0]}={args_rest[0]!r}`", + DeprecationWarning, + stacklevel=2, + ) + kw_new: P.kwargs = {**kw, **dict(zip(old_positionals, args_rest))} + + return fn(*args_pos, **kw_new) return fn_compatible diff --git a/src/testing/legacy_api_wrap/py.typed b/src/testing/legacy_api_wrap/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/testing/legacy_api_wrap/pytest.py b/src/testing/legacy_api_wrap/pytest.py index 6c45acd..d71d41d 100644 --- a/src/testing/legacy_api_wrap/pytest.py +++ b/src/testing/legacy_api_wrap/pytest.py @@ -1,15 +1,21 @@ """Pytest plugin for legacy_api_wrap.""" +from __future__ import annotations + import sys import warnings +from typing import TYPE_CHECKING import pytest +if TYPE_CHECKING: + from collections.abc import Generator + __all__ = ["_doctest_env", "pytest_itemcollected"] @pytest.fixture() -def _doctest_env() -> None: +def _doctest_env() -> Generator[None, None, None]: """Pytest fixture to make doctests not error on expected warnings.""" sys.stderr, stderr_orig = sys.stdout, sys.stderr with warnings.catch_warnings(): diff --git a/tests/test_basic.py b/tests/test_basic.py index 6058f77..0d18e40 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -10,7 +10,7 @@ # def old(a, b=None, d=1, c=2): # pass @legacy_api("d", "c") -def new(a, b=None, *, c=2, d=1, e=3): # noqa: ANN001, ANN201 +def new(a, b=None, *, c=2, d=1, e=3): # type: ignore[no-untyped-def] # noqa: ANN001, ANN201 return {"a": a, "b": b, "c": c, "d": d, "e": e} @@ -24,13 +24,13 @@ def test_new_param_available() -> None: def test_old_positional_order() -> None: with pytest.deprecated_call(): - res = new(12, 13, 14) + res = new(12, 13, 14) # type: ignore[misc] assert res["d"] == 14 def test_warning_stack() -> None: with pytest.deprecated_call() as record: - new(12, 13, 14) + new(12, 13, 14) # type: ignore[misc] w = record.pop() assert w.filename == __file__ @@ -40,4 +40,4 @@ def test_too_many_args() -> None: TypeError, match=r"new\(\) takes from 1 to 4 parameters, but 5 were given\.", ): - new(1, 2, 3, 4, 5) + new(1, 2, 3, 4, 5) # type: ignore[misc]