diff --git a/tests/test_past_issues/test_8_non_protocol_member/__init__.py b/tests/test_past_issues/test_8_non_protocol_member/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_past_issues/test_8_non_protocol_member/input.py b/tests/test_past_issues/test_8_non_protocol_member/input.py new file mode 100644 index 0000000..1ac4e16 --- /dev/null +++ b/tests/test_past_issues/test_8_non_protocol_member/input.py @@ -0,0 +1,19 @@ +from typing import Generic, TypeVar + +from typing_protocol_intersection import ProtocolIntersection + + +class NotProtocol: + pass + + +T = TypeVar("T") + + +class Noop(Generic[T]): + @classmethod + def noop(cls, value: T) -> ProtocolIntersection[T]: + return value + + +np = Noop[NotProtocol].noop(NotProtocol()) diff --git a/tests/test_past_issues/test_8_non_protocol_member/test_8_non_protocol_member.py b/tests/test_past_issues/test_8_non_protocol_member/test_8_non_protocol_member.py new file mode 100644 index 0000000..4505f92 --- /dev/null +++ b/tests/test_past_issues/test_8_non_protocol_member/test_8_non_protocol_member.py @@ -0,0 +1,17 @@ +""" +ref: https://github.com/klausweiss/typing-protocol-intersection/issues/8 +""" +from pathlib import Path + +HERE = Path(__file__).parent + + +def test_8_non_protocol_member(run_mypy): + # given + input_file = HERE / "input.py" + # when + stdout, _stderr = run_mypy(input_file, no_incremental=False) + # then no error + assert not stdout.startswith("Success") + assert "error:" in stdout + assert "Only Protocols can be used in ProtocolIntersection" in stdout diff --git a/typing_protocol_intersection/mypy_plugin.py b/typing_protocol_intersection/mypy_plugin.py index 1b19748..8ffee82 100644 --- a/typing_protocol_intersection/mypy_plugin.py +++ b/typing_protocol_intersection/mypy_plugin.py @@ -18,6 +18,7 @@ from typing_extensions import TypeGuard SignatureContext = typing.Union[mypy.plugin.FunctionSigContext, mypy.plugin.MethodSigContext] +AnyContext = typing.Union[SignatureContext, mypy.plugin.AnalyzeTypeContext] class ProtocolIntersectionPlugin(mypy.plugin.Plugin): @@ -116,6 +117,10 @@ def mk_protocol_intersection_typeinfo( class ProtocolIntersectionResolver: + def __init__(self, context: SignatureContext) -> None: + super().__init__() + self._context = context + def fold_intersection_and_its_args(self, type_: mypy.types.Type) -> mypy.types.Type: folded_type = self.fold_intersection(type_) if isinstance(folded_type, mypy.types.Instance): @@ -142,6 +147,8 @@ def _run_fold(self, type_: mypy.types.Instance, intersection_type_info_wrapper: intersections_to_process.append(arg) continue if isinstance(arg, mypy.types.Instance): + if not arg.type.is_protocol: + _error_non_protocol_member(arg, context=self._context) self._add_type_to_intersection(intersection_type_info_wrapper, arg) return intersection_type_info_wrapper @@ -166,7 +173,7 @@ def _is_intersection(typ: mypy.types.Type) -> TypeGuard[mypy.types.Instance]: def intersection_function_signature_hook(context: SignatureContext) -> mypy.types.FunctionLike: - resolver = ProtocolIntersectionResolver() + resolver = ProtocolIntersectionResolver(context) signature = context.default_signature signature.ret_type = resolver.fold_intersection_and_its_args(signature.ret_type) signature.arg_types = [resolver.fold_intersection_and_its_args(t) for t in signature.arg_types] @@ -182,9 +189,7 @@ def _type_analyze_hook(context: mypy.plugin.AnalyzeTypeContext) -> mypy.types.Ty if arg.type.is_protocol: base_types_of_args.update(arg.type.mro) else: - context.api.fail( - "Only Protocols can be used in ProtocolIntersection.", arg, code=mypy.errorcodes.VALID_TYPE - ) + _error_non_protocol_member(arg, context=context) symbol_table = mypy.nodes.SymbolTable(collections.ChainMap(*(base.names for base in base_types_of_args))) type_info = mk_protocol_intersection_typeinfo( context.type.name, fullname=UniqueFullname(fullname), symbol_table=symbol_table @@ -197,6 +202,10 @@ def _type_analyze_hook(context: mypy.plugin.AnalyzeTypeContext) -> mypy.types.Ty return _type_analyze_hook +def _error_non_protocol_member(arg: mypy.types.Type, *, context: AnyContext) -> None: + context.api.fail("Only Protocols can be used in ProtocolIntersection.", arg, code=mypy.errorcodes.VALID_TYPE) + + def plugin(version: str) -> typing.Type[mypy.plugin.Plugin]: version_prefix, *_ = version.split("dev.", maxsplit=1) # stripping +dev.f6a8037cc... suffix if applicable numeric_prefixes = (_numeric_prefix(x) for x in version_prefix.split("."))