From 4e2400088204737586e842bf12bb81c5aa45c29e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Gierwia=C5=82o?= Date: Wed, 5 Jul 2023 11:14:59 +0200 Subject: [PATCH] Close communication improvements --- packages/python-runner/runner.py | 43 ++++++++++----------- packages/python-runner/tecemux.py | 63 +++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/packages/python-runner/runner.py b/packages/python-runner/runner.py index 37745f999..5bdabdc90 100644 --- a/packages/python-runner/runner.py +++ b/packages/python-runner/runner.py @@ -4,7 +4,6 @@ import codecs import json import logging -# import debugpy from pyee.asyncio import AsyncIOEventEmitter from tecemux import Tecemux import importlib.util @@ -21,10 +20,6 @@ SERVER_HOST = os.getenv('INSTANCES_SERVER_HOST') or 'localhost' INSTANCE_ID = os.getenv('INSTANCE_ID') -# debugpy.listen(5678) -# debugpy.wait_for_client() -# debugpy.breakpoint() - def send_encoded_msg(stream, msg_code, data={}): message = json.dumps([msg_code.value, data]) stream.write(f'{message}\r\n'.encode()) @@ -40,11 +35,9 @@ def __init__(self, instance_id, sequence_path, log_setup) -> None: self.emitter = AsyncIOEventEmitter() self.keep_alive_requested = False self.protocol = None - @staticmethod - def is_incoming(channel): - return channel in [CC.STDIN, CC.IN, CC.CONTROL] async def main(self, server_host, server_port): + asyncio.current_task().set_name('RUNNER_MAIN') input_stream = Stream() await self.init_tecemux(server_host, server_port) # Do this early to have access to any thrown exceptions and logs. @@ -59,21 +52,26 @@ async def main(self, server_host, server_port): connect_input_stream_task = asyncio.create_task(self.connect_input_stream(input_stream)) self.load_sequence() - - await self.protocol.sync() + + await self.protocol.sync() await self.run_instance(config, input_stream, args) await self.protocol.sync() heartbeat_task.cancel() - connect_input_stream_task.cancel() + await asyncio.gather(*[heartbeat_task]) + + await asyncio.gather(*[connect_input_stream_task]) + + await self.protocol.get_channel(CC.IN).sync() + await self.protocol.get_channel(CC.CONTROL).sync() + control_stream_task.cancel() + await asyncio.gather(*[control_stream_task]) - await asyncio.gather(*[heartbeat_task, - connect_input_stream_task, - control_stream_task]) - await self.protocol.stop() + [ task.cancel() if task.get_name() != 'RUNNER_MAIN' else None for task in asyncio.all_tasks()] + async def init_tecemux(self, server_host, server_port): self.logger.info('Connecting to host with TeceMux...') @@ -152,6 +150,10 @@ async def connect_control_stream(self): if code == msg_codes.EVENT.value: await self.emitter.emit(data['eventName'], data['message'] if 'message' in data else None) except asyncio.CancelledError: + task = self.protocol.get_channel(CC.CONTROL)._outcoming_process_task + task.cancel() + await asyncio.sleep(0) + await asyncio.gather(*[task]) return @@ -182,6 +184,7 @@ async def setup_heartbeat(self): ) await asyncio.sleep(1) except asyncio.CancelledError: + await self.protocol.get_channel(CC.MONITORING).sync() return @@ -199,8 +202,7 @@ def load_sequence(self): os.chdir(os.path.dirname(self.seq_path)) async def run_instance(self, config, input, args): - context = AppContext(self, config) - await self.protocol.sync() + context = AppContext(self, config) self.logger.info('Running instance...') try: result = self.sequence.run(context, input, *args) @@ -228,13 +230,11 @@ async def run_instance(self, config, input, args): elif asyncio.iscoroutine(result): result = await result if result: - await self.protocol.sync() await self.forward_output_stream(result) else: self.logger.debug('Sequence returned no output.') self.logger.info('Finished.') - await self.protocol.sync() async def connect_input_stream(self, input_stream): @@ -249,8 +249,6 @@ async def connect_input_stream(self, input_stream): self.logger.info(f'Input headers: {repr(headers)}') input_type = headers.get('content-type') - await self.protocol.sync() - if input_type == 'text/plain': input = Stream.read_from(self.protocol.get_channel(CC.IN)) @@ -266,7 +264,6 @@ async def connect_input_stream(self, input_stream): input.pipe(input_stream) self.logger.debug('Input stream forwarded to the instance.') - async def forward_output_stream(self, output): if hasattr(output, 'content_type'): content_type = output.content_type @@ -341,4 +338,4 @@ async def keep_alive(self, timeout: int = 0): sys.exit(2) runner = Runner(INSTANCE_ID, SEQUENCE_PATH, LOG_SETUP) -asyncio.run(runner.main(SERVER_HOST, SERVER_PORT), debug=False) +asyncio.run(runner.main(SERVER_HOST, SERVER_PORT)) diff --git a/packages/python-runner/tecemux.py b/packages/python-runner/tecemux.py index 9a7d174ae..90405a78d 100644 --- a/packages/python-runner/tecemux.py +++ b/packages/python-runner/tecemux.py @@ -61,14 +61,33 @@ def __init__(self, channel: CC, self._stop_channel_event = stop_event self._sync_channel_event = sync_event self._sync_barrier = sync_barrier - self._outcoming_process_task = asyncio.create_task(self._outcomming_process_tasks()) + self._outcoming_process_task = asyncio.create_task(self._outcomming_process_tasks(), name=f'{self._get_channel_name()}_PROCESS_TASK') self._ended = None self._read_buffer = bytearray() + async def sync(self): + + if self._is_incoming(): + await self._internal_incoming_queue.join() + + await self._internal_outcoming_queue.join() + + + + def _is_incoming(self): + return self._channel_enum in [CC.STDIN, CC.IN, CC.CONTROL] + async def _outcomming_process_tasks(self): - while not self._stop_channel_event.is_set(): - await self._queue_up_outcoming() - await asyncio.sleep(0) + while True: + try: + await self._queue_up_outcoming() + await asyncio.sleep(0) + + if self._stop_channel_event.is_set(): + return + except asyncio.CancelledError: + return + @@ -145,9 +164,15 @@ async def _get_data(self): self._read_buffer.extend(buf) break except asyncio.QueueEmpty: + await asyncio.sleep(0) + + if self._ended: + return False + except asyncio.CancelledError: return False + await asyncio.sleep(0) return True @@ -207,6 +232,7 @@ async def end(self) -> None: await self._internal_outcoming_queue.join() await self._global_queue.put(IPPacket(segment=TCPSegment(dst_port=self._get_channel_id(), data=b'', flags=['PSH']))) + self._ended = True async def close(self) -> None: """Close channel @@ -239,7 +265,8 @@ async def queue_up_incoming(self, pkt: IPPacket) -> None: if pkt.segment.is_flag('FIN'): self._ended = True return - self._internal_incoming_queue.put_nowait(pkt.get_segment().data) + await self._internal_incoming_queue.put(pkt.get_segment().data) + self._internal_incoming_queue.task_done() async def _queue_up_outcoming(self) -> None: """Redirects raw data from currect channel to global queue. @@ -254,12 +281,14 @@ def wrap(channel_enum, buf): if not self._channel_paused: while not self._internal_outcoming_queue.empty(): try: - buf = await self._internal_outcoming_queue.get() + buf = await asyncio.wait_for(self._internal_outcoming_queue.get(),1) await self._global_queue.put(wrap(self._channel_enum, buf)) self._internal_outcoming_queue.task_done() except asyncio.QueueEmpty: self._debug(f'Tecemux/{self._get_channel_name()}: [-] All data stored during pause were redirected to global queue') break + except asyncio.TimeoutError: + pass else: self._debug(f'Tecemux/{self._get_channel_name()}: [-] Channel paused. Data queued up internally for future') @@ -451,9 +480,6 @@ async def stop(self) -> None: await self._finish_channels() await self._finish_incoming() await self._finish_outcoming() - await self._read_eof() - - self._debug('Tecemux/MAIN: [-] Finished') async def _finish_channels(self) -> None: """Stops protocol @@ -464,22 +490,21 @@ async def _finish_channels(self) -> None: for channel in self.get_channels(): await channel._internal_outcoming_queue.join() - + self._global_stop_channel_event.set() await self._global_stop_channel_event.wait() - - tasks = asyncio.gather(*[channel._outcoming_process_task for channel in self.get_channels()]) - tasks.cancel() - await tasks - - + for channel in self.get_channels(): + try: + await asyncio.wait_for(channel._outcoming_process_task, timeout=1) + except asyncio.TimeoutError: + pass + async def _finish_outcoming(self): self._global_stop_outcoming_event.set() + await asyncio.sleep(0) await self._global_stop_outcoming_event.wait() - await asyncio.gather(*[self._outcoming_data_forwarder]) - self._writer.write_eof() await self._writer.drain() self._writer.close() await self._writer.wait_closed() @@ -575,6 +600,7 @@ async def outcoming_data_forward(self) -> None: while True: try: pkt = await asyncio.wait_for(self._queue.get(),1) + self._queue.task_done() # inject sequence number if pkt.segment.seq == 0: @@ -587,7 +613,6 @@ async def outcoming_data_forward(self) -> None: self._debug(f'Tecemux/MAIN: [>] Outcoming chunk {Tecemux._chunk_preview(chunk)} is waiting to send to Transform Hub') self._writer.write(chunk) await self._writer.drain() - self._queue.task_done() self._debug(f'Tecemux/MAIN: [>] Chunk {Tecemux._chunk_preview(chunk)} with sequence number: {pkt.segment.seq} was sent to Transform Hub') except asyncio.QueueEmpty: if self._global_stop_outcoming_event.is_set():