From 908b8a94f05015a1dfbb9e2758bdbfb25df1bbc0 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Tue, 17 Oct 2023 14:26:51 -0700 Subject: [PATCH] [local-worker-mgr] Force using 'spawn' for process creation It turns out tflite conversion uses under the hood numexpr. Numexpr spawns a bunch of threads at module init. So if the main process converts a model, then forks and the resulting processes try converting a model, too, the state the forked processes work on appears like having been initialized. They hang waiting for results. Using `spawn` and not `forkserver` because the latter still runs this risk. Spawning may be slower, but in the grand schema of things (training time), slightly larger init time is not that problematic. --- compiler_opt/distributed/local/local_worker_manager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index dcf60448..f1838c9b 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -60,6 +60,10 @@ class TaskResult: value: Any +def _get_context(): + return multiprocessing.get_context('spawn') + + SerializedClass = bytes @@ -74,6 +78,7 @@ def _run_impl(pipe: connection.Connection, worker_class: SerializedClass, *args, # jobs, this effectively limits the number of clang instances spawned. pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) obj = cloudpickle.loads(worker_class)(*args, **kwargs) + obj = cloudpickle.loads(worker_class)(*args, **kwargs) # Pipes are not thread safe pipe_lock = threading.Lock() @@ -122,7 +127,7 @@ class _Stub: """Client stub to a worker hosted by a process.""" def __init__(self): - parent_pipe, child_pipe = multiprocessing.get_context().Pipe() + parent_pipe, child_pipe = _get_context().Pipe() self._pipe = parent_pipe self._pipe_lock = threading.Lock() @@ -130,7 +135,7 @@ def __init__(self): # we set aside 1 thread to coordinate running jobs, and the main thread # to handle high priority requests. The expectation is that the user # achieves concurrency through multiprocessing, not multithreading. - self._process = multiprocessing.get_context().Process( + self._process = _get_context().Process( target=functools.partial(_run, child_pipe, cloudpickle.dumps(cls), * args, **kwargs)) # lock for the msgid -> reply future map. The map will be set to None