Skip to content

Commit

Permalink
neaten up multiprocess_map, add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Aug 28, 2024
1 parent cc59094 commit 94c280c
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 47 deletions.
1 change: 1 addition & 0 deletions countess/core/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main() -> None:
run(sys.argv[1:])

logging_handler.stop()
logging_queue.close()


if __name__ == "__main__":
Expand Down
122 changes: 77 additions & 45 deletions countess/utils/parallel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import gc
import logging
import threading
import time
from multiprocessing import Process, Queue, Value
from os import cpu_count, getpid
from queue import Empty
from queue import Empty, Full
import time
from typing import Callable, Iterable

try:
Expand All @@ -22,6 +21,37 @@
logger = logging.getLogger(__name__)


class IterableMultiprocessQueue:
"""This connects a multiprocessing.Queue with a multiprocessing.Value
and gives us a queue that multiple reader processes can iterate over and
they'll each get a StopIteration when the Queue is both finished *and*
empty."""

def __init__(self, maxsize=3):
self.queue = Queue(maxsize=maxsize)
self.finished = Value("b", False)

def put(self, value, timeout=None):
if self.finished.value:
raise ValueError("IterableMultiprocessQueue Stopped")
self.queue.put(value, timeout=timeout)

def close(self):
self.finished.value = True
self.queue.close()

def __iter__(self):
return self

def __next__(self):
while True:
try:
return self.queue.get(timeout=0.1)
except Empty as exc:
if self.finished.value:
raise StopIteration from exc


def multiprocess_map(
function: Callable[Concatenate[V, P], Iterable[D]], values: Iterable[V], *args: P.args, **kwargs: P.kwargs
) -> Iterable[D]:
Expand All @@ -37,58 +67,60 @@ def multiprocess_map(function, values, *args, **kwargs):

# Start up several workers.
nproc = ((cpu_count() or 1) + 1) // 2
input_queue: Queue = Queue()
input_queue = IterableMultiprocessQueue(maxsize=nproc)
output_queue: Queue = Queue(maxsize=3)

# XXX is it actually necessary to have this in a separate thread or
# would it be sufficient to add items to input_queue alternately with
# removing items from output_queue?

enqueue_running = Value("b", True)

def __enqueue():
for v in values:
input_queue.put(v)
enqueue_running.value = False
def __process():
for data_in in input_queue:
for data_out in function(data_in, *args, **(kwargs or {})):
output_queue.put(data_out)

# Make sure large data is disposed of before we
# go around for the next loop
del data_out
del data_in
gc.collect()

thread = threading.Thread(target=__enqueue)
thread.start()
# Prevent processes from using up all available memory while waiting
# XXX this is probably a bad idea
while psutil.virtual_memory().percent > 90:
logger.warning("PID %d LOW MEMORY %f%%", getpid(), psutil.virtual_memory().percent)
time.sleep(1)

# XXX is this necessary?
time.sleep(1)
processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)]
for p in processes:
p.start()

def __process():
# push each of the input values onto the input_queue, if it gets full
# then also try to drain the output_queue.
for v in values:
while True:
try:
while True:
# Prevent processes from using up all
# available memory while waiting
# XXX this is probably a bad idea
while psutil.virtual_memory().percent > 90:
logger.warning("PID %d LOW MEMORY alert %f%%", getpid(), psutil.virtual_memory().percent)
time.sleep(1)

data_in = input_queue.get(timeout=1)
for data_out in function(data_in, *args, **(kwargs or {})):
output_queue.put(data_out)

# Make sure large data is disposed of before we
# go around for the next loop
del data_out
del data_in
input_queue.put(v, timeout=0.1)
break
except Full:
try:
yield output_queue.get(timeout=0.1)
except Empty:
# Waiting for the next output, might as well tidy up
gc.collect()

except Empty:
if not enqueue_running.value:
break
# we're finished with input values, so close the input_queue to
# signal to all the processes that there will be no new entries
# and once the queue is empty they can finish.
input_queue.close()

processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)]
for p in processes:
p.start()

while thread.is_alive() or any(p.is_alive() for p in processes):
# wait for all processes to finish and yield any data which appears
# on the output_queue
while any(p.is_alive() for p in processes):
try:
yield output_queue.get(timeout=0.1)
while True:
yield output_queue.get(timeout=0.1)
except Empty:
# Waiting for the next input, might as well tidy up
# Waiting for the next output, might as well tidy up
gc.collect()

