diff --git a/tests/commands/__init__.py b/tests/commands/__init__.py index e69de29b..19732206 100644 --- a/tests/commands/__init__.py +++ b/tests/commands/__init__.py @@ -0,0 +1,76 @@ +from collections.abc import Callable, Sequence +from typing import Any +from unittest.mock import AsyncMock, Mock, call + +from deebot_client.authentication import Authenticator +from deebot_client.command import Command, CommandResult +from deebot_client.event_bus import EventBus +from deebot_client.events import Event +from deebot_client.models import Credentials, DeviceInfo, StaticDeviceInfo + + +def _wrap_command(command: Command) -> tuple[Command, Callable[[CommandResult], None]]: + result: CommandResult | None = None + execute_fn = command._execute + + async def _execute( + _: Command, + authenticator: Authenticator, + device_info: DeviceInfo, + event_bus: EventBus, + ) -> CommandResult: + nonlocal result + result = await execute_fn(authenticator, device_info, event_bus) + return result + + def verify_result(expected_result: CommandResult) -> None: + assert result == expected_result + + command._execute = _execute.__get__(command) # type: ignore[method-assign] + return (command, verify_result) + + +async def assert_command( + command: Command, + json_api_response: dict[str, Any], + expected_events: Event | None | Sequence[Event], + *, + static_device_info: StaticDeviceInfo, + command_result: CommandResult | None = None, +) -> None: + command_result = command_result or CommandResult.success() + event_bus = Mock(spec_set=EventBus) + authenticator = Mock(spec_set=Authenticator) + authenticator.authenticate = AsyncMock( + return_value=Credentials("token", "user_id", 9999) + ) + authenticator.post_authenticated = AsyncMock(return_value=json_api_response) + device_info = DeviceInfo( + { + "company": "company", + "did": "did", + "name": "name", + "nick": "nick", + "resource": "resource", + "deviceName": "device_name", + "status": 1, + "class": "get_class", + }, + static_device_info, + ) + + command, verify_result = _wrap_command(command) + + await command.execute(authenticator, device_info, event_bus) + + # verify + verify_result(command_result) + authenticator.post_authenticated.assert_called() + if expected_events: + if isinstance(expected_events, Sequence): + event_bus.notify.assert_has_calls([call(x) for x in expected_events]) + assert event_bus.notify.call_count == len(expected_events) + else: + event_bus.notify.assert_called_once_with(expected_events) + else: + event_bus.notify.assert_not_called() diff --git a/tests/commands/json/__init__.py b/tests/commands/json/__init__.py index 87fcc576..d4bd431f 100644 --- a/tests/commands/json/__init__.py +++ b/tests/commands/json/__init__.py @@ -1,11 +1,10 @@ -from collections.abc import Callable, Sequence +from functools import partial from typing import Any -from unittest.mock import AsyncMock, Mock, call +from unittest.mock import Mock from testfixtures import LogCapture -from deebot_client.authentication import Authenticator -from deebot_client.command import Command, CommandResult +from deebot_client.command import CommandResult from deebot_client.commands.json.common import ( ExecuteCommand, JsonSetCommand, @@ -15,73 +14,12 @@ from deebot_client.events import EnableEvent, Event from deebot_client.hardware.deebot import FALLBACK, get_static_device_info from deebot_client.message import HandlingState -from deebot_client.models import Credentials, DeviceInfo +from tests.commands import assert_command as assert_command_base from tests.helpers import get_message_json, get_request_json, get_success_body - -def _wrap_command(command: Command) -> tuple[Command, Callable[[CommandResult], None]]: - result: CommandResult | None = None - execute_fn = command._execute - - async def _execute( - _: Command, - authenticator: Authenticator, - device_info: DeviceInfo, - event_bus: EventBus, - ) -> CommandResult: - nonlocal result - result = await execute_fn(authenticator, device_info, event_bus) - return result - - def verify_result(expected_result: CommandResult) -> None: - assert result == expected_result - - command._execute = _execute.__get__(command) # type: ignore[method-assign] - return (command, verify_result) - - -async def assert_command( - command: Command, - json_api_response: dict[str, Any], - expected_events: Event | None | Sequence[Event], - command_result: CommandResult | None = None, -) -> None: - command_result = command_result or CommandResult.success() - event_bus = Mock(spec_set=EventBus) - authenticator = Mock(spec_set=Authenticator) - authenticator.authenticate = AsyncMock( - return_value=Credentials("token", "user_id", 9999) - ) - authenticator.post_authenticated = AsyncMock(return_value=json_api_response) - device_info = DeviceInfo( - { - "company": "company", - "did": "did", - "name": "name", - "nick": "nick", - "resource": "resource", - "deviceName": "device_name", - "status": 1, - "class": "get_class", - }, - get_static_device_info(FALLBACK), - ) - - command, verify_result = _wrap_command(command) - - await command.execute(authenticator, device_info, event_bus) - - # verify - verify_result(command_result) - authenticator.post_authenticated.assert_called() - if expected_events: - if isinstance(expected_events, Sequence): - event_bus.notify.assert_has_calls([call(x) for x in expected_events]) - assert event_bus.notify.call_count == len(expected_events) - else: - event_bus.notify.assert_called_once_with(expected_events) - else: - event_bus.notify.assert_not_called() +assert_command = partial( + assert_command_base, static_device_info=get_static_device_info(FALLBACK) +) async def assert_execute_command( @@ -98,7 +36,9 @@ async def assert_execute_command( with LogCapture() as log: body = {"code": 500, "msg": "fail"} json = get_request_json(body) - await assert_command(command, json, None, CommandResult(HandlingState.FAILED)) + await assert_command( + command, json, None, command_result=CommandResult(HandlingState.FAILED) + ) log.check_present( ( diff --git a/tests/commands/json/test_charge.py b/tests/commands/json/test_charge.py index 43c8399b..c0ecec33 100644 --- a/tests/commands/json/test_charge.py +++ b/tests/commands/json/test_charge.py @@ -37,7 +37,9 @@ async def test_Charge(json: dict[str, Any], expected: StateEvent) -> None: async def test_Charge_failed(caplog: pytest.LogCaptureFixture) -> None: json = _prepare_json(500, "fail") - await assert_command(Charge(), json, None, CommandResult(HandlingState.FAILED)) + await assert_command( + Charge(), json, None, command_result=CommandResult(HandlingState.FAILED) + ) assert ( "deebot_client.commands.json.common", diff --git a/tests/commands/json/test_clean_log.py b/tests/commands/json/test_clean_log.py index 30cfa229..ffbbf32b 100644 --- a/tests/commands/json/test_clean_log.py +++ b/tests/commands/json/test_clean_log.py @@ -127,7 +127,7 @@ async def test_GetCleanLogs_analyse_logged( GetCleanLogs(), json, None, - CommandResult(HandlingState.ANALYSE_LOGGED), + command_result=CommandResult(HandlingState.ANALYSE_LOGGED), ) assert ( @@ -142,7 +142,7 @@ async def test_GetCleanLogs_handle_error(caplog: pytest.LogCaptureFixture) -> No GetCleanLogs(), {}, None, - CommandResult(HandlingState.ERROR), + command_result=CommandResult(HandlingState.ERROR), ) assert ( diff --git a/tests/commands/json/test_custom.py b/tests/commands/json/test_custom.py index 6470cf99..41c020d9 100644 --- a/tests/commands/json/test_custom.py +++ b/tests/commands/json/test_custom.py @@ -31,4 +31,4 @@ async def test_CustomCommand( expected: CustomCommandEvent | None, command_result: CommandResult, ) -> None: - await assert_command(command, json, expected, command_result) + await assert_command(command, json, expected, command_result=command_result) diff --git a/tests/commands/json/test_map.py b/tests/commands/json/test_map.py index e6eb58e6..3eb9fd31 100644 --- a/tests/commands/json/test_map.py +++ b/tests/commands/json/test_map.py @@ -107,7 +107,7 @@ async def test_getCachedMapInfo() -> None: GetCachedMapInfo(), json, CachedMapInfoEvent(expected_name, active=True), - CommandResult( + command_result=CommandResult( HandlingState.SUCCESS, {"map_id": expected_mid}, [GetMapSet(expected_mid, entry) for entry in MapSetType], @@ -165,7 +165,7 @@ async def test_getMapSet() -> None: GetMapSet(mid), json, MapSetEvent(MapSetType.ROOMS, subsets), - CommandResult( + command_result=CommandResult( HandlingState.SUCCESS, {"id": "199390082", "set_id": "8", "type": "ar", "subsets": subsets}, [ @@ -195,5 +195,7 @@ async def test_getMapTrace() -> None: GetMapTrace(start), json, MapTraceEvent(start=start, total=total, data=trace_value), - CommandResult(HandlingState.SUCCESS, {"start": start, "total": total}, []), + command_result=CommandResult( + HandlingState.SUCCESS, {"start": start, "total": total}, [] + ), )