From e7fa784cb58417f198ee2577b576b3bae1d85bb6 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Wed, 10 Jan 2024 22:08:51 +0100 Subject: [PATCH] Fix typing --- src/paho/mqtt/client.py | 23 ++++++++++++----------- src/paho/mqtt/publish.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index f3c88b61..de0892a0 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -583,7 +583,7 @@ def __init__( clean_session: bool | None = None, userdata: Any = None, protocol: int = MQTTv311, - transport: str = "tcp", + transport: Literal["tcp", "websockets"] = "tcp", reconnect_on_failure: bool = True, manual_ack: bool = False, ) -> None: @@ -627,8 +627,13 @@ def __init__( locally. """ + transport = transport.lower() # type: ignore + if transport not in ("websockets", "tcp"): + raise ValueError( + f'transport must be "websockets" or "tcp", not {transport}') + self._manual_ack = manual_ack - self.transport = transport + self._transport = transport self._protocol = protocol self._userdata = userdata self._sock: SocketLike | None = None @@ -790,25 +795,21 @@ def keepalive(self, value: int) -> None: self._keepalive = value @property - def transport(self) -> str: + def transport(self) -> Literal["tcp", "websockets"]: """Transport method used for the connection.""" return self._transport @transport.setter - def transport(self, value: str) -> None: + def transport(self, value: Literal["tcp", "websockets"]) -> None: """ Update transport which should be "tcp" or "websockets". This will only be used on future (re)connection. You should probably use reconnect() to update the connection if established. """ - if value.lower() not in ("websockets", "tcp"): - raise ValueError( - f'transport must be "websockets" or "tcp", not {value}') - - self._transport = value.lower() + self._transport = value @property - def protocol(self) -> int: + def protocol(self) -> MQTTProtocolVersion: """Protocol version used (MQTT v3, MQTT v3.11, MQTTv5)""" return self.protocol @@ -818,7 +819,7 @@ def connect_timeout(self) -> float: return self._connect_timeout @connect_timeout.setter - def connect_timeout(self, value: float): + def connect_timeout(self, value: float) -> None: "Change connect_timeout for future (re)connection" if value <= 0.0: raise ValueError("timeout must be a positive number") diff --git a/src/paho/mqtt/publish.py b/src/paho/mqtt/publish.py index 0e68a82a..d07cc412 100644 --- a/src/paho/mqtt/publish.py +++ b/src/paho/mqtt/publish.py @@ -33,6 +33,12 @@ except ImportError: from typing_extensions import NotRequired, Required, TypedDict + try: + from typing import Literal + except ImportError: + from typing_extensions import Literal # type: ignore + + class AuthParameter(TypedDict, total=False): username: Required[str] @@ -108,7 +114,7 @@ def multiple( auth: AuthParameter | None = None, tls: TLSParameter | None = None, protocol: int = paho.MQTTv311, - transport: str = "tcp", + transport: Literal["tcp", "websockets"] = "tcp", proxy_args: Any | None = None, ) -> None: """Publish multiple messages to a broker, then disconnect cleanly. @@ -231,7 +237,7 @@ def single( auth: AuthParameter | None = None, tls: TLSParameter | None = None, protocol: int = paho.MQTTv311, - transport: str = "tcp", + transport: Literal["tcp", "websockets"] = "tcp", proxy_args: Any | None = None, ) -> None: """Publish a single message to a broker, then disconnect cleanly.