Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[parallel] If not num_threads, use queue instead of tail recursion #306

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions bin/fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def _run(self, bar):
return False
localbar.done()
else:
if not target_ansfile.is_file():
bar.error(f'{self.i}: {ansfile.name} does not exist and was not generated.')
return False
bar.error(f'{self.i}: {ansfile.name} was not generated.')
return False
else:
if not testcase.ans_path.is_file():
testcase.ans_path.write_text('')
Expand Down Expand Up @@ -187,27 +186,18 @@ def run(self):
self.start_time = time.monotonic()
self.iteration = 0
self.tasks = 0
self.queue = parallel.Parallel(lambda task: task.run(bar), pin=True)

if self.queue.num_threads:
# pool of ids used for generators
self.tmp_ids = 2 * max(1, self.queue.num_threads) + 1
self.free_tmp_id = {*range(self.tmp_ids)}
self.tmp_id_count = [0] * self.tmp_ids
self.queue = parallel.new_queue(lambda task: task.run(bar), pin=True)

# add first generator task
self.finish_task()
# pool of ids used for generators
self.tmp_ids = 2 * max(1, self.queue.num_threads) + 1
self.free_tmp_id = {*range(self.tmp_ids)}
self.tmp_id_count = [0] * self.tmp_ids

# wait for the queue to run empty (after config.args.time)
self.queue.join()
else:
self.tmp_ids = -1
while time.monotonic() - self.start_time <= config.args.time:
testcase_rule = self.testcase_rules[self.iteration % len(self.testcase_rules)]
self.iteration += 1
self.queue.put(GeneratorTask(self, testcase_rule, self.iteration, None))
self.queue.join()
# add first generator task
self.finish_task()

# wait for the queue to run empty (after config.args.time)
self.queue.join()
# At this point, no new tasks may be started anymore.
self.queue.done()
bar.done()
Expand Down
6 changes: 3 additions & 3 deletions bin/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def build_program(p):
p.build(localbar)
localbar.done()

p = parallel.Parallel(build_program)
p = parallel.new_queue(build_program)
for pr in programs:
p.put(pr)
p.done()
Expand Down Expand Up @@ -1678,7 +1678,7 @@ def count_dir(d):
# after to deduplicate them against generated testcases.

# 1
p = parallel.Parallel(lambda t: t.listed and t.generate(self.problem, self, bar))
p = parallel.new_queue(lambda t: t.listed and t.generate(self.problem, self, bar))

def generate_dir(d):
p.join()
Expand All @@ -1688,7 +1688,7 @@ def generate_dir(d):
p.done()

# 2
p = parallel.Parallel(lambda t: not t.listed and t.generate(self.problem, self, bar))
p = parallel.new_queue(lambda t: not t.listed and t.generate(self.problem, self, bar))

def generate_dir_unlisted(d):
p.join()
Expand Down
195 changes: 117 additions & 78 deletions bin/parallel.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/usr/bin/env python3
import threading
import signal
import heapq

import os
import signal
import threading

import config
import util


class ParallelItem:
class QueueItem:
def __init__(self, task, priority, id):
self.task = task
self.priority = priority
Expand All @@ -26,74 +25,119 @@ def __lt__(self, other):
return self.id < other.id


class Parallel:
# f(task): the function to run on each queue item.
# num_threads: True: the configured default
# None/False/0: disable parallelization
def __init__(self, f, num_threads=True, pin=False):
class AbstractQueue:
def __init__(self, f, pin):
self.f = f
self.num_threads = config.args.jobs if num_threads is True else num_threads
self.pin = pin and not util.is_windows() and not util.is_bsd()
self.pin = pin
self.num_threads = 1

# min heap
self.tasks: list[QueueItem] = []
self.total_tasks = 0
self.missing = 0

self.aborted = False

# mutex to lock parallel access
self.mutex = threading.RLock()

def __enter__(self):
self.mutex.__enter__()

def __exit__(self, *args):
self.mutex.__exit__(*args)

# Add one task. Higher priority => done first
def put(self, task, priority=0):
raise "Abstract method"

# By default, do nothing on .join(). This is overridden in ParallelQueue.
def join(self):
return

def done(self):
raise "Abstract method"

def abort(self):
self.aborted = True


class SequentialQueue(AbstractQueue):
def __init__(self, f, pin):
super().__init__(f, pin)

# Add one task. Higher priority => done first
def put(self, task, priority=0):
# no task will be handled after self.abort() so skip adding
if self.aborted:
return

self.total_tasks += 1
heapq.heappush(self.tasks, QueueItem(task, priority, self.total_tasks))

# Execute all tasks.
def done(self):
if self.pin:
RagnarGrootKoerkamp marked this conversation as resolved.
Show resolved Hide resolved
cores = list(os.sched_getaffinity(0))
os.sched_setaffinity(0, {cores[0]})

# no task will be handled after self.abort()
while self.tasks and not self.aborted:
self.f(heapq.heappop(self.tasks).task)

if self.pin:
os.sched_setaffinity(0, cores)


class ParallelQueue(AbstractQueue):
def __init__(self, f, pin, num_threads):
super().__init__(f, pin)

assert num_threads and type(num_threads) is int
self.num_threads = num_threads

# condition used to notify worker if the queue has changed
self.todo = threading.Condition(self.mutex)
# condition used to notify join that the queue is empty
self.all_done = threading.Condition(self.mutex)

# only used in parallel mode
self.first_error = None
# min heap
self.tasks = []
self.total_tasks = 0
self.missing = 0

# also used if num_threads is false
self.abort = False
self.finish = False

