Skip to content

Commit

Permalink
Close communication improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
gierwialo committed Jul 5, 2023
1 parent ecc6786 commit 4e24000
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
43 changes: 20 additions & 23 deletions packages/python-runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import codecs
import json
import logging
# import debugpy
from pyee.asyncio import AsyncIOEventEmitter
from tecemux import Tecemux
import importlib.util
Expand All @@ -21,10 +20,6 @@
SERVER_HOST = os.getenv('INSTANCES_SERVER_HOST') or 'localhost'
INSTANCE_ID = os.getenv('INSTANCE_ID')

# debugpy.listen(5678)
# debugpy.wait_for_client()
# debugpy.breakpoint()

def send_encoded_msg(stream, msg_code, data={}):
message = json.dumps([msg_code.value, data])
stream.write(f'{message}\r\n'.encode())
Expand All @@ -40,11 +35,9 @@ def __init__(self, instance_id, sequence_path, log_setup) -> None:
self.emitter = AsyncIOEventEmitter()
self.keep_alive_requested = False
self.protocol = None
@staticmethod
def is_incoming(channel):
return channel in [CC.STDIN, CC.IN, CC.CONTROL]

async def main(self, server_host, server_port):
asyncio.current_task().set_name('RUNNER_MAIN')
input_stream = Stream()
await self.init_tecemux(server_host, server_port)
# Do this early to have access to any thrown exceptions and logs.
Expand All @@ -59,21 +52,26 @@ async def main(self, server_host, server_port):
connect_input_stream_task = asyncio.create_task(self.connect_input_stream(input_stream))

self.load_sequence()
await self.protocol.sync()

await self.protocol.sync()
await self.run_instance(config, input_stream, args)

await self.protocol.sync()
heartbeat_task.cancel()
connect_input_stream_task.cancel()
await asyncio.gather(*[heartbeat_task])

await asyncio.gather(*[connect_input_stream_task])

await self.protocol.get_channel(CC.IN).sync()
await self.protocol.get_channel(CC.CONTROL).sync()

control_stream_task.cancel()
await asyncio.gather(*[control_stream_task])

await asyncio.gather(*[heartbeat_task,
connect_input_stream_task,
control_stream_task])

await self.protocol.stop()

[ task.cancel() if task.get_name() != 'RUNNER_MAIN' else None for task in asyncio.all_tasks()]


async def init_tecemux(self, server_host, server_port):
self.logger.info('Connecting to host with TeceMux...')
Expand Down Expand Up @@ -152,6 +150,10 @@ async def connect_control_stream(self):
if code == msg_codes.EVENT.value:
await self.emitter.emit(data['eventName'], data['message'] if 'message' in data else None)
except asyncio.CancelledError:
task = self.protocol.get_channel(CC.CONTROL)._outcoming_process_task
task.cancel()
await asyncio.sleep(0)
await asyncio.gather(*[task])
return


Expand Down Expand Up @@ -182,6 +184,7 @@ async def setup_heartbeat(self):
)
await asyncio.sleep(1)
except asyncio.CancelledError:
await self.protocol.get_channel(CC.MONITORING).sync()
return


Expand All @@ -199,8 +202,7 @@ def load_sequence(self):
os.chdir(os.path.dirname(self.seq_path))

async def run_instance(self, config, input, args):
context = AppContext(self, config)
await self.protocol.sync()
context = AppContext(self, config)
self.logger.info('Running instance...')
try:
result = self.sequence.run(context, input, *args)
Expand Down Expand Up @@ -228,13 +230,11 @@ async def run_instance(self, config, input, args):
elif asyncio.iscoroutine(result):
result = await result
if result:
await self.protocol.sync()
await self.forward_output_stream(result)
else:
self.logger.debug('Sequence returned no output.')

self.logger.info('Finished.')
await self.protocol.sync()


async def connect_input_stream(self, input_stream):
Expand All @@ -249,8 +249,6 @@ async def connect_input_stream(self, input_stream):
self.logger.info(f'Input headers: {repr(headers)}')
input_type = headers.get('content-type')

await self.protocol.sync()

if input_type == 'text/plain':
input = Stream.read_from(self.protocol.get_channel(CC.IN))

Expand All @@ -266,7 +264,6 @@ async def connect_input_stream(self, input_stream):
input.pipe(input_stream)
self.logger.debug('Input stream forwarded to the instance.')


async def forward_output_stream(self, output):
if hasattr(output, 'content_type'):
content_type = output.content_type
Expand Down Expand Up @@ -341,4 +338,4 @@ async def keep_alive(self, timeout: int = 0):
sys.exit(2)

