Skip to content

Commit

Permalink
Rename mqtt config + arguments (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus authored Jan 29, 2024
1 parent 9ae284c commit 4a6023d
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 29 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ import logging
import time

from deebot_client.api_client import ApiClient
from deebot_client.authentication import Authenticator
from deebot_client.authentication import Authenticator, create_rest_config
from deebot_client.commands import *
from deebot_client.commands.clean import CleanAction
from deebot_client.events import BatteryEvent
from deebot_client.configuration import Configuration
from deebot_client.mqtt_client import MqttClient
from deebot_client.mqtt_client import MqttClient, create_mqtt_config
from deebot_client.util import md5
from deebot_client.device import Device

Expand All @@ -42,16 +41,17 @@ country = "de"
async def main():
async with aiohttp.ClientSession() as session:
logging.basicConfig(level=logging.DEBUG)
config = create_config(session, device_id=device_id, country=country)
rest_config = create_rest_config(session, device_id=device_id, country=country)

authenticator = Authenticator(config.rest, account_id, password_hash)
authenticator = Authenticator(rest_config, account_id, password_hash)
api_client = ApiClient(authenticator)

devices_ = await api_client.get_devices()

bot = Device(devices_[0], authenticator)

mqtt = MqttClient(config.mqtt, authenticator)
mqtt_config = create_mqtt_config(device_id=device_id, country=country)
mqtt = MqttClient(mqtt_config, authenticator)
await bot.initialize(mqtt)

async def on_battery(event: BatteryEvent):
Expand Down
3 changes: 2 additions & 1 deletion deebot_client/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class RestConfiguration:
auth_code_url: str


def create_config(
def create_rest_config(
session: ClientSession,
*,
device_id: str,
country: str,
override_rest_url: str | None = None,
Expand Down
10 changes: 9 additions & 1 deletion deebot_client/const.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Constants module."""
from __future__ import annotations

from enum import StrEnum
from enum import Enum, StrEnum
from typing import Self

REALM = "ecouser.net"
Expand Down Expand Up @@ -30,6 +30,14 @@ def get(cls, value: str) -> Self | None:
return None


class UndefinedType(Enum):
"""Singleton type for use with not set sentinel values."""

_singleton = 0


UNDEFINED = UndefinedType._singleton # pylint: disable=protected-access # noqa: SLF001

# from https://github.com/mrbungle64/ecovacs-deebot.js/blob/master/library/errorCodes.json
ERROR_CODES = {
-3: "Error parsing response data",
Expand Down
15 changes: 8 additions & 7 deletions deebot_client/mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiomqtt import Client, Message, MqttError as AioMqttError
from cachetools import TTLCache

from deebot_client.const import DataType
from deebot_client.const import UNDEFINED, DataType, UndefinedType
from deebot_client.exceptions import AuthenticationError, MqttError

from .commands import COMMANDS_WITH_MQTT_P2P_HANDLING
Expand Down Expand Up @@ -56,22 +56,22 @@ class MqttConfiguration:
device_id: str


def create_config(
def create_mqtt_config(
*,
device_id: str,
country: str,
override_mqtt_url: str | None = None,
*,
disable_ssl_context_validation: bool = False,
ssl_context: ssl.SSLContext | None | UndefinedType = UNDEFINED,
) -> MqttConfiguration:
"""Create configuration."""
continent_postfix = get_continent_url_postfix(country.upper())

ssl_ctx = None
if override_mqtt_url:
url = urlparse(override_mqtt_url)
match url.scheme:
case "mqtt":
default_port = 1883
ssl_ctx = None
case "mqtts":
default_port = 8883
ssl_ctx = ssl.create_default_context()
Expand All @@ -86,12 +86,13 @@ def create_config(
else:
hostname = f"mq{continent_postfix}.ecouser.net"
port = 443

if not override_mqtt_url or disable_ssl_context_validation:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE

if ssl_context is not UNDEFINED:
ssl_ctx = ssl_context

return MqttConfiguration(
hostname=hostname,
port=port,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from deebot_client.authentication import (
Authenticator,
RestConfiguration,
create_config as create_config_rest,
create_rest_config as create_config_rest,
)
from deebot_client.event_bus import EventBus
from deebot_client.hardware.deebot import FALLBACK, get_static_device_info
Expand All @@ -24,7 +24,7 @@
from deebot_client.mqtt_client import (
MqttClient,
MqttConfiguration,
create_config as create_config_mqtt,
create_mqtt_config as create_config_mqtt,
)

from .fixtures.mqtt_server import MqttServer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from deebot_client.authentication import Authenticator, create_config
from deebot_client.authentication import Authenticator, create_rest_config
from deebot_client.models import Credentials

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_config_override_rest_url(
expected_auth_code_url: str,
) -> None:
"""Test rest configuration."""
config = create_config(
config = create_rest_config(
session=session,
device_id="123",
country=country,
Expand Down
22 changes: 12 additions & 10 deletions tests/test_mqtt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from deebot_client.commands.json.battery import GetBattery
from deebot_client.commands.json.volume import SetVolume
from deebot_client.const import DataType
from deebot_client.const import UNDEFINED, DataType, UndefinedType
from deebot_client.exceptions import AuthenticationError, MqttError
from deebot_client.mqtt_client import MqttClient, MqttConfiguration, create_config
from deebot_client.mqtt_client import MqttClient, MqttConfiguration, create_mqtt_config

from .mqtt_util import subscribe, verify_subscribe

Expand Down Expand Up @@ -343,33 +343,35 @@ async def test_mqtt_task_exceptions(
],
)
@pytest.mark.parametrize("device_id", ["test", "123"])
@pytest.mark.parametrize("disable_ssl_context_validation", [True, False])
@pytest.mark.parametrize("ssl_context", [UNDEFINED, None, ssl.create_default_context()])
def test_config(
authenticator: Authenticator,
country: str,
device_id: str,
override_mqtt_url: str | None,
expected_hostname: str,
expected_port: int,
ssl_context: ssl.SSLContext | None | UndefinedType,
*,
disable_ssl_context_validation: bool,
expect_ssl_context: bool,
) -> None:
"""Test mqtt part of the configuration."""
client = MqttClient(
create_config(
create_mqtt_config(
device_id=device_id,
country=country,
override_mqtt_url=override_mqtt_url,
disable_ssl_context_validation=disable_ssl_context_validation,
ssl_context=ssl_context,
),
authenticator,
)
config = client._config
assert config.hostname == expected_hostname
assert config.device_id == device_id
assert config.port == expected_port
if expect_ssl_context or disable_ssl_context_validation:
if isinstance(ssl_context, ssl.SSLContext) or (
expect_ssl_context and isinstance(ssl_context, UndefinedType)
):
assert isinstance(config.ssl_context, ssl.SSLContext)
else:
assert config.ssl_context is None
Expand All @@ -389,7 +391,7 @@ def test_config_override_mqtt_url_invalid(
"""Test that an invalid mqtt override url will raise a DeebotError."""
with pytest.raises(MqttError, match=error_msg):
MqttClient(
create_config(
create_mqtt_config(
device_id="123",
country="IT",
override_mqtt_url=override_mqtt_url,
Expand All @@ -401,7 +403,7 @@ def test_config_override_mqtt_url_invalid(
async def test_verify_config(authenticator: Authenticator) -> None:
with patch("deebot_client.mqtt_client.Client", autospec=True) as client_mock:
client = MqttClient(
create_config(
create_mqtt_config(
device_id="123",
country="IT",
),
Expand All @@ -416,7 +418,7 @@ async def test_verify_config_fails(authenticator: Authenticator) -> None:
with patch("deebot_client.mqtt_client.Client", autospec=True) as client_mock:
client_mock.return_value.__aenter__.side_effect = AioMqttError
client = MqttClient(
create_config(
create_mqtt_config(
device_id="123",
country="IT",
),
Expand Down

0 comments on commit 4a6023d

Please sign in to comment.