Skip to content

Commit

Permalink
Docs for new methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gierwialo committed Jul 5, 2023
1 parent 4e24000 commit b7ab9ac
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 62 deletions.
15 changes: 5 additions & 10 deletions packages/python-runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,14 @@ async def main(self, server_host, server_port):

await self.protocol.sync()
await self.run_instance(config, input_stream, args)

await self.protocol.sync()

heartbeat_task.cancel()
await asyncio.gather(*[heartbeat_task])
control_stream_task.cancel()
await asyncio.gather(*[connect_input_stream_task,
heartbeat_task,
control_stream_task])

await asyncio.gather(*[connect_input_stream_task])

await self.protocol.get_channel(CC.IN).sync()
await self.protocol.get_channel(CC.CONTROL).sync()

control_stream_task.cancel()
await asyncio.gather(*[control_stream_task])

await self.protocol.stop()

[ task.cancel() if task.get_name() != 'RUNNER_MAIN' else None for task in asyncio.all_tasks()]
Expand Down
122 changes: 70 additions & 52 deletions packages/python-runner/tecemux.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import random
import socket
from typing import Any, Coroutine

from attrs import define, field
from barrier import Barrier
Expand Down Expand Up @@ -61,23 +60,30 @@ def __init__(self, channel: CC,
self._stop_channel_event = stop_event
self._sync_channel_event = sync_event
self._sync_barrier = sync_barrier
self._outcoming_process_task = asyncio.create_task(self._outcomming_process_tasks(), name=f'{self._get_channel_name()}_PROCESS_TASK')
self._outcoming_process_task = asyncio.create_task(self._outcoming_process_tasks(), name=f'{self._get_channel_name()}_PROCESS_TASK')
self._ended = None
self._read_buffer = bytearray()

async def sync(self):
async def sync(self) -> None:
"""Waits until internal queues will be empty (packets will be processed)
"""

if self._is_incoming():
await self._internal_incoming_queue.join()

await self._internal_outcoming_queue.join()

def _is_incoming(self) -> bool:
"""Returns whether channel is mostly read by Tememux

def _is_incoming(self):
Returns:
bool: True is mostly read by Runner, False otherwise
"""
return self._channel_enum in [CC.STDIN, CC.IN, CC.CONTROL]

async def _outcomming_process_tasks(self):

async def _outcoming_process_tasks(self) -> None:
""" Internal loop for outcoming data from channel
"""
while True:
try:
await self._queue_up_outcoming()
Expand All @@ -88,14 +94,15 @@ async def _outcomming_process_tasks(self):
except asyncio.CancelledError:
return




def _debug(self, msg):
"""Wrapper for printing debug messages
Args:
msg (str): Debug message
"""
if TECEMUX_INTERNAL_VERBOSE_DEBUG:
self._logger.debug(msg)


def _get_channel_name(self) -> str:
"""Returns channel name
Expand All @@ -121,12 +128,25 @@ def _set_logger(self, logger: logging.Logger) -> None:
"""
self._logger = logger

async def readline(self) -> str:
async def readline(self) -> bytes:
"""Reads data from current channel until '\n' appears
Returns:
bytes: Bufer data from channel
"""
sep = b'\n'
line = await self.readuntil(sep)
return line

async def readuntil(self, separator=b'\n'):
async def readuntil(self, separator=b'\n') -> bytes:
"""Reads data from current channel until provided separator appears
Args:
separator (bytes, optional): Defaults to b'\n'.
Returns:
bytes: Buffer data from channel
"""
seplen = len(separator)

offset = 0
Expand All @@ -149,8 +169,12 @@ async def readuntil(self, separator=b'\n'):
del self._read_buffer[:isep + seplen]
return bytes(chunk)

async def _get_data(self):
async def _get_data(self) -> bool:
"""Internal method, returns True if data are moved from queue to _read_buffer, false otherwise
Returns:
bool: True if data are moved from queue to _read_buffer, false otherwise
"""
if self._ended and self._internal_incoming_queue.empty():
return False

Expand All @@ -176,8 +200,15 @@ async def _get_data(self):
await asyncio.sleep(0)
return True

async def read(self, n: int = -1):
async def read(self, n: int = -1) -> bytes:
"""Reads up to 'n' bytes of data from current channel
Args:
n (int, optional): Number of bytes. Defaults to -1.
Returns:
bytes: Buffer data from channel
"""
if n == 0:
return b''

Expand All @@ -204,6 +235,7 @@ async def send_ACK(self, sequence_number: int) -> None:
Args:
sequence_number (int): Value for Acknowledge field
"""

await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), flags=['ACK'], ack=sequence_number)))

async def _send_pause_ACK(self, sequence_number: int) -> None:
Expand All @@ -212,10 +244,11 @@ async def _send_pause_ACK(self, sequence_number: int) -> None:
Args:
sequence_number (int): Value for Acknowledge field
"""

await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), flags=['ACK', 'SYN'], ack=sequence_number)))

