From 7a571acc5b80c2d2fe13b2618251bdb946f74ac9 Mon Sep 17 00:00:00 2001 From: Min RK Date: Fri, 16 Jul 2021 15:18:10 +0200 Subject: [PATCH] pause receiving while submitting tasks results coming in during submission causes thread contention while submitting tasks pause receiving messages while we are preparing tasks to be submitted --- ipyparallel/client/client.py | 83 +++++++++++++++++++++++++++++++----- ipyparallel/client/view.py | 44 +++++++++++++------ 2 files changed, 103 insertions(+), 24 deletions(-) diff --git a/ipyparallel/client/client.py b/ipyparallel/client/client.py index adc8a2b4d..e07b15750 100644 --- a/ipyparallel/client/client.py +++ b/ipyparallel/client/client.py @@ -12,6 +12,7 @@ import warnings from collections.abc import Iterable from concurrent.futures import Future +from contextlib import contextmanager from getpass import getpass from pprint import pprint from threading import current_thread @@ -990,21 +991,59 @@ def _stop_io_thread(self): self._io_thread.join() def _setup_streams(self): - self._query_stream = ZMQStream(self._query_socket, self._io_loop) - self._query_stream.on_recv(self._dispatch_single_reply, copy=False) - self._control_stream = ZMQStream(self._control_socket, self._io_loop) + self._streams = [] # all streams + self._engine_streams = [] # streams that talk to engines + self._query_stream = s = ZMQStream(self._query_socket, self._io_loop) + self._streams.append(s) + self._notification_stream = s = ZMQStream( + self._notification_socket, self._io_loop + ) + self._streams.append(s) + + self._control_stream = s = ZMQStream(self._control_socket, self._io_loop) + self._streams.append(s) + self._engine_streams.append(s) + self._mux_stream = s = ZMQStream(self._mux_socket, self._io_loop) + self._streams.append(s) + self._engine_streams.append(s) + self._task_stream = s = ZMQStream(self._task_socket, self._io_loop) + self._streams.append(s) + self._engine_streams.append(s) + self._broadcast_stream = s = ZMQStream(self._broadcast_socket, self._io_loop) + self._streams.append(s) + self._engine_streams.append(s) + self._iopub_stream = s = ZMQStream(self._iopub_socket, self._io_loop) + self._streams.append(s) + self._engine_streams.append(s) + self._start_receiving(all=True) + + def _start_receiving(self, all=False): + """Start receiving on streams + + default: only engine streams + + if all: include hub streams + """ + if all: + self._query_stream.on_recv(self._dispatch_single_reply, copy=False) + self._notification_stream.on_recv(self._dispatch_notification, copy=False) self._control_stream.on_recv(self._dispatch_single_reply, copy=False) - self._mux_stream = ZMQStream(self._mux_socket, self._io_loop) self._mux_stream.on_recv(self._dispatch_reply, copy=False) - self._task_stream = ZMQStream(self._task_socket, self._io_loop) self._task_stream.on_recv(self._dispatch_reply, copy=False) - self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop) + self._broadcast_stream.on_recv(self._dispatch_reply, copy=False) self._iopub_stream.on_recv(self._dispatch_iopub, copy=False) - self._notification_stream = ZMQStream(self._notification_socket, self._io_loop) - self._notification_stream.on_recv(self._dispatch_notification, copy=False) - self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop) - self._broadcast_stream.on_recv(self._dispatch_reply, copy=False) + def _stop_receiving(self, all=False): + """Stop receiving on engine streams + + If all: include hub streams + """ + if all: + streams = self._streams + else: + streams = self._engine_streams + for s in streams: + s.stop_on_recv() def _start_io_thread(self): """Start IOLoop in a background thread.""" @@ -1034,6 +1073,30 @@ def _io_main(self, start_evt=None): self._io_loop.start() self._io_loop.close() + @contextmanager + def _pause_results(self): + """Context manager to pause receiving results + + When submitting lots of tasks, + the arrival of results can disrupt the processing + of new submissions. + + Threadsafe. + """ + f = Future() + + def _stop(): + self._stop_receiving() + f.set_result(None) + + # use add_callback to make it threadsafe + self._io_loop.add_callback(_stop) + f.result() + try: + yield + finally: + self._io_loop.add_callback(self._start_receiving) + @unpack_message def _dispatch_single_reply(self, msg): """Dispatch single (non-execution) replies""" diff --git a/ipyparallel/client/view.py b/ipyparallel/client/view.py index 5fac07f4c..199f535de 100644 --- a/ipyparallel/client/view.py +++ b/ipyparallel/client/view.py @@ -578,11 +578,12 @@ def _really_apply( pargs = [PrePickled(arg) for arg in args] pkwargs = {k: PrePickled(v) for k, v in kwargs.items()} - for ident in _idents: - future = self.client.send_apply_request( - self._socket, pf, pargs, pkwargs, track=track, ident=ident - ) - futures.append(future) + with self.client._pause_results(): + for ident in _idents: + future = self.client.send_apply_request( + self._socket, pf, pargs, pkwargs, track=track, ident=ident + ) + futures.append(future) if track: trackers = [_.tracker for _ in futures] else: @@ -641,9 +642,16 @@ def map(self, f, *sequences, block=None, track=False, return_exceptions=False): assert len(sequences) > 0, "must have some sequences to map onto!" pf = ParallelFunction( - self, f, block=block, track=track, return_exceptions=return_exceptions + self, f, block=False, track=track, return_exceptions=return_exceptions ) - return pf.map(*sequences) + with self.client._pause_results(): + ar = pf.map(*sequences) + if block: + try: + return ar.get() + except KeyboardInterrupt: + return ar + return ar @sync_results @save_ids @@ -665,11 +673,12 @@ def execute(self, code, silent=True, targets=None, block=None): _idents, _targets = self.client._build_targets(targets) futures = [] - for ident in _idents: - future = self.client.send_execute_request( - self._socket, code, silent=silent, ident=ident - ) - futures.append(future) + with self.client._pause_results(): + for ident in _idents: + future = self.client.send_execute_request( + self._socket, code, silent=silent, ident=ident + ) + futures.append(future) if isinstance(targets, int): futures = futures[0] ar = AsyncResult( @@ -1292,12 +1301,19 @@ def map( pf = ParallelFunction( self, f, - block=block, + block=False, chunksize=chunksize, ordered=ordered, return_exceptions=return_exceptions, ) - return pf.map(*sequences) + with self.client._pause_results(): + ar = pf.map(*sequences) + if block: + try: + return ar.get() + except KeyboardInterrupt: + return ar + return ar def imap( self,