# once all processes have finished, we can clean up the queue.
for p in processes:
p.join()
output_queue.close()
1 change: 1 addition & 0 deletions tests/gui/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_main():

root.update()

wrap.destroy()
root.destroy()


Expand Down
3 changes: 2 additions & 1 deletion tests/plugins/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_save_csv_bz2():
df2 = pd.DataFrame([[10, 11, 12]], columns=["a", "b", "d"])


def test_save_csv_multi():
def test_save_csv_multi(caplog):
plugin = SaveCsvPlugin()
plugin.set_parameter("header", True)
plugin.set_parameter("filename", "tests/output2.csv")
Expand All @@ -154,3 +154,4 @@ def test_save_csv_multi():
with open("tests/output2.csv", "r", encoding="utf-8") as fh:
text = fh.read()
assert text == "a,b,c\n1,2,3\n4,5,6\n7,8,9\n10,11,,12\n"
assert "Added CSV Column" in caplog.text
96 changes: 95 additions & 1 deletion tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import importlib.metadata
from unittest.mock import patch

import pandas as pd
import pytest

import countess
from countess.core.plugins import BasePlugin, get_plugin_classes
from countess.core.plugins import BasePlugin, FileInputPlugin, PandasProductPlugin, get_plugin_classes

empty_entry_points_dict = {"countess_plugins": []}

Expand Down Expand Up @@ -31,3 +34,94 @@ def test_get_plugin_classes_wrongparent(caplog):
with patch("importlib.metadata.EntryPoint.load", lambda x: NoParentPlugin):
get_plugin_classes()
assert "not a valid CountESS plugin" in caplog.text


class PPP(PandasProductPlugin):
version = "0"

def process_dataframes(self, dataframe1, dataframe2):
return dataframe1 + dataframe2


def test_product_plugin():
ppp = PPP()

df1 = pd.DataFrame([{"a": 1}])
df2 = pd.DataFrame([{"a": 2}])
df3 = pd.DataFrame([{"a": 4}])
df4 = pd.DataFrame([{"a": 8}])
df5 = pd.DataFrame([{"a": 16}])

ppp.prepare(["source1", "source2"])

dfs = list(ppp.process(df1, "source1"))
assert len(dfs) == 0

dfs = list(ppp.process(df2, "source1"))
assert len(dfs) == 0

dfs = list(ppp.process(df3, "source2"))
assert len(dfs) == 2
assert dfs[0]["a"][0] == 5
assert dfs[1]["a"][0] == 6

dfs = list(ppp.process(df4, "source1"))
assert len(dfs) == 1
assert dfs[0]["a"][0] == 12

dfs = list(ppp.finished("source1"))
assert len(dfs) == 0

dfs = list(ppp.process(df5, "source2"))
assert len(dfs) == 3
assert dfs[0]["a"][0] == 17
assert dfs[1]["a"][0] == 18
assert dfs[2]["a"][0] == 24

dfs = list(ppp.finished("source2"))
assert len(dfs) == 0


def test_product_plugin_sources():
with pytest.raises(ValueError):
ppp = PPP()
ppp.prepare(["source1"])

with pytest.raises(ValueError):
ppp = PPP()
ppp.prepare(["source1", "source2", "source3"])

with pytest.raises(ValueError):
ppp = PPP()
ppp.prepare(["source1", "source2"])
list(ppp.process(pd.DataFrame(), "source3"))

with pytest.raises(ValueError):
ppp = PPP()
ppp.prepare(["source1", "source2"])
list(ppp.finished("source3"))


class FIP(FileInputPlugin):
version = "0"

def num_files(self):
return 3

def load_file(self, file_number, row_limit):
if row_limit is None:
row_limit = 1000000
return [f"hello{file_number}"] * row_limit


def test_fileinput(caplog):
caplog.set_level("INFO")
fip = FIP("fip")

fip.prepare([], 1000)
data = list(fip.finalize())

assert len(data) >= 999
assert sorted(set(data)) == ["hello0", "hello1", "hello2"]

assert "100%" in caplog.text

0 comments on commit 94c280c

Please sign in to comment.