Skip to content

Commit

Permalink
Move assert_command outside of json (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus authored Dec 12, 2023
1 parent d0f90f1 commit 6afeb4e
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 77 deletions.
76 changes: 76 additions & 0 deletions tests/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from collections.abc import Callable, Sequence
from typing import Any
from unittest.mock import AsyncMock, Mock, call

from deebot_client.authentication import Authenticator
from deebot_client.command import Command, CommandResult
from deebot_client.event_bus import EventBus
from deebot_client.events import Event
from deebot_client.models import Credentials, DeviceInfo, StaticDeviceInfo


def _wrap_command(command: Command) -> tuple[Command, Callable[[CommandResult], None]]:
result: CommandResult | None = None
execute_fn = command._execute

async def _execute(
_: Command,
authenticator: Authenticator,
device_info: DeviceInfo,
event_bus: EventBus,
) -> CommandResult:
nonlocal result
result = await execute_fn(authenticator, device_info, event_bus)
return result

def verify_result(expected_result: CommandResult) -> None:
assert result == expected_result

command._execute = _execute.__get__(command) # type: ignore[method-assign]
return (command, verify_result)


async def assert_command(
command: Command,
json_api_response: dict[str, Any],
expected_events: Event | None | Sequence[Event],
*,
static_device_info: StaticDeviceInfo,
command_result: CommandResult | None = None,
) -> None:
command_result = command_result or CommandResult.success()
event_bus = Mock(spec_set=EventBus)
authenticator = Mock(spec_set=Authenticator)
authenticator.authenticate = AsyncMock(
return_value=Credentials("token", "user_id", 9999)
)
authenticator.post_authenticated = AsyncMock(return_value=json_api_response)
device_info = DeviceInfo(
{
"company": "company",
"did": "did",
"name": "name",
"nick": "nick",
"resource": "resource",
"deviceName": "device_name",
"status": 1,
"class": "get_class",
},
static_device_info,
)

command, verify_result = _wrap_command(command)

await command.execute(authenticator, device_info, event_bus)

# verify
verify_result(command_result)
authenticator.post_authenticated.assert_called()
if expected_events:
if isinstance(expected_events, Sequence):
event_bus.notify.assert_has_calls([call(x) for x in expected_events])
assert event_bus.notify.call_count == len(expected_events)
else:
event_bus.notify.assert_called_once_with(expected_events)
else:
event_bus.notify.assert_not_called()
80 changes: 10 additions & 70 deletions tests/commands/json/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, Mock, call
from unittest.mock import Mock

from testfixtures import LogCapture

from deebot_client.authentication import Authenticator
from deebot_client.command import Command, CommandResult
from deebot_client.command import CommandResult
from deebot_client.commands.json.common import (
ExecuteCommand,
JsonSetCommand,
Expand All @@ -15,73 +14,12 @@
from deebot_client.events import EnableEvent, Event
from deebot_client.hardware.deebot import FALLBACK, get_static_device_info
from deebot_client.message import HandlingState
from deebot_client.models import Credentials, DeviceInfo
from tests.commands import assert_command as assert_command_base
from tests.helpers import get_message_json, get_request_json, get_success_body


def _wrap_command(command: Command) -> tuple[Command, Callable[[CommandResult], None]]:
result: CommandResult | None = None
execute_fn = command._execute

async def _execute(
_: Command,
authenticator: Authenticator,
device_info: DeviceInfo,
event_bus: EventBus,
) -> CommandResult:
nonlocal result
result = await execute_fn(authenticator, device_info, event_bus)
return result

def verify_result(expected_result: CommandResult) -> None:
assert result == expected_result

command._execute = _execute.__get__(command) # type: ignore[method-assign]
return (command, verify_result)


async def assert_command(
command: Command,
json_api_response: dict[str, Any],
expected_events: Event | None | Sequence[Event],
command_result: CommandResult | None = None,
) -> None:
command_result = command_result or CommandResult.success()
event_bus = Mock(spec_set=EventBus)
authenticator = Mock(spec_set=Authenticator)
authenticator.authenticate = AsyncMock(
return_value=Credentials("token", "user_id", 9999)
)
authenticator.post_authenticated = AsyncMock(return_value=json_api_response)
device_info = DeviceInfo(
{
"company": "company",
"did": "did",
"name": "name",
"nick": "nick",
"resource": "resource",
"deviceName": "device_name",
"status": 1,
"class": "get_class",
},
get_static_device_info(FALLBACK),
)

command, verify_result = _wrap_command(command)

await command.execute(authenticator, device_info, event_bus)

# verify
verify_result(command_result)
authenticator.post_authenticated.assert_called()
if expected_events:
if isinstance(expected_events, Sequence):
event_bus.notify.assert_has_calls([call(x) for x in expected_events])
assert event_bus.notify.call_count == len(expected_events)
else:
event_bus.notify.assert_called_once_with(expected_events)
else:
event_bus.notify.assert_not_called()
assert_command = partial(
assert_command_base, static_device_info=get_static_device_info(FALLBACK)
)


async def assert_execute_command(
Expand All @@ -98,7 +36,9 @@ async def assert_execute_command(
with LogCapture() as log:
body = {"code": 500, "msg": "fail"}
json = get_request_json(body)
await assert_command(command, json, None, CommandResult(HandlingState.FAILED))
await assert_command(
command, json, None, command_result=CommandResult(HandlingState.FAILED)
)

log.check_present(
(
Expand Down
4 changes: 3 additions & 1 deletion tests/commands/json/test_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ async def test_Charge(json: dict[str, Any], expected: StateEvent) -> None:

async def test_Charge_failed(caplog: pytest.LogCaptureFixture) -> None:
json = _prepare_json(500, "fail")
await assert_command(Charge(), json, None, CommandResult(HandlingState.FAILED))
await assert_command(
Charge(), json, None, command_result=CommandResult(HandlingState.FAILED)
)

assert (
"deebot_client.commands.json.common",
Expand Down
4 changes: 2 additions & 2 deletions tests/commands/json/test_clean_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def test_GetCleanLogs_analyse_logged(
GetCleanLogs(),
json,
None,
CommandResult(HandlingState.ANALYSE_LOGGED),
command_result=CommandResult(HandlingState.ANALYSE_LOGGED),
)

assert (
Expand All @@ -142,7 +142,7 @@ async def test_GetCleanLogs_handle_error(caplog: pytest.LogCaptureFixture) -> No
GetCleanLogs(),
{},
None,
CommandResult(HandlingState.ERROR),
command_result=CommandResult(HandlingState.ERROR),
)

assert (
Expand Down
2 changes: 1 addition & 1 deletion tests/commands/json/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ async def test_CustomCommand(
expected: CustomCommandEvent | None,
command_result: CommandResult,
) -> None:
await assert_command(command, json, expected, command_result)
await assert_command(command, json, expected, command_result=command_result)
8 changes: 5 additions & 3 deletions tests/commands/json/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def test_getCachedMapInfo() -> None:
GetCachedMapInfo(),
json,
CachedMapInfoEvent(expected_name, active=True),
CommandResult(
command_result=CommandResult(
HandlingState.SUCCESS,
{"map_id": expected_mid},
[GetMapSet(expected_mid, entry) for entry in MapSetType],
Expand Down Expand Up @@ -165,7 +165,7 @@ async def test_getMapSet() -> None:
GetMapSet(mid),
json,
MapSetEvent(MapSetType.ROOMS, subsets),
CommandResult(
command_result=CommandResult(
HandlingState.SUCCESS,
{"id": "199390082", "set_id": "8", "type": "ar", "subsets": subsets},
[
Expand Down Expand Up @@ -195,5 +195,7 @@ async def test_getMapTrace() -> None:
GetMapTrace(start),
json,
MapTraceEvent(start=start, total=total, data=trace_value),
CommandResult(HandlingState.SUCCESS, {"start": start, "total": total}, []),
command_result=CommandResult(
HandlingState.SUCCESS, {"start": start, "total": total}, []
),
)

0 comments on commit 6afeb4e

Please sign in to comment.