async def open(self) -> None:
"""Open channel.
"""Open channel
"""

if not self._channel_opened:
Expand All @@ -229,19 +262,22 @@ async def open(self) -> None:
self._ended = False

async def end(self) -> None:
"""Send EOF on current channel
"""

await self._internal_outcoming_queue.join()
await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), data=b'', flags=['PSH'])))
self._ended = True

async def close(self) -> None:
"""Close channel
"""Close current channel
"""

self._debug(f'Tecemux/{self._get_channel_name()}: [-] Channel close request is send')
await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), data=b'' if self._global_instance_id is None else self._global_instance_id, flags=['FIN'])))

async def queue_up_incoming(self, pkt: IPPacket) -> None:
"""Redirects incomming data from provided packet to current channel
"""Redirects incoming data from provided packet to current channel
Args:
pkt (IPPacket): Redirected packet from outside
Expand Down Expand Up @@ -308,21 +344,6 @@ def write(self, data: bytes) -> None:

self._internal_outcoming_queue.put_nowait(data)

def drain(self) -> bool:
"""Drain channel
"""
return

def write_eof(self) -> None:
"""asyncio.StreamWriter API
"""
return

def wait_closed(self) -> Coroutine:
"""asyncio.StreamWriter API
"""
return

def set_pause(self, state: bool) -> None:
"""Sets pause state for current channel
Expand Down Expand Up @@ -445,7 +466,7 @@ def get_channel(self, channel: CC) -> _ChannelContext:
"""
return self._channels[channel]

async def sync(self):
async def sync(self) -> None:
"""Waits until all write tasks will be done
"""
self._global_sync_channel_event.set()
Expand Down Expand Up @@ -477,13 +498,17 @@ def _chunk_preview(value: bytes) -> str:
return f'{value[0:5]}... <len:{len(value)}>' if len(value) > 5 else f'{value}'

async def stop(self) -> None:
""" Stops protocol
"""

await self._finish_channels()
await self._finish_incoming()
await self._finish_outcoming()

async def _finish_channels(self) -> None:
"""Stops protocol
""" Close all channels
"""

for channel in self.get_channels():
await channel.end()
await channel.close()
Expand All @@ -498,9 +523,11 @@ async def _finish_channels(self) -> None:
await asyncio.wait_for(channel._outcoming_process_task, timeout=1)
except asyncio.TimeoutError:
pass

async def _finish_outcoming(self):


async def _finish_outcoming(self) -> None:
""" Finish outcoming forwarder and main writer to STH
"""

self._global_stop_outcoming_event.set()
await asyncio.sleep(0)
await self._global_stop_outcoming_event.wait()
Expand All @@ -509,22 +536,14 @@ async def _finish_outcoming(self):
self._writer.close()
await self._writer.wait_closed()

async def _finish_incoming(self) -> None:
""" Finish incoming forwarder
"""

async def _finish_incoming(self):
self._global_stop_incoming_event.set()
await self._global_stop_incoming_event.wait()
await asyncio.gather(*[self._incoming_data_forwarder])


async def _read_eof(self):
while True:
try:
data = await asyncio.wait_for(self._reader.read(),0.5)
if not data:
break
except asyncio.TimeoutError as e:
break

async def loop(self) -> None:
"""Main loop of Tecemux protocol. Starts forwarders tasks
"""
Expand Down Expand Up @@ -572,7 +591,7 @@ async def incoming_data_forward(self) -> None:
single_packet_buffer = buffer[:current_packet_size]
pkt = IPPacket().from_buffer_with_pseudoheader(single_packet_buffer)
self._last_sequence_received = pkt.get_segment().seq
self._debug(f'Tecemux/MAIN: [<] Full incomming packet with sequence number {self._last_sequence_received} from Transform Hub was received')
self._debug(f'Tecemux/MAIN: [<] Full incoming packet with sequence number {self._last_sequence_received} from Transform Hub was received')

channel = CC(str(pkt.get_segment().dst_port))

Expand All @@ -591,7 +610,7 @@ async def incoming_data_forward(self) -> None:
except asyncio.CancelledError:
break

self._debug('Tecemux/MAIN: Incomming data forwarder finished')
self._debug('Tecemux/MAIN: Incoming data forwarder finished')

async def outcoming_data_forward(self) -> None:
"""Loop for outcoming data to Transform Hub
Expand Down Expand Up @@ -627,5 +646,4 @@ async def outcoming_data_forward(self) -> None:
except asyncio.CancelledError:
break


self._debug('Tecemux/MAIN: Outcoming data forwarder finished')

0 comments on commit b7ab9ac

Please sign in to comment.