diff --git a/countess/core/cmd.py b/countess/core/cmd.py index e934a20..a8d0364 100644 --- a/countess/core/cmd.py +++ b/countess/core/cmd.py @@ -26,6 +26,7 @@ def main() -> None: run(sys.argv[1:]) logging_handler.stop() + logging_queue.close() if __name__ == "__main__": diff --git a/countess/utils/parallel.py b/countess/utils/parallel.py index 0ccfed8..e0d120d 100644 --- a/countess/utils/parallel.py +++ b/countess/utils/parallel.py @@ -1,10 +1,9 @@ import gc import logging -import threading -import time from multiprocessing import Process, Queue, Value from os import cpu_count, getpid -from queue import Empty +from queue import Empty, Full +import time from typing import Callable, Iterable try: @@ -22,6 +21,37 @@ logger = logging.getLogger(__name__) +class IterableMultiprocessQueue: + """This connects a multiprocessing.Queue with a multiprocessing.Value + and gives us a queue that multiple reader processes can iterate over and + they'll each get a StopIteration when the Queue is both finished *and* + empty.""" + + def __init__(self, maxsize=3): + self.queue = Queue(maxsize=maxsize) + self.finished = Value("b", False) + + def put(self, value, timeout=None): + if self.finished.value: + raise ValueError("IterableMultiprocessQueue Stopped") + self.queue.put(value, timeout=timeout) + + def close(self): + self.finished.value = True + self.queue.close() + + def __iter__(self): + return self + + def __next__(self): + while True: + try: + return self.queue.get(timeout=0.1) + except Empty as exc: + if self.finished.value: + raise StopIteration from exc + + def multiprocess_map( function: Callable[Concatenate[V, P], Iterable[D]], values: Iterable[V], *args: P.args, **kwargs: P.kwargs ) -> Iterable[D]: @@ -37,58 +67,60 @@ def multiprocess_map(function, values, *args, **kwargs): # Start up several workers. nproc = ((cpu_count() or 1) + 1) // 2 - input_queue: Queue = Queue() + input_queue = IterableMultiprocessQueue(maxsize=nproc) output_queue: Queue = Queue(maxsize=3) - # XXX is it actually necessary to have this in a separate thread or - # would it be sufficient to add items to input_queue alternately with - # removing items from output_queue? - - enqueue_running = Value("b", True) - - def __enqueue(): - for v in values: - input_queue.put(v) - enqueue_running.value = False + def __process(): + for data_in in input_queue: + for data_out in function(data_in, *args, **(kwargs or {})): + output_queue.put(data_out) + + # Make sure large data is disposed of before we + # go around for the next loop + del data_out + del data_in + gc.collect() - thread = threading.Thread(target=__enqueue) - thread.start() + # Prevent processes from using up all available memory while waiting + # XXX this is probably a bad idea + while psutil.virtual_memory().percent > 90: + logger.warning("PID %d LOW MEMORY %f%%", getpid(), psutil.virtual_memory().percent) + time.sleep(1) - # XXX is this necessary? - time.sleep(1) + processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)] + for p in processes: + p.start() - def __process(): + # push each of the input values onto the input_queue, if it gets full + # then also try to drain the output_queue. + for v in values: while True: try: - while True: - # Prevent processes from using up all - # available memory while waiting - # XXX this is probably a bad idea - while psutil.virtual_memory().percent > 90: - logger.warning("PID %d LOW MEMORY alert %f%%", getpid(), psutil.virtual_memory().percent) - time.sleep(1) - - data_in = input_queue.get(timeout=1) - for data_out in function(data_in, *args, **(kwargs or {})): - output_queue.put(data_out) - - # Make sure large data is disposed of before we - # go around for the next loop - del data_out - del data_in + input_queue.put(v, timeout=0.1) + break + except Full: + try: + yield output_queue.get(timeout=0.1) + except Empty: + # Waiting for the next output, might as well tidy up gc.collect() - except Empty: - if not enqueue_running.value: - break + # we're finished with input values, so close the input_queue to + # signal to all the processes that there will be no new entries + # and once the queue is empty they can finish. + input_queue.close() - processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)] - for p in processes: - p.start() - - while thread.is_alive() or any(p.is_alive() for p in processes): + # wait for all processes to finish and yield any data which appears + # on the output_queue + while any(p.is_alive() for p in processes): try: - yield output_queue.get(timeout=0.1) + while True: + yield output_queue.get(timeout=0.1) except Empty: - # Waiting for the next input, might as well tidy up + # Waiting for the next output, might as well tidy up gc.collect() + + # once all processes have finished, we can clean up the queue. + for p in processes: + p.join() + output_queue.close() diff --git a/tests/gui/test_main.py b/tests/gui/test_main.py index 11731dc..d0ce589 100644 --- a/tests/gui/test_main.py +++ b/tests/gui/test_main.py @@ -44,6 +44,7 @@ def test_main(): root.update() + wrap.destroy() root.destroy() diff --git a/tests/plugins/test_csv.py b/tests/plugins/test_csv.py index 3a29209..549851a 100644 --- a/tests/plugins/test_csv.py +++ b/tests/plugins/test_csv.py @@ -142,7 +142,7 @@ def test_save_csv_bz2(): df2 = pd.DataFrame([[10, 11, 12]], columns=["a", "b", "d"]) -def test_save_csv_multi(): +def test_save_csv_multi(caplog): plugin = SaveCsvPlugin() plugin.set_parameter("header", True) plugin.set_parameter("filename", "tests/output2.csv") @@ -154,3 +154,4 @@ def test_save_csv_multi(): with open("tests/output2.csv", "r", encoding="utf-8") as fh: text = fh.read() assert text == "a,b,c\n1,2,3\n4,5,6\n7,8,9\n10,11,,12\n" + assert "Added CSV Column" in caplog.text diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 0bbb3fb..a658f1c 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,8 +1,11 @@ import importlib.metadata from unittest.mock import patch +import pandas as pd +import pytest + import countess -from countess.core.plugins import BasePlugin, get_plugin_classes +from countess.core.plugins import BasePlugin, FileInputPlugin, PandasProductPlugin, get_plugin_classes empty_entry_points_dict = {"countess_plugins": []} @@ -31,3 +34,94 @@ def test_get_plugin_classes_wrongparent(caplog): with patch("importlib.metadata.EntryPoint.load", lambda x: NoParentPlugin): get_plugin_classes() assert "not a valid CountESS plugin" in caplog.text + + +class PPP(PandasProductPlugin): + version = "0" + + def process_dataframes(self, dataframe1, dataframe2): + return dataframe1 + dataframe2 + + +def test_product_plugin(): + ppp = PPP() + + df1 = pd.DataFrame([{"a": 1}]) + df2 = pd.DataFrame([{"a": 2}]) + df3 = pd.DataFrame([{"a": 4}]) + df4 = pd.DataFrame([{"a": 8}]) + df5 = pd.DataFrame([{"a": 16}]) + + ppp.prepare(["source1", "source2"]) + + dfs = list(ppp.process(df1, "source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df2, "source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df3, "source2")) + assert len(dfs) == 2 + assert dfs[0]["a"][0] == 5 + assert dfs[1]["a"][0] == 6 + + dfs = list(ppp.process(df4, "source1")) + assert len(dfs) == 1 + assert dfs[0]["a"][0] == 12 + + dfs = list(ppp.finished("source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df5, "source2")) + assert len(dfs) == 3 + assert dfs[0]["a"][0] == 17 + assert dfs[1]["a"][0] == 18 + assert dfs[2]["a"][0] == 24 + + dfs = list(ppp.finished("source2")) + assert len(dfs) == 0 + + +def test_product_plugin_sources(): + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1"]) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2", "source3"]) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2"]) + list(ppp.process(pd.DataFrame(), "source3")) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2"]) + list(ppp.finished("source3")) + + +class FIP(FileInputPlugin): + version = "0" + + def num_files(self): + return 3 + + def load_file(self, file_number, row_limit): + if row_limit is None: + row_limit = 1000000 + return [f"hello{file_number}"] * row_limit + + +def test_fileinput(caplog): + caplog.set_level("INFO") + fip = FIP("fip") + + fip.prepare([], 1000) + data = list(fip.finalize()) + + assert len(data) >= 999 + assert sorted(set(data)) == ["hello0", "hello1", "hello2"] + + assert "100%" in caplog.text