if self.num_threads:
if self.pin:
# only use available cores and reserve one
cores = list(os.sched_getaffinity(0))
if self.num_threads > len(cores) - 1:
self.num_threads = len(cores) - 1

# sort cores by id. If num_threads << len(cores) this ensures that we
# use different physical cores instead of hyperthreads
cores.sort()

self.threads = []
for i in range(self.num_threads):
args = [{cores[i]}] if self.pin else []
t = threading.Thread(target=self._worker, args=args, daemon=True)
t.start()
self.threads.append(t)
if self.pin:
# only use available cores and reserve one
cores = list(os.sched_getaffinity(0))
if self.num_threads > len(cores) - 1:
self.num_threads = len(cores) - 1

signal.signal(signal.SIGINT, self._interrupt_handler)
# sort cores by id. If num_threads << len(cores) this ensures that we
# use different physical cores instead of hyperthreads
cores.sort()

def __enter__(self):
self.mutex.__enter__()
self.threads = []
for i in range(self.num_threads):
args = [{cores[i]}] if self.pin else []
t = threading.Thread(target=self._worker, args=args, daemon=True)
t.start()
self.threads.append(t)

def __exit__(self, *args):
self.mutex.__exit__(*args)
signal.signal(signal.SIGINT, self._interrupt_handler)

def _worker(self, cores=False):
def _worker(self, cores: bool | list[int] = False):
if cores is not False:
os.sched_setaffinity(0, cores)
while True:
with self.mutex:
# if self.abort we need no item in the queue and can stop
# if self.aborted we need no item in the queue and can stop
# if self.finish we may need to wake up if all tasks were completed earlier
# else we need an item to handle
self.todo.wait_for(lambda: len(self.tasks) > 0 or self.abort or self.finish)
self.todo.wait_for(lambda: len(self.tasks) > 0 or self.aborted or self.finish)

if self.abort:
# we dont handle the queue on abort
if self.aborted:
# we don't handle the queue if self.aborted
break
elif self.finish and len(self.tasks) == 0:
# on finish we can only stop after the queue runs empty
# if self.finish, we can only stop after the queue runs empty
break
else:
# get item from queue (update self.missing after the task is done)
Expand All @@ -105,7 +149,7 @@ def _worker(self, cores=False):
current_error = None
self.f(task)
except Exception as e:
self.stop()
self.abort()
current_error = e

with self.mutex:
Expand All @@ -121,35 +165,18 @@ def _interrupt_handler(self, sig, frame):

# Add one task. Higher priority => done first
def put(self, task, priority=0):
if not self.num_threads:
# no task should be added after .done() was called
assert not self.finish
# no task will be handled after self.abort
if not self.abort:
if self.pin:
cores = list(os.sched_getaffinity(0))
os.sched_setaffinity(0, {cores[0]})
self.f(task)
os.sched_setaffinity(0, cores)
else:
self.f(task)
return

with self.mutex:
# no task should be added after .done() was called
assert not self.finish
# no task will be handled after self.abort so skip adding
if not self.abort:
# no task will be handled after self.aborted so skip adding
if not self.aborted:
# mark task as to be done and notify workers
self.missing += 1
self.total_tasks += 1
heapq.heappush(self.tasks, ParallelItem(task, priority, self.total_tasks))
heapq.heappush(self.tasks, QueueItem(task, priority, self.total_tasks))
self.todo.notify()

def join(self):
if not self.num_threads:
return

# wait for all current task to be completed
with self.all_done:
self.all_done.wait_for(lambda: self.missing == 0)
Expand All @@ -160,10 +187,7 @@ def join(self):
def done(self):
self.finish = True
mpsijm marked this conversation as resolved.
Show resolved Hide resolved

if not self.num_threads:
return

# notify all workes with permission to leave main loop
# notify all workers with permission to leave main loop
with self.todo:
self.todo.notify_all()

Expand All @@ -172,17 +196,14 @@ def done(self):
t.join()

# mutex is no longer needed
# report first error occured during execution
# report first error occurred during execution
if self.first_error is not None:
raise self.first_error

# Discard all remaining work in the queue and stop all workers.
# Call done() to join the threads.
def stop(self):
self.abort = True

if not self.num_threads:
return
def abort(self):
super().abort()

with self.mutex:
# drop all items in the queue at once
Expand All @@ -193,3 +214,21 @@ def stop(self):
# notify .join() if queue runs empty
if self.missing == 0:
self.all_done.notify_all()


def new_queue(f, pin=False, num_threads=True):
"""
f(task): the function to run on each queue item.

num_threads: True: the configured default
None/False/0: disable parallelization

pin: whether to pin the threads to (physical) CPU cores.
"""
num_threads = config.args.jobs if num_threads is True else num_threads
pin = pin and not util.is_windows() and not util.is_bsd()

if num_threads:
return ParallelQueue(f, pin, num_threads)
else:
return SequentialQueue(f, pin)
4 changes: 2 additions & 2 deletions bin/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def build_program(p):
p.build(localbar)
localbar.done()

p = parallel.Parallel(build_program)
p = parallel.new_queue(build_program)
for pr in programs:
p.put(pr)
p.done()
Expand Down Expand Up @@ -465,7 +465,7 @@ def build_program(p):
build_ok &= p.build(localbar)
localbar.done()

p = parallel.Parallel(build_program)
p = parallel.new_queue(build_program)
for pr in validators:
p.put(pr)
p.done()
Expand Down
4 changes: 2 additions & 2 deletions bin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ def process_run(run, p):
return

bar.count = None
p.stop()
p.abort()

p = parallel.Parallel(lambda run: process_run(run, p), pin=True)
p = parallel.new_queue(lambda run: process_run(run, p), pin=True)

for run in runs:
p.put(run)
Expand Down
Loading