Skip to content

Commit

Permalink
Fix #8 non-Protocol members allowed in intersection (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
klausweiss authored Dec 27, 2023
1 parent 391d030 commit a9751e4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
Empty file.
19 changes: 19 additions & 0 deletions tests/test_past_issues/test_8_non_protocol_member/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Generic, TypeVar

from typing_protocol_intersection import ProtocolIntersection


class NotProtocol:
pass


T = TypeVar("T")


class Noop(Generic[T]):
@classmethod
def noop(cls, value: T) -> ProtocolIntersection[T]:
return value


np = Noop[NotProtocol].noop(NotProtocol())
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
ref: https://github.com/klausweiss/typing-protocol-intersection/issues/8
"""
from pathlib import Path

HERE = Path(__file__).parent


def test_8_non_protocol_member(run_mypy):
# given
input_file = HERE / "input.py"
# when
stdout, _stderr = run_mypy(input_file, no_incremental=False)
# then no error
assert not stdout.startswith("Success")
assert "error:" in stdout
assert "Only Protocols can be used in ProtocolIntersection" in stdout
17 changes: 13 additions & 4 deletions typing_protocol_intersection/mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import TypeGuard

SignatureContext = typing.Union[mypy.plugin.FunctionSigContext, mypy.plugin.MethodSigContext]
AnyContext = typing.Union[SignatureContext, mypy.plugin.AnalyzeTypeContext]


class ProtocolIntersectionPlugin(mypy.plugin.Plugin):
Expand Down Expand Up @@ -116,6 +117,10 @@ def mk_protocol_intersection_typeinfo(


class ProtocolIntersectionResolver:
def __init__(self, context: SignatureContext) -> None:
super().__init__()
self._context = context

def fold_intersection_and_its_args(self, type_: mypy.types.Type) -> mypy.types.Type:
folded_type = self.fold_intersection(type_)
if isinstance(folded_type, mypy.types.Instance):
Expand All @@ -142,6 +147,8 @@ def _run_fold(self, type_: mypy.types.Instance, intersection_type_info_wrapper:
intersections_to_process.append(arg)
continue
if isinstance(arg, mypy.types.Instance):
if not arg.type.is_protocol:
_error_non_protocol_member(arg, context=self._context)
self._add_type_to_intersection(intersection_type_info_wrapper, arg)
return intersection_type_info_wrapper

Expand All @@ -166,7 +173,7 @@ def _is_intersection(typ: mypy.types.Type) -> TypeGuard[mypy.types.Instance]:


def intersection_function_signature_hook(context: SignatureContext) -> mypy.types.FunctionLike:
resolver = ProtocolIntersectionResolver()
resolver = ProtocolIntersectionResolver(context)
signature = context.default_signature
signature.ret_type = resolver.fold_intersection_and_its_args(signature.ret_type)
signature.arg_types = [resolver.fold_intersection_and_its_args(t) for t in signature.arg_types]
Expand All @@ -182,9 +189,7 @@ def _type_analyze_hook(context: mypy.plugin.AnalyzeTypeContext) -> mypy.types.Ty
if arg.type.is_protocol:
base_types_of_args.update(arg.type.mro)
else:
context.api.fail(
"Only Protocols can be used in ProtocolIntersection.", arg, code=mypy.errorcodes.VALID_TYPE
)
_error_non_protocol_member(arg, context=context)
symbol_table = mypy.nodes.SymbolTable(collections.ChainMap(*(base.names for base in base_types_of_args)))
type_info = mk_protocol_intersection_typeinfo(
context.type.name, fullname=UniqueFullname(fullname), symbol_table=symbol_table
Expand All @@ -197,6 +202,10 @@ def _type_analyze_hook(context: mypy.plugin.AnalyzeTypeContext) -> mypy.types.Ty
return _type_analyze_hook


def _error_non_protocol_member(arg: mypy.types.Type, *, context: AnyContext) -> None:
context.api.fail("Only Protocols can be used in ProtocolIntersection.", arg, code=mypy.errorcodes.VALID_TYPE)


def plugin(version: str) -> typing.Type[mypy.plugin.Plugin]:
version_prefix, *_ = version.split("dev.", maxsplit=1) # stripping +dev.f6a8037cc... suffix if applicable
numeric_prefixes = (_numeric_prefix(x) for x in version_prefix.split("."))
Expand Down

0 comments on commit a9751e4

Please sign in to comment.