runner = Runner(INSTANCE_ID, SEQUENCE_PATH, LOG_SETUP)
asyncio.run(runner.main(SERVER_HOST, SERVER_PORT), debug=False)
asyncio.run(runner.main(SERVER_HOST, SERVER_PORT))
63 changes: 44 additions & 19 deletions packages/python-runner/tecemux.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,33 @@ def __init__(self, channel: CC,
self._stop_channel_event = stop_event
self._sync_channel_event = sync_event
self._sync_barrier = sync_barrier
self._outcoming_process_task = asyncio.create_task(self._outcomming_process_tasks())
self._outcoming_process_task = asyncio.create_task(self._outcomming_process_tasks(), name=f'{self._get_channel_name()}_PROCESS_TASK')
self._ended = None
self._read_buffer = bytearray()

async def sync(self):

if self._is_incoming():
await self._internal_incoming_queue.join()

await self._internal_outcoming_queue.join()



def _is_incoming(self):
return self._channel_enum in [CC.STDIN, CC.IN, CC.CONTROL]

async def _outcomming_process_tasks(self):
while not self._stop_channel_event.is_set():
await self._queue_up_outcoming()
await asyncio.sleep(0)
while True:
try:
await self._queue_up_outcoming()
await asyncio.sleep(0)

if self._stop_channel_event.is_set():
return
except asyncio.CancelledError:
return




Expand Down Expand Up @@ -145,9 +164,15 @@ async def _get_data(self):
self._read_buffer.extend(buf)
break
except asyncio.QueueEmpty:

await asyncio.sleep(0)

if self._ended:
return False

except asyncio.CancelledError:
return False

await asyncio.sleep(0)
return True

Expand Down Expand Up @@ -207,6 +232,7 @@ async def end(self) -> None:

await self._internal_outcoming_queue.join()
await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), data=b'', flags=['PSH'])))
self._ended = True

async def close(self) -> None:
"""Close channel
Expand Down Expand Up @@ -239,7 +265,8 @@ async def queue_up_incoming(self, pkt: IPPacket) -> None:
if pkt.segment.is_flag('FIN'):
self._ended = True
return
self._internal_incoming_queue.put_nowait(pkt.get_segment().data)
await self._internal_incoming_queue.put(pkt.get_segment().data)
self._internal_incoming_queue.task_done()

async def _queue_up_outcoming(self) -> None:
"""Redirects raw data from currect channel to global queue.
Expand All @@ -254,12 +281,14 @@ def wrap(channel_enum, buf):
if not self._channel_paused:
while not self._internal_outcoming_queue.empty():
try:
buf = await self._internal_outcoming_queue.get()
buf = await asyncio.wait_for(self._internal_outcoming_queue.get(),1)
await self._global_queue.put(wrap(self._channel_enum, buf))
self._internal_outcoming_queue.task_done()
except asyncio.QueueEmpty:
self._debug(f'Tecemux/{self._get_channel_name()}: [-] All data stored during pause were redirected to global queue')
break
except asyncio.TimeoutError:
pass
else:
self._debug(f'Tecemux/{self._get_channel_name()}: [-] Channel paused. Data queued up internally for future')

Expand Down Expand Up @@ -451,9 +480,6 @@ async def stop(self) -> None:
await self._finish_channels()
await self._finish_incoming()
await self._finish_outcoming()
await self._read_eof()

self._debug('Tecemux/MAIN: [-] Finished')

async def _finish_channels(self) -> None:
"""Stops protocol
Expand All @@ -464,22 +490,21 @@ async def _finish_channels(self) -> None:

for channel in self.get_channels():
await channel._internal_outcoming_queue.join()

self._global_stop_channel_event.set()
await self._global_stop_channel_event.wait()

tasks = asyncio.gather(*[channel._outcoming_process_task for channel in self.get_channels()])
tasks.cancel()
await tasks


for channel in self.get_channels():
try:
await asyncio.wait_for(channel._outcoming_process_task, timeout=1)
except asyncio.TimeoutError:
pass
async def _finish_outcoming(self):

self._global_stop_outcoming_event.set()
await asyncio.sleep(0)
await self._global_stop_outcoming_event.wait()

await asyncio.gather(*[self._outcoming_data_forwarder])
self._writer.write_eof()
await self._writer.drain()
self._writer.close()
await self._writer.wait_closed()
Expand Down Expand Up @@ -575,6 +600,7 @@ async def outcoming_data_forward(self) -> None:
while True:
try:
pkt = await asyncio.wait_for(self._queue.get(),1)
self._queue.task_done()

# inject sequence number
if pkt.segment.seq == 0:
Expand All @@ -587,7 +613,6 @@ async def outcoming_data_forward(self) -> None:
self._debug(f'Tecemux/MAIN: [>] Outcoming chunk {Tecemux._chunk_preview(chunk)} is waiting to send to Transform Hub')
self._writer.write(chunk)
await self._writer.drain()
self._queue.task_done()
self._debug(f'Tecemux/MAIN: [>] Chunk {Tecemux._chunk_preview(chunk)} with sequence number: {pkt.segment.seq} was sent to Transform Hub')
except asyncio.QueueEmpty:
if self._global_stop_outcoming_event.is_set():
Expand Down

0 comments on commit 4e24000

Please sign in to comment.