Skip to content

Commit

Permalink
strict typing (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Oct 20, 2023
1 parent c8ca42b commit b73f6f2
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 35 deletions.
10 changes: 8 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: "23.9.1"
rev: "23.10.0"
hooks:
- id: black
language_version: python3
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.1
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -29,3 +29,9 @@ repos:
hooks:
- id: taplo
language_version: stable
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
hooks:
- id: mypy
additional_dependencies:
- pytest
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.testing.pytestArgs": [],
"python.testing.pytestArgs": ["--color=yes"],
"[python]": {
"editor.formatOnSave": true,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.codeActionsOnSave": {
"source.fixAll.ruff": true,
"source.organizeImports.ruff": true,
"source.fixAll": true,
"source.organizeImports": true,
},
},
"[toml]": {
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ ignore = [
]
allowed-confusables = [""]
[tool.ruff.isort]

known-first-party = ["legacy_api_wrap"]
required-imports = ["from __future__ import annotations"]
[tool.ruff.extend-per-file-ignores]
"src/testing/*.py" = ["INP001"]
"tests/**/test_*.py" = [
Expand All @@ -45,6 +45,11 @@ known-first-party = ["legacy_api_wrap"]
"S101", # tests use `assert`
]

[tool.mypy]
strict = true
explicit_package_bases = true
mypy_path = "src"

[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
Expand Down
55 changes: 31 additions & 24 deletions src/legacy_api_wrap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from __future__ import annotations

import sys
from functools import wraps
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING
from warnings import warn

if TYPE_CHECKING:
from typing import ParamSpec
from collections.abc import Callable
from typing import ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")
Expand All @@ -25,6 +27,8 @@
POS_TYPES = {Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD}


# The actual returned Callable of course accepts more positional parameters,
# but we want the type to lie so end users don’t rely on the deprecated API.
def legacy_api(*old_positionals: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Legacy API wrapper.
Expand Down Expand Up @@ -57,30 +61,33 @@ def wrapper(fn: Callable[P, R]) -> Callable[P, R]:
par_types = [p.kind for p in sig.parameters.values()]
has_var = Parameter.VAR_POSITIONAL in par_types
n_required = sum(1 for p in sig.parameters.values() if p.default is Parameter.empty)
n_positional = INF if has_var else sum(1 for p in par_types if p in POS_TYPES)
n_positional = sys.maxsize if has_var else sum(1 for p in par_types if p in POS_TYPES)

@wraps(fn)
def fn_compatible(*args: P.args, **kw: P.kwargs) -> R:
if len(args) > n_positional:
args, args_rest = args[:n_positional], args[n_positional:]
if args_rest:
if len(args_rest) > len(old_positionals):
n_max = n_positional + len(old_positionals)
msg = (
f"{fn.__name__}() takes from {n_required} to {n_max} parameters, "
f"but {len(args) + len(args_rest)} were given."
)
raise TypeError(msg)
warn(
f"The specified parameters {old_positionals[:len(args_rest)]!r} are "
"no longer positional. "
f"Please specify them like `{old_positionals[0]}={args_rest[0]!r}`",
DeprecationWarning,
stacklevel=2,
)
kw = {**kw, **dict(zip(old_positionals, args_rest))}

return fn(*args, **kw)
def fn_compatible(*args_all: P.args, **kw: P.kwargs) -> R:
if len(args_all) <= n_positional:
return fn(*args_all, **kw)

args_pos: P.args
args_pos, args_rest = args_all[:n_positional], args_all[n_positional:]

if len(args_rest) > len(old_positionals):
n_max = n_positional + len(old_positionals)
msg = (
f"{fn.__name__}() takes from {n_required} to {n_max} parameters, "
f"but {len(args_pos) + len(args_rest)} were given."
)
raise TypeError(msg)
warn(
f"The specified parameters {old_positionals[:len(args_rest)]!r} are "
"no longer positional. "
f"Please specify them like `{old_positionals[0]}={args_rest[0]!r}`",
DeprecationWarning,
stacklevel=2,
)
kw_new: P.kwargs = {**kw, **dict(zip(old_positionals, args_rest))}

return fn(*args_pos, **kw_new)

return fn_compatible

Expand Down
Empty file.
8 changes: 7 additions & 1 deletion src/testing/legacy_api_wrap/pytest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Pytest plugin for legacy_api_wrap."""

from __future__ import annotations

import sys
import warnings
from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from collections.abc import Generator

__all__ = ["_doctest_env", "pytest_itemcollected"]


@pytest.fixture()
def _doctest_env() -> None:
def _doctest_env() -> Generator[None, None, None]:
"""Pytest fixture to make doctests not error on expected warnings."""
sys.stderr, stderr_orig = sys.stdout, sys.stderr
with warnings.catch_warnings():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# def old(a, b=None, d=1, c=2):
# pass
@legacy_api("d", "c")
def new(a, b=None, *, c=2, d=1, e=3): # noqa: ANN001, ANN201
def new(a, b=None, *, c=2, d=1, e=3): # type: ignore[no-untyped-def] # noqa: ANN001, ANN201
return {"a": a, "b": b, "c": c, "d": d, "e": e}


Expand All @@ -24,13 +24,13 @@ def test_new_param_available() -> None:

def test_old_positional_order() -> None:
with pytest.deprecated_call():
res = new(12, 13, 14)
res = new(12, 13, 14) # type: ignore[misc]
assert res["d"] == 14


def test_warning_stack() -> None:
with pytest.deprecated_call() as record:
new(12, 13, 14)
new(12, 13, 14) # type: ignore[misc]
w = record.pop()
assert w.filename == __file__

Expand All @@ -40,4 +40,4 @@ def test_too_many_args() -> None:
TypeError,
match=r"new\(\) takes from 1 to 4 parameters, but 5 were given\.",
):
new(1, 2, 3, 4, 5)
new(1, 2, 3, 4, 5) # type: ignore[misc]

0 comments on commit b73f6f2

Please sign in to comment.