Skip to content

Commit

Permalink
fix: race condition in PruningThreadPoolExecutor (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Jun 18, 2023
1 parent 40b5895 commit 648b4ef
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions dank_mids/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ def _worker(executor_reference, work_queue, initializer, initargs, timeout): #
except queue.Empty: # NOTE: NEW
# Its been 'timeout' seconds and there are no new work items. # NOTE: NEW
executor = executor_reference() # NOTE: NEW
t = threading.current_thread() # NOTE: NEW
executor._threads.remove(t) # NOTE: NEW
thread._threads_queues.pop(t) # NOTE: NEW

with executor._adjusting_lock: # NOTE: NEW
t = threading.current_thread() # NOTE: NEW
executor._threads.remove(t) # NOTE: NEW
thread._threads_queues.pop(t) # NOTE: NEW
# Let the executor know we have one less idle thread available
executor._idle_semaphore.acquire(blocking=False) # NOTE: NEW

return # NOTE: NEW

if work_item is not None:
Expand Down Expand Up @@ -66,6 +71,7 @@ class PruningThreadPoolExecutor(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), timeout=TEN_MINUTES):
self._timeout=timeout
self._adjusting_lock = threading.Lock()
super().__init__(max_workers, thread_name_prefix, initializer, initargs)

def __repr__(self) -> str:
Expand All @@ -75,28 +81,29 @@ def __len__(self) -> int:
return len(self._threads)

def _adjust_thread_count(self):
# if idle threads are available, don't spin new threads
if self._idle_semaphore.acquire(timeout=0):
return
with self._adjusting_lock:
# if idle threads are available, don't spin new threads
if self._idle_semaphore.acquire(timeout=0):
return

# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)

num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = '%s_%d' % (self._thread_name_prefix or self,
num_threads)
t = threading.Thread(name=thread_name, target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs,
self._timeout))
t.daemon = True
t.start()
self._threads.add(t)
thread._threads_queues[t] = self._work_queue
num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = '%s_%d' % (self._thread_name_prefix or self,
num_threads)
t = threading.Thread(name=thread_name, target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs,
self._timeout))
t.daemon = True
t.start()
self._threads.add(t)
thread._threads_queues[t] = self._work_queue

executor = PruningThreadPoolExecutor(128)

0 comments on commit 648b4ef

Please sign in to comment.