Skip to content

Commit

Permalink
Prevent task cancellation from propagating to ASH (#628)
Browse files Browse the repository at this point in the history
* Do not allow task cancellation to propagate to ASH

* Enter a failed state if we cannot send a frame

* Support resolving multiple frames at once (we still limit TX_K=1)
  • Loading branch information
puddly committed Jun 14, 2024
1 parent 3875d11 commit 3657cf3
Showing 1 changed file with 46 additions and 14 deletions.
60 changes: 46 additions & 14 deletions bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import abc
import asyncio
import binascii
from collections.abc import Coroutine
import dataclasses
import enum
import logging
import sys
import time
import typing

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
Expand Down Expand Up @@ -55,7 +57,7 @@ class Reserved(enum.IntEnum):

# Maximum number of DATA frames the NCP can transmit without having received
# acknowledgements
TX_K = 1
TX_K = 1 # TODO: investigate why this cannot be raised without causing a firmware crash

# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
Expand All @@ -81,6 +83,23 @@ def generate_random_sequence(length: int) -> bytes:
# Since the sequence is static for every frame, we only need to generate it once
PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256)

if sys.version_info[:2] < (3, 12):
create_eager_task = asyncio.create_task
else:
_T = typing.TypeVar("T")

def create_eager_task(
coro: Coroutine[typing.Any, typing.Any, _T],
*,
name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> asyncio.Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
if loop is None:
loop = asyncio.get_running_loop()

return asyncio.Task(coro, loop=loop, name=name, eager_start=True)


class NcpState(enum.Enum):
CONNECTED = "connected"
Expand Down Expand Up @@ -463,15 +482,14 @@ def data_received(self, data: bytes) -> None:
def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
# Note that ackNum is the number of the next frame the receiver expects and it
# is one greater than the last frame received.
ack_num = (frame.ack_num - 1) % 8
for ack_num_offset in range(-TX_K, 0):
ack_num = (frame.ack_num + ack_num_offset) % 8
fut = self._pending_data_frames.get(ack_num)

fut = self._pending_data_frames.get(ack_num)
if fut is None or fut.done():
continue

if fut is None or fut.done():
return

# _LOGGER.debug("Resolving frame %d", ack_num)
self._pending_data_frames[ack_num].set_result(True)
self._pending_data_frames[ack_num].set_result(True)

def frame_received(self, frame: AshFrame) -> None:
_LOGGER.debug("Received frame %r", frame)
Expand Down Expand Up @@ -537,13 +555,16 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
self._ncp_state = NcpState.FAILED

# Cancel all pending requests
exc = NcpFailure(code=self._ncp_reset_code)
self._enter_failed_state(self._ncp_reset_code)

def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
exc = NcpFailure(code=reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

self._ezsp_protocol.reset_received(frame.reset_code)
self._ezsp_protocol.reset_received(reset_code)

def _write_frame(
self,
Expand Down Expand Up @@ -582,7 +603,7 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
for attempt in range(ACK_TIMEOUTS):
if self._ncp_state == NcpState.FAILED:
_LOGGER.debug(
"NCP is in a failed state, not re-sending: %r", frame
"NCP is in a failed state, not sending: %r", frame
)
raise NcpFailure(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
Expand Down Expand Up @@ -618,6 +639,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta)

if attempt >= ACK_TIMEOUTS - 1:
self._enter_failed_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)
raise
except NcpFailure:
_LOGGER.debug(
Expand All @@ -635,6 +659,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._change_ack_timeout(2 * self._t_rx_ack)

if attempt >= ACK_TIMEOUTS - 1:
self._enter_failed_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)
raise
else:
# Whenever an acknowledgement is received, t_rx_ack is set to
Expand All @@ -649,9 +676,14 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._pending_data_frames.pop(frm_num)

async def send_data(self, data: bytes) -> None:
await self._send_data_frame(
# All of the other fields will be set during transmission/retries
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
# Sending data is a critical operation and cannot really be cancelled
await asyncio.shield(
create_eager_task(
self._send_data_frame(
# All of the other fields will be set during transmission/retries
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
)
)
)

def send_reset(self) -> None:
Expand Down

0 comments on commit 3657cf3

Please sign in to comment.