diff --git a/.github/workflows/run-tests-push.yml b/.github/workflows/run-tests-push.yml index e71f440..75c4dcb 100644 --- a/.github/workflows/run-tests-push.yml +++ b/.github/workflows/run-tests-push.yml @@ -11,148 +11,60 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.9" + cache: 'pip' - run: sudo apt install xvfb - run: pip install --upgrade pip - run: pip install .[dev] - - run: xvfb-run pytest tests/ + - run: xvfb-run pytest -v -rP --doctest-modules countess/ tests/ - run-tests-ubuntu-22_04-python-3_10-with-coverage: + run-tests-ubuntu-22_04-python-3_10: runs-on: ubuntu-22.04 - name: Ubuntu 22.04, Python 3.10 (with coverage) + name: Ubuntu 22.04, Python 3.10 steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.10" + cache: 'pip' - run: sudo apt install xvfb - run: pip install --upgrade pip - run: pip install .[dev] - - run: xvfb-run coverage run --source countess -m pytest tests/ -# - run: coverage html -# - uses: actions/upload-artifact@v3 -# with: -# name: test coverage report -# path: htmlcov/* - - run: echo '### Coverage Report' >> $GITHUB_STEP_SUMMARY - - run: coverage report --format=markdown --skip-empty --sort=-cover >> $GITHUB_STEP_SUMMARY - -# run-tests-ubuntu-22_04-python-3_11_0rc2: -# runs-on: ubuntu-22.04 -# name: Ubuntu 22.04, Python 3.11.0rc2 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v4 -# with: -# python-version: "3.11.0-rc.2" -# - run: sudo apt install xvfb -# - run: pip install --upgrade pip -# - run: pip install .[dev] -# - run: xvfb-run pytest tests/ -# -# run-tests-ubuntu-22_04-python-3_11_0: -# runs-on: ubuntu-22.04 -# name: Ubuntu 22.04, Python 3.11.1 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v4 -# with: -# python-version: "3.11.0" -# - run: sudo apt install xvfb -# - run: pip install --upgrade pip -# - run: pip install .[dev] -# - run: xvfb-run pytest tests/ -# -# run-tests-ubuntu-22_04-python-3_11_1: -# runs-on: ubuntu-22.04 -# name: Ubuntu 22.04, Python 3.11.1 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v4 -# with: -# python-version: "3.11.1" -# - run: sudo apt install xvfb -# - run: pip install --upgrade pip -# - run: pip install .[dev] -# - run: xvfb-run pytest tests/ - - run-tests-ubuntu-22_04-python-3_11_2: - runs-on: ubuntu-22.04 - name: Ubuntu 22.04, Python 3.11.2 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.11.2" - - run: sudo apt install xvfb - - run: pip install --upgrade pip - - run: pip install .[dev] - - run: xvfb-run pytest tests/ - - run-tests-ubuntu-22_04-python-3_11: - runs-on: ubuntu-22.04 - name: Ubuntu 22.04, Python 3.11 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - run: sudo apt install xvfb - - run: pip install --upgrade pip - - run: pip install .[dev] - - run: xvfb-run pytest tests/ - -# run-tests-ubuntu-22_04-python-3_12_0_alpha5: -# runs-on: ubuntu-22.04 -# name: Ubuntu 22.04, Python 3.12.0alpha5 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v4 -# with: -# python-version: "3.12.0-alpha.5" -# - run: sudo apt install xvfb -# - run: pip install --upgrade pip -# - run: pip install .[dev] -# - run: xvfb-run pytest tests/ + - run: xvfb-run pytest -v -rP --doctest-modules countess/ tests/ run-tests-ubuntu-22_04-python-3_11_from_apt: runs-on: ubuntu-22.04 - name: Ubuntu 22.04, Python from Apt + name: Ubuntu 22.04, Python 3.11 from Apt steps: - uses: actions/checkout@v4 - run: sudo apt install python3.11-full python3-pip xvfb - run: python3.11 -m pip install --upgrade pip - run: python3.11 -m pip install -e .[dev] - - run: xvfb-run python3.11 -mpytest tests/ + - run: xvfb-run python3.11 -m pytest -v -rP --doctest-modules countess/ tests/ - # run-tests-ubuntu-22_10-python-3_11_from_apt: - #runs-on: ubuntu-22.10 - #name: Ubuntu 22.10, Python from Apt - #steps: - #- uses: actions/checkout@v4 - #- run: sudo apt install python3.11-full python3-pip xvfb - #- run: python3.11 -m pip install --upgrade pip - #- run: python3.11 -m pip install -e .[dev] - #- run: xvfb-run python3.11 -mpytest tests/ - - #run-tests-ubuntu-23_04-python-3_11_from_apt: - #runs-on: ubuntu-23.04 - #name: Ubuntu 23.04, Python from Apt - #steps: - #- uses: actions/checkout@v4 - #- run: sudo apt install python3.11-full python3-pip xvfb - #- run: python3.11 -m pip install --upgrade pip - #- run: python3.11 -m pip install -e .[dev] - #- run: xvfb-run python3.11 -mpytest tests/ + run-tests-ubuntu-24_04-python-3_12_from_apt: + runs-on: ubuntu-24.04 + name: Ubuntu 24.04, Python 3.12 from Apt + steps: + - uses: actions/checkout@v4 + - run: sudo apt install python3.12-full python3-pip xvfb + - run: python3.12 -m venv /tmp/venv + - run: /tmp/venv/bin/python -m pip install --upgrade pip + - run: /tmp/venv/bin/python -m pip install -e .[dev] + - run: xvfb-run /tmp/venv/bin/python -m pytest -v -rP --doctest-modules countess/ tests/ -# run-tests-ubuntu-22_04-pypy3: -# runs-on: ubuntu-22.04 -# name: Ubuntu 22.04, PyPy 3 -# steps: -# - uses: actions/checkout@v4 -# - uses: actions/setup-python@v4 -# with: -# python-version: "pypy3.9" -# - run: sudo apt install pypy3 pypy3-tk pypy3-dev xvfb -# - run: pypy3 -mpip install -U pip wheel -# - run: pypy3 -mpip install .[dev] -# - run: xvfb-run pytest tests/ + run-tests-ubuntu-24_04-python-3_x: + runs-on: ubuntu-24.04 + name: Ubuntu 24.04, Python 3.x + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + cache: 'pip' + - run: sudo apt install xvfb + - run: python3 -m venv /tmp/venv + - run: /tmp/venv/bin/python -m pip install --upgrade pip + - run: /tmp/venv/bin/python -m pip install -e .[dev] + - run: xvfb-run /tmp/venv/bin/coverage run --source countess -m pytest -v -rP --doctest-modules countess/ tests/ + - run: echo '### Coverage Report' >> $GITHUB_STEP_SUMMARY + - run: /tmp/venv/bin/coverage report --format=markdown --skip-empty --sort=-cover >> $GITHUB_STEP_SUMMARY diff --git a/countess/core/cmd.py b/countess/core/cmd.py index e934a20..b173060 100644 --- a/countess/core/cmd.py +++ b/countess/core/cmd.py @@ -17,15 +17,28 @@ def run(argv) -> None: def main() -> None: + # set up a default stderr StreamHandler for logs + logging_handler = logging.StreamHandler() + + # set up a QueueHandler/QueueListener to forward the logs between + # processes and send them to the logging_handler logging_queue: multiprocessing.Queue = multiprocessing.Queue() - logging.getLogger().addHandler(logging.handlers.QueueHandler(logging_queue)) - logging.getLogger().setLevel(logging.INFO) - logging_handler = logging.handlers.QueueListener(logging_queue, logging.StreamHandler()) - logging_handler.start() + logging_queue_handler = logging.handlers.QueueHandler(logging_queue) + logging_queue_listener = logging.handlers.QueueListener(logging_queue, logging_handler) + logging_queue_listener.start() + + # set up all loggers to be handled by the QueueHandler. + root_logger = logging.getLogger() + root_logger.addHandler(logging_queue_handler) + root_logger.setLevel(logging.INFO) run(sys.argv[1:]) - logging_handler.stop() + # shut down the logging subsystem, in case this function is being + # called as part of something else (eg: tests) + root_logger.handlers.clear() + logging_queue_listener.stop() + logging_queue.close() if __name__ == "__main__": diff --git a/countess/core/parameters.py b/countess/core/parameters.py index a723c67..77d48c1 100644 --- a/countess/core/parameters.py +++ b/countess/core/parameters.py @@ -110,13 +110,13 @@ def __ne__(self, other): def __gt__(self, other): return self._value > other - def __gte__(self, other): + def __ge__(self, other): return self._value >= other def __lt__(self, other): return self._value < other - def __lte__(self, other): + def __le__(self, other): return self._value <= other @@ -187,6 +187,15 @@ def __int__(self): def __float__(self): return float(self._value) + def __abs__(self): + return abs(self._value) + + def __pos__(self): + return self._value + + def __neg__(self): + return 0 - (self._value) + # XXX should include many more numeric operator methods here, see # https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types # matmul, truediv, floordiv, mod, divmod, pow, lshift, rshift, and, xor, or, @@ -219,10 +228,12 @@ def set_value(self, value): if isinstance(value, str): if value in ("t", "T", "true", "True", "1"): self._value = True - if value in ("f", "F", "false", "False", "0"): + elif value in ("f", "F", "false", "False", "0"): self._value = False - raise ValueError(f"Can't convert {value} to boolean") - self._value = bool(value) + else: + raise ValueError(f"Can't convert {value} to boolean") + else: + self._value = bool(value) def __bool__(self): return self._value or False @@ -402,10 +413,10 @@ def __init__(self, label: str, value: Optional[str] = None, choices: Optional[It super().__init__(label, value, choices) def get_selected_type(self): - if self.value is None: - return None - else: + try: return self.DATA_TYPES[self.value][0] + except KeyError: + return None def cast_value(self, value): if value is not None: @@ -519,7 +530,7 @@ def is_not_none(self): def get_column_prefix(self): if self.is_none(): return None - return self.value.removesuffix("*") + return super().get_column_prefix() class ColumnOrIndexChoiceParam(ColumnChoiceParam): @@ -548,33 +559,32 @@ class ColumnOrStringParam(ColumnChoiceParam): def set_column_choices(self, choices): self.set_choices([self.PREFIX + c for c in choices]) - def get_column_name(self): - if self.value.startswith(self.PREFIX): + def get_column_name(self) -> Optional[str]: + if type(self.value) is str and self.value.startswith(self.PREFIX): return self.value[len(self.PREFIX) :] return None - def get_value_from_dict(self, data: dict): + def get_value_from_dict(self, data: dict) -> str: if type(self.value) is str and self.value.startswith(self.PREFIX): return data[self.value[len(self.PREFIX) :]] else: return self.value - def get_column_or_value(self, df: pd.DataFrame, numeric: bool): - if self.value.startswith(self.PREFIX): + def get_column_or_value(self, df: pd.DataFrame, numeric: bool) -> Union[float, str, pd.Series]: + if type(self.value) is str and self.value.startswith(self.PREFIX): col = df[self.value[len(self.PREFIX) :]] - return col.astype("f" if numeric else "string") + return col.astype(float if numeric else str) else: return float(self.value) if numeric else str(self.value) - def get_column_or_value_numeric(self, df: pd.DataFrame): - if self.value.startswith(self.PREFIX): - return df[self.value[len(self.PREFIX) :]] - else: - return self.value - def set_choices(self, choices: Iterable[str]): self.choices = list(choices) - if self._value is not None and self._value.startswith(self.PREFIX) and self._value not in self.choices: + if ( + self._value is not None + and type(self._value) is str + and self._value.startswith(self.PREFIX) + and self._value not in self.choices + ): self._value = self.DEFAULT_VALUE self._choice = None @@ -862,17 +872,6 @@ def __init__(self, label: str, params: Optional[Mapping[str, BaseParam]] = None) def copy(self) -> "MultiParam": return self.__class__(self.label, self.params) - # XXX decide if these "dict-like" accessors are worth keeping - - def __getitem__(self, key): - return self.params[key] - - def __contains__(self, item): - return item in self.params - - def __setitem__(self, key, value): - self.params[key].value = value - def __iter__(self): return self.params.__iter__() diff --git a/countess/core/pipeline.py b/countess/core/pipeline.py index 94f6475..3d88750 100644 --- a/countess/core/pipeline.py +++ b/countess/core/pipeline.py @@ -18,8 +18,12 @@ class SentinelQueue(Queue): The writer is expected to call `queue.finish()` when it is done and the reader can treat the queue like an iterable.""" - # XXX this is an attempt to handle multiple threads reading from the - # queue in parallel: they should all get StopIterations. + # catch attempts to 'put' more data onto the queue after it has finished. + finished = False + + # Handle multiple threads reading from the + # queue in parallel: once the sentinel has been received by any thread + # all further attempts to read get StopIterations. stopped = False class SENTINEL: @@ -27,6 +31,7 @@ class SENTINEL: def finish(self): self.put(self.SENTINEL) + self.finished = True def __iter__(self): return self @@ -47,13 +52,8 @@ def __next__(self): raise StopIteration return val - def get(self, block=True, timeout=None): - if self.stopped: - raise ValueError("SentinelQueue stopped") - return super().get(block, timeout) - def put(self, item, block=True, timeout=None): - if self.stopped: + if self.finished: raise ValueError("SentinelQueue stopped") super().put(item, block, timeout) @@ -83,7 +83,7 @@ def __init__(self, name, plugin=None, config=None, position=None, notes=None, so self.name = name self.plugin = plugin self.config = config or [] - self.position = position + self.position = position or (0.5, 0.5) self.sort_column = sort_column self.sort_descending = sort_descending self.notes = notes @@ -105,9 +105,6 @@ def add_output_queue(self): self.output_queues.add(queue) return queue - def clear_output_queues(self): - self.output_queues = set() - def queue_output(self, result): for data in result: self.counter_out += 1 @@ -237,39 +234,21 @@ def del_parent(self, parent): parent.child_nodes.discard(self) self.mark_dirty() - def has_sibling(self): - return any(len(pn.child_nodes) > 1 for pn in self.parent_nodes) - def configure_plugin(self, key, value, base_dir="."): self.plugin.set_parameter(key, value, base_dir) self.mark_dirty() - def final_descendants(self): - if self.child_nodes: - return set(n2 for n1 in self.child_nodes for n2 in n1.final_descendants()) - else: - return set(self) - - def detatch(self): + def detach(self): for parent_node in self.parent_nodes: parent_node.child_nodes.discard(self) for child_node in self.child_nodes: child_node.parent_nodes.discard(self) - @classmethod - def get_ancestor_list(cls, nodes): - """Given a bunch of nodes, find the list of all the ancestors in a - sensible order""" - parents = set((p for n in nodes for p in n.parent_nodes)) - if not parents: - return list(nodes) - return cls.get_ancestor_list(parents) + list(nodes) - class PipelineGraph: - def __init__(self): + def __init__(self, nodes: Optional[list[PipelineNode]] = None): self.plugin_classes = get_plugin_classes() - self.nodes = [] + self.nodes = nodes or [] def reset_node_name(self, node): node_names_seen = set(n.name for n in self.nodes if n != node) @@ -285,7 +264,7 @@ def add_node(self, node): self.nodes.append(node) def del_node(self, node): - node.detatch() + node.detach() self.nodes.remove(node) def traverse_nodes(self): diff --git a/countess/core/plugins.py b/countess/core/plugins.py index f890100..85cb9de 100644 --- a/countess/core/plugins.py +++ b/countess/core/plugins.py @@ -250,7 +250,7 @@ def process(self, data: pd.DataFrame, source: str) -> Iterable[pd.DataFrame]: yield result except Exception as exc: # pylint: disable=broad-exception-caught - logger.warning("Exception", exc_info=exc) + logger.warning("Exception", exc_info=exc) # pragma: no cover def process_dataframe(self, dataframe: pd.DataFrame) -> Optional[pd.DataFrame]: """Override this to process a single dataframe""" @@ -383,6 +383,7 @@ def dataframe_to_series(self, dataframe: pd.DataFrame) -> pd.Series: raise NotImplementedError(f"{self.__class__}.dataframe_to_series()") def process_dataframe(self, dataframe: pd.DataFrame) -> Optional[pd.DataFrame]: + dataframe_merged = None try: # 1. A dataframe with duplicates in its index can't be merged back correctly # in Step 4, so we add in an extra RangeIndex to guarantee uniqueness, @@ -407,9 +408,8 @@ def process_dataframe(self, dataframe: pd.DataFrame) -> Optional[pd.DataFrame]: if "__tmpidx" in dataframe_merged.index.names: dataframe_merged.reset_index("__tmpidx", drop=True, inplace=True) - except Exception as exc: # pylint: disable=broad-exception-caught - logger.warning("Exception", exc_info=exc) - return None + except Exception as exc: # pylint: disable=broad-exception-caught # pragma: no cover + logger.warning("Exception", exc_info=exc) # pragma: no cover return dataframe_merged @@ -523,13 +523,9 @@ def series_to_dataframe(self, series: pd.Series) -> pd.DataFrame: series.dropna(inplace=True) data = series.tolist() - if len(data): - max_cols = max(len(d) for d in data) - column_names = column_names[:max_cols] - df = pd.DataFrame(data, columns=column_names, index=series.index) - return df - else: - return pd.DataFrame() + max_cols = max(len(d) for d in data) if len(data) else 0 + column_names = column_names[:max_cols] + return pd.DataFrame(data, columns=column_names, index=series.index) class PandasTransformXToDictMixin: diff --git a/countess/plugins/hgvs_parser.py b/countess/plugins/hgvs_parser.py index f4a9db3..07bf2ab 100644 --- a/countess/plugins/hgvs_parser.py +++ b/countess/plugins/hgvs_parser.py @@ -38,8 +38,8 @@ def process_dict(self, data: dict): if self.guides_str: guides += self.guides_str.value.split(";") - if m := re.match(r"([\w.]+):([ncg].)(.*)", value): - output["reference"] = m.group(1) + if m := re.match(r"(?:([\w.]+):)?([ncg]\.)(.*)", value): + output["reference"] = m.group(1) or '' output["prefix"] = m.group(2) value = m.group(3) diff --git a/countess/utils/files.py b/countess/utils/files.py index a4aec69..f33295e 100644 --- a/countess/utils/files.py +++ b/countess/utils/files.py @@ -1,7 +1,7 @@ import re -def clean_filename(filename): +def clean_filename(filename: str) -> str: m = re.match(r"(?:.*/)*([^.]+).*", filename) if m and m.group(1): return m.group(1) diff --git a/countess/utils/parallel.py b/countess/utils/parallel.py index 0ccfed8..368abfa 100644 --- a/countess/utils/parallel.py +++ b/countess/utils/parallel.py @@ -1,17 +1,16 @@ import gc import logging -import threading -import time from multiprocessing import Process, Queue, Value from os import cpu_count, getpid -from queue import Empty +from queue import Empty, Full +import time from typing import Callable, Iterable try: from typing import Concatenate, ParamSpec, TypeVar -except ImportError: +except ImportError: # pragma: no cover # for Python 3.9 compatibility - from typing_extensions import Concatenate, ParamSpec, TypeVar # type: ignore + from typing_extensions import Concatenate, ParamSpec, TypeVar # type: ignore import psutil @@ -22,6 +21,39 @@ logger = logging.getLogger(__name__) +class IterableMultiprocessQueue: + """This connects a multiprocessing.Queue with a multiprocessing.Value + and gives us a queue that multiple reader processes can iterate over and + they'll each get a StopIteration when the Queue is both finished *and* + empty.""" + + def __init__(self, maxsize=3): + self.queue = Queue(maxsize=maxsize) + self.finished = Value("b", False) + + def put(self, value, timeout=None): + if self.finished.value: + raise ValueError("IterableMultiprocessQueue Stopped") + self.queue.put(value, timeout=timeout) + + def finish(self): + self.finished.value = True + + def close(self): + self.queue.close() + + def __iter__(self): + return self + + def __next__(self): + while True: + try: + return self.queue.get(timeout=0.1) + except Empty as exc: + if self.finished.value: + raise StopIteration from exc + + def multiprocess_map( function: Callable[Concatenate[V, P], Iterable[D]], values: Iterable[V], *args: P.args, **kwargs: P.kwargs ) -> Iterable[D]: @@ -37,58 +69,56 @@ def multiprocess_map(function, values, *args, **kwargs): # Start up several workers. nproc = ((cpu_count() or 1) + 1) // 2 - input_queue: Queue = Queue() + input_queue = IterableMultiprocessQueue(maxsize=nproc) output_queue: Queue = Queue(maxsize=3) - # XXX is it actually necessary to have this in a separate thread or - # would it be sufficient to add items to input_queue alternately with - # removing items from output_queue? + def __process(): # pragma: no cover + # this is run in a pool of `nproc` processes to handle resource-intensive + # processes which don't play nicely with the GIL. + # XXX Coverage doesn't seem to understand this so we exclude it from coverage. + + for data_in in input_queue: + for data_out in function(data_in, *args, **(kwargs or {})): + output_queue.put(data_out) + + # Make sure large data is disposed of before we + # go around for the next loop + del data_out + del data_in + gc.collect() - enqueue_running = Value("b", True) + # Prevent processes from using up all available memory while waiting + # XXX this is probably a bad idea + while psutil.virtual_memory().percent > 90: + logger.warning("PID %d LOW MEMORY %f%%", getpid(), psutil.virtual_memory().percent) + time.sleep(1) + processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)] + for p in processes: + p.start() + + # separate thread is in charge of pushing items into the input_queue def __enqueue(): for v in values: input_queue.put(v) - enqueue_running.value = False + input_queue.finish() thread = threading.Thread(target=__enqueue) thread.start() - # XXX is this necessary? - time.sleep(1) - - def __process(): - while True: - try: - while True: - # Prevent processes from using up all - # available memory while waiting - # XXX this is probably a bad idea - while psutil.virtual_memory().percent > 90: - logger.warning("PID %d LOW MEMORY alert %f%%", getpid(), psutil.virtual_memory().percent) - time.sleep(1) - - data_in = input_queue.get(timeout=1) - for data_out in function(data_in, *args, **(kwargs or {})): - output_queue.put(data_out) - - # Make sure large data is disposed of before we - # go around for the next loop - del data_out - del data_in - gc.collect() - - except Empty: - if not enqueue_running.value: - break - - processes = [Process(target=__process, name=f"worker {n}") for n in range(0, nproc)] - for p in processes: - p.start() - - while thread.is_alive() or any(p.is_alive() for p in processes): + # wait for all processes to finish and yield any data which appears + # on the output_queue as soon as it is available. + while any(p.is_alive() for p in processes): try: - yield output_queue.get(timeout=0.1) + while True: + yield output_queue.get(timeout=0.1) except Empty: - # Waiting for the next input, might as well tidy up + # Waiting for the next output, might as well tidy up gc.collect() + + # once all processes have finished, we can clean up the queue. + thread.join() + for p in processes: + p.join() + input_queue.close() + output_queue.close() diff --git a/docs/contributing/index.md b/docs/contributing/index.md index bda5e9c..9d71bd5 100644 --- a/docs/contributing/index.md +++ b/docs/contributing/index.md @@ -152,6 +152,9 @@ For issues with these pages, especially accessibility issues, please ## Deployment Github actions are set up to run code checks and tests on every push. +Tests are run across Python 3.9, 3.10, 3.11 and 3.12. Even though +3.9 is very old it is very widely deployed and there are a lot of +small but breaking changes. Deployment is not yet automated. There's a couple of small scripts to set a new version number in the code, documentation and git tags. @@ -162,7 +165,8 @@ name on the command line: script/set_version 1.2.3 -There's also a script to automate upload to PyPI using twine. +Releases are not yet automated. Releases are on PyPI (not github), +there's a script to automate upload to PyPI using twine: script/build_and_upload diff --git a/tests/gui/test_main.py b/tests/gui/test_main.py index 11731dc..d0ce589 100644 --- a/tests/gui/test_main.py +++ b/tests/gui/test_main.py @@ -44,6 +44,7 @@ def test_main(): root.update() + wrap.destroy() root.destroy() diff --git a/tests/plugins/test_csv.py b/tests/plugins/test_csv.py index 3a29209..549851a 100644 --- a/tests/plugins/test_csv.py +++ b/tests/plugins/test_csv.py @@ -142,7 +142,7 @@ def test_save_csv_bz2(): df2 = pd.DataFrame([[10, 11, 12]], columns=["a", "b", "d"]) -def test_save_csv_multi(): +def test_save_csv_multi(caplog): plugin = SaveCsvPlugin() plugin.set_parameter("header", True) plugin.set_parameter("filename", "tests/output2.csv") @@ -154,3 +154,4 @@ def test_save_csv_multi(): with open("tests/output2.csv", "r", encoding="utf-8") as fh: text = fh.read() assert text == "a,b,c\n1,2,3\n4,5,6\n7,8,9\n10,11,,12\n" + assert "Added CSV Column" in caplog.text diff --git a/tests/plugins/test_filter.py b/tests/plugins/test_filter.py new file mode 100644 index 0000000..31e7b0f --- /dev/null +++ b/tests/plugins/test_filter.py @@ -0,0 +1,26 @@ +import pandas as pd + +from countess.plugins.filter import FilterPlugin + +df1 = pd.DataFrame( + [ + {"foo": 1, "bar": 2, "baz": 3}, + {"foo": 4, "bar": 5, "baz": 6}, + {"foo": 7, "bar": 8, "baz": 9}, + ], +) + +df2 = df1.set_index("foo") + +code_1 = "qux = bar + baz\n\nquux = bar * baz\n" + +code_2 = "bar + baz != 11" + + +def test_filter_0(): + plugin = FilterPlugin() + # plugin.set_parameter("drop.1", True) + plugin.prepare(["x"]) + + dfs = list(plugin.process(df1, "x")) + assert len(dfs) == 0 diff --git a/tests/plugins/test_hgvs_parser.py b/tests/plugins/test_hgvs_parser.py index a31339f..a71311f 100644 --- a/tests/plugins/test_hgvs_parser.py +++ b/tests/plugins/test_hgvs_parser.py @@ -75,7 +75,7 @@ def test_hgvs_parser_split_and_multi(): assert df["loc"].iloc[1] == "43124111" -df2 = pd.DataFrame([{"fnords": "whatever"}, {"hgvs": None}, {"hgvs": "g.="}, {"hgvs": "g.[1A>T;2G>C;3C>T;4A>T;5A>T]"}]) +df2 = pd.DataFrame([{"fnords": "whatever"}, {"hgvs": None}, {"hgvs": "g.[1A>T;2G>C;3C>T;4A>T;5A>T]"}]) def test_hgvs_parser_bad(): @@ -85,7 +85,17 @@ def test_hgvs_parser_bad(): df = plugin.process_dataframe(df2) print(df) - assert np.isnan(df["var_1"].iloc[0]) - assert np.isnan(df["var_1"].iloc[1]) - assert df["var_1"].iloc[2] == "g.=" - assert np.isnan(df["var_1"].iloc[3]) + assert all(np.isnan(df["var_1"])) + #assert np.isnan(df["var_1"].iloc[0]) + #assert np.isnan(df["var_1"].iloc[1]) + #assert np.isnan(df["var_1"].iloc[2]) + + +def test_hgvs_parser_very_bad(): + plugin = HgvsParserPlugin() + plugin.set_parameter("column", "hgvs") + + dfi = pd.DataFrame([{'a': 1}]) + dfo = plugin.process_dataframe(dfi) + + assert all(dfo == dfi) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..e27216b --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,72 @@ +import io +from unittest.mock import patch + +from countess.core.config import export_config_graphviz, read_config_dict, write_config, write_config_node_string +from countess.core.parameters import IntegerParam +from countess.core.pipeline import PipelineGraph, PipelineNode +from countess.core.plugins import BasePlugin + + +class NothingPlugin(BasePlugin): + version = "0" + param = IntegerParam("param", 0) + + +def test_read_config_dict(): + pn = read_config_dict("node", ".", {"_module": __name__, "_class": "NothingPlugin", "foo": '"bar"'}) + assert pn + assert isinstance(pn.plugin, NothingPlugin) + assert list(pn.config[0]) == ["foo", "bar", "."] + + +def test_read_config_dict_no_plugin(): + pn = read_config_dict("node", ".", {"foo": '"bar"'}) + assert pn.plugin is None + assert list(pn.config[0]) == ["foo", "bar", "."] + + +def test_write_config(): + pn = PipelineNode("node", plugin=NothingPlugin("node"), config=[("foo", "bar", "baz")]) + pg = PipelineGraph([pn]) + + buf = io.StringIO() + buf.close = lambda: None + with patch("builtins.open", lambda *_, **__: buf): + write_config(pg, "whatever") + + s = buf.getvalue() + assert s.startswith("[node]") + assert "foo = 'bar'" in s + + +def test_write_config_node_string(): + pn = PipelineNode("node", plugin=NothingPlugin("node")) + pn.plugin.param = 12 + pn.notes = "hello" + + s = write_config_node_string(pn) + + assert "[node]" in s + assert "_module = %s" % __name__ in s + assert "_class = NothingPlugin" in s + assert "_notes = hello" in s + assert "param = 12" in s + + +def test_export_graphviz(): + pn1 = PipelineNode("node 1") + pn2 = PipelineNode("node 2") + pn3 = PipelineNode("node 3") + pn3.add_parent(pn2) + pn2.add_parent(pn1) + pg = PipelineGraph([pn1, pn2, pn3]) + + buf = io.StringIO() + buf.close = lambda: None + with patch("builtins.open", lambda *_, **__: buf): + export_config_graphviz(pg, "filename") + + s = buf.getvalue() + assert s.startswith("digraph {") + assert '"node 1" -> "node 2";' in s + assert '"node 2" -> "node 3";' in s diff --git a/tests/test_parameters.py b/tests/test_parameters.py new file mode 100644 index 0000000..3435ff6 --- /dev/null +++ b/tests/test_parameters.py @@ -0,0 +1,300 @@ +import io +from unittest.mock import mock_open, patch + +import pandas as pd +import pytest + +from countess.core.parameters import ( + ArrayParam, + BooleanParam, + ChoiceParam, + ColumnChoiceParam, + ColumnGroupOrNoneChoiceParam, + ColumnOrIndexChoiceParam, + ColumnOrIntegerParam, + ColumnOrNoneChoiceParam, + DataTypeChoiceParam, + DataTypeOrNoneChoiceParam, + FileParam, + FloatParam, + IntegerParam, + MultiParam, + ScalarParam, + StringCharacterSetParam, + StringParam, + make_prefix_groups, +) + + +def test_make_prefix_groups(): + x = make_prefix_groups(["one_two", "one_three", "two_one", "two_two", "two_three", "three_four_five"]) + + assert x == {"one_": ["two", "three"], "two_": ["one", "two", "three"]} + + +def test_scalarparm(): + sp1 = ScalarParam("x") + sp1.value = "hello" + sp2 = sp1.copy_and_set_value("goodbye") + assert sp1.value == "hello" + assert sp2.value == "goodbye" + + +def test_stringparam(): + sp = StringParam("i'm a frayed knot") + + sp.set_value("hello") + + assert sp == "hello" + assert sp != "goodbye" + assert sp > "hell" + assert sp >= "hell" + assert sp >= "hello" + assert sp < "help" + assert sp <= "hello" + assert sp <= "help" + assert sp + "world" == "helloworld" + assert "why" + sp == "whyhello" + assert "ell" in sp + assert hash(sp) == hash("hello") + + +def test_floatparam(): + fp = FloatParam("whatever") + + for v in (0, 1, 106.7, -45): + fp.set_value(v) + + assert fp == v + assert fp != v + 1 + assert fp > v - 1 + assert fp < v + 1 + assert fp + 1 == v + 1 + assert 1 + fp == 1 + v + assert fp * 2 == v * 2 + assert 3 * fp == 3 * v + assert -fp == -v + assert fp - 1 == v - 1 + assert 2 - fp == 2 - v + assert float(fp) == v + assert abs(fp) == abs(v) + assert +fp == +v + + +def test_booleanparam(): + bp = BooleanParam("dude") + + with pytest.raises(ValueError): + bp.set_value("Yeah, Nah") + + bp.set_value("T") + + assert bool(bp) + assert str(bp) == "True" + + bp.set_value("F") + + assert not bool(bp) + assert str(bp) == "False" + + +def test_multiparam(): + mp = MultiParam( + "x", + { + "foo": StringParam("Foo"), + "bar": StringParam("Bar"), + }, + ) + + assert "foo" in mp + + mp["foo"] = "hello" + assert mp.foo == "hello" + assert mp["foo"] == "hello" + assert "bar" in mp + + for key in mp: + assert isinstance(mp[key], StringParam) + + for key, param in mp.items(): + assert isinstance(param, StringParam) + + mp.set_parameter("foo._label", "fnord") + assert mp["foo"].label == "fnord" + + +def test_scsp(): + pp = StringCharacterSetParam("x", "hello", character_set=set("HelO")) + pp.value = "helicopter" + assert pp.value == "HelOe" + + +def test_choiceparam(): + cp = ChoiceParam("x", value="a", choices=["a", "b", "c", "d"]) + + cp.value = None + assert cp.value == "" + + cp.choice = 2 + assert cp.choice == 2 + assert cp.value == "c" + + cp.choice = 5 + assert cp.choice is None + assert cp.value == "" + + cp.value = "b" + cp.set_choices(["a", "b", "c"]) + assert cp.choice == 1 + assert cp.value == "b" + + cp.set_choices(["x", "y"]) + assert cp.choice == 0 + assert cp.value == "x" + + cp.set_choices([]) + assert cp.choice is None + assert cp.value == "" + + +def test_dtcp1(): + cp = DataTypeChoiceParam("x") + assert cp.get_selected_type() is None + + +def test_dtcp2(): + cp = DataTypeOrNoneChoiceParam("x") + + assert cp.get_selected_type() is None + assert cp.cast_value("whatever") is None + assert cp.is_none() + + cp.value = "integer" + assert cp.get_selected_type() == int + assert cp.cast_value(7.3) == 7 + assert cp.cast_value("whatever") == 0 + assert not cp.is_none() + + +def test_ccp1(): + cp = ColumnChoiceParam("x", "a") + df = pd.DataFrame([]) + with pytest.raises(ValueError): + cp.get_column(df) + + +def test_ccp2(): + df = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + cp = ColumnOrNoneChoiceParam("x") + cp.set_choices(["a", "b"]) + assert cp.is_none() + assert cp.get_column(df) is None + + cp.value = "a" + assert cp.is_not_none() + assert isinstance(cp.get_column(df), pd.Series) + + df = df.set_index("a") + assert isinstance(cp.get_column(df), pd.Series) + + df = df.reset_index().set_index(["a", "b"]) + assert isinstance(cp.get_column(df), pd.Series) + + df = pd.DataFrame([], columns=["x", "y"]) + with pytest.raises(ValueError): + cp.get_column(df) + + +def test_coindex(): + cp = ColumnOrIndexChoiceParam("x", choices=["a", "b"]) + df = pd.DataFrame(columns=["a", "b"]).set_index("a") + assert cp.is_index() + assert isinstance(cp.get_column(df), pd.Series) + + cp.choice = 1 + assert cp.is_not_index() + assert isinstance(cp.get_column(df), pd.Series) + + +def test_columnorintegerparam(): + df = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + cp = ColumnOrIntegerParam("x") + cp.set_column_choices(["a", "b"]) + + assert cp.get_column_name() is None + + cp.value = "7" + assert cp.choice is None + assert cp.get_column_name() is None + assert cp.get_column_or_value(df, False) == "7" + assert cp.get_column_or_value(df, True) == 7 + + cp.choice = 0 + assert cp.get_column_name() == "a" + assert isinstance(cp.get_column_or_value(df, False), pd.Series) + + cp.set_column_choices(["c", "d"]) + assert cp.choice is None + + cp.value = "hello" + assert cp.value == 0 + + +def test_columngroup(): + df = pd.DataFrame([], columns=["one_two", "one_three", "two_one", "two_two", "two_three", "three_four_five"]) + cp = ColumnGroupOrNoneChoiceParam("x") + cp.set_column_choices(df.columns) + assert cp.is_none() + assert "one_*" in cp.choices + assert "two_*" in cp.choices + assert cp.get_column_prefix() is None + + cp.choice = 2 + assert cp.is_not_none() + assert cp.get_column_prefix() == "two_" + assert cp.get_column_suffixes(df) == ["one", "two", "three"] + assert cp.get_column_names(df) == ["two_one", "two_two", "two_three"] + + +def test_fileparam(): + fp = FileParam("x") + assert fp.get_file_hash() == "0" + + fp.value = "filename" + buf = io.BytesIO(b"hello") + + with patch("builtins.open", lambda *_, **__: buf): + h = fp.get_file_hash() + assert h == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + + def dummy(*_, **__): + raise IOError("x") + + with patch("builtins.open", dummy): + assert fp.get_file_hash() == "0" + + fp.value = "/foo/bar/baz" + assert fp.get_parameters("fnord", "/foo") == [("fnord", "bar/baz")] + + +def test_arrayparam_minmax(): + pp = IntegerParam("x") + ap = ArrayParam("y", param=pp, min_size=2, max_size=3) + assert len(ap) == 2 + + assert isinstance(ap.add_row(), IntegerParam) + assert len(ap) == 3 + + assert ap.add_row() is None + assert len(ap) == 3 + + ap.del_row(0) + assert len(ap) == 2 + + ap.del_row(1) + assert len(ap) == 2 + + # FIX minimum and maximum constraints! + #ap.del_subparam(ap[1]) + #assert len(ap) == 2 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..4a1e09e --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,143 @@ +import pytest + +from countess.core.parameters import IntegerParam +from countess.core.pipeline import PipelineGraph, PipelineNode +from countess.core.plugins import ProcessPlugin + + +@pytest.fixture(name="pg") +def fixture_pg(): + pn0 = PipelineNode("node") + pn1 = PipelineNode("node") + pn2 = PipelineNode("node") + pn3 = PipelineNode("node") + pn4 = PipelineNode("node") + + pg = PipelineGraph([pn0, pn1, pn2, pn3, pn4]) + + pn4.add_parent(pn2) + pn4.add_parent(pn3) + pn3.add_parent(pn1) + pn2.add_parent(pn1) + pn1.add_parent(pn0) + pn1.add_parent(pn0) + + return pg + + +def test_ancestor_descendant(pg): + pns = list(pg.traverse_nodes()) + for pn in pns[1:]: + assert pns[0].is_ancestor_of(pn) + assert not pn.is_ancestor_of(pns[0]) + + for pn in pns[:-1]: + assert pns[-1].is_descendant_of(pn) + assert not pn.is_descendant_of(pns[-1]) + + +def test_pipeline_graph_tidy(pg): + pg.tidy() + + pns = list(pg.traverse_nodes()) + + # check that all nodes have different positions + assert len(set(pn.position for pn in pns)) == len(pns) + + # check that first coordinate is monotonic increasing + xs = [pn.position[0] for pn in pns] + assert sorted(xs) == xs + + +def test_pipeline_del_node(pg): + pns = list(pg.traverse_nodes()) + pg.del_node(pns[2]) + + assert not pns[2].is_descendant_of(pns[0]) + assert not pns[2].is_ancestor_of(pns[-1]) + + +def test_pipeline_del_parent(pg): + pns = list(pg.traverse_nodes()) + pns[2].del_parent(pns[1]) + + assert not pns[1].is_ancestor_of(pns[2]) + assert pns[2].is_ancestor_of(pns[-1]) + + +def test_pipeline_graph_reset_node_name(pg): + pns = list(pg.traverse_nodes()) + pg.reset_node_name(pns[1]) + assert pns[1].name == "node 2" + + pg.reset_node_name(pns[3]) + assert pns[3].name == "node 3" + + +def test_pipeline_graph_reset_node_names(pg): + pg.reset_node_names() + names = [pn.name for pn in pg.traverse_nodes()] + assert sorted(set(names)) == names + + pn = PipelineNode("node") + pg.add_node(pn) + assert pn.name == "node 5" + + +def test_pg_reset(pg): + pg.reset() + + assert all(pn.result is None for pn in pg.traverse_nodes()) + assert all(pn.is_dirty for pn in pg.traverse_nodes()) + + +class DoesNothingPlugin(ProcessPlugin): + version = "0" + param = IntegerParam("param", 0) + + def process(self, data, source): + yield data + + def finished(self, source): + yield 107 + + +def test_plugin_config(caplog): + dnp = DoesNothingPlugin() + dnn = PipelineNode( + "node", + plugin=dnp, + config=[ + ("param", 1, "."), + ("noparam", "whatever", "."), + ], + ) + dnn.load_config() + + assert "noparam=whatever" in caplog.text + assert dnp.param == 1 + + +def test_noplugin_prerun(): + pn = PipelineNode("node") + + with pytest.raises(AssertionError): + pn.load_config() + + pn.prerun() + + +def test_mark_dirty(): + pn1 = PipelineNode("node1", plugin=DoesNothingPlugin()) + pn2 = PipelineNode("node2", plugin=DoesNothingPlugin()) + pn2.add_parent(pn1) + + pn2.prerun() + + assert not pn1.is_dirty + assert not pn2.is_dirty + + pn1.configure_plugin("param", 2) + + assert pn1.is_dirty + assert pn2.is_dirty diff --git a/tests/test_pipeline_sentinelqueue.py b/tests/test_pipeline_sentinelqueue.py new file mode 100644 index 0000000..bf66f4a --- /dev/null +++ b/tests/test_pipeline_sentinelqueue.py @@ -0,0 +1,29 @@ +import pytest + +from countess.core.pipeline import SentinelQueue + + +def test_sentinelqueue(): + sq = SentinelQueue() + + sq.put("hello") + sq.put("world") + + sq.finish() + + # can't add more messages once finished + with pytest.raises(ValueError): + sq.put("oh, no!") + + sqr = iter(sq) + assert next(sqr) == "hello" + assert next(sqr) == "world" + + # when the iterator hits the sentinel it + # raises StopIteration ... + with pytest.raises(StopIteration): + next(sqr) + + # ... and keeps doing so if asked again + with pytest.raises(StopIteration): + next(sqr) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 0bbb3fb..13ccb58 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,8 +1,25 @@ import importlib.metadata from unittest.mock import patch +import pandas as pd +import pytest + +<<<<<<< HEAD import countess -from countess.core.plugins import BasePlugin, get_plugin_classes +from countess.core.plugins import BasePlugin, FileInputPlugin, PandasProductPlugin, get_plugin_classes +======= +from countess.core.parameters import ColumnChoiceParam, StringParam +from countess.core.plugins import ( + BasePlugin, + FileInputPlugin, + PandasConcatProcessPlugin, + PandasProductPlugin, + PandasTransformDictToDictPlugin, + PandasTransformRowToDictPlugin, + PandasTransformSingleToSinglePlugin, + get_plugin_classes, +) +>>>>>>> bf58ca03bb7b84f59ec878258834878f375eb369 empty_entry_points_dict = {"countess_plugins": []} @@ -31,3 +48,185 @@ def test_get_plugin_classes_wrongparent(caplog): with patch("importlib.metadata.EntryPoint.load", lambda x: NoParentPlugin): get_plugin_classes() assert "not a valid CountESS plugin" in caplog.text + + +class PPP(PandasProductPlugin): + version = "0" + + def process_dataframes(self, dataframe1, dataframe2): + return dataframe1 + dataframe2 + + +def test_product_plugin(): + ppp = PPP() + + df1 = pd.DataFrame([{"a": 1}]) + df2 = pd.DataFrame([{"a": 2}]) + df3 = pd.DataFrame([{"a": 4}]) + df4 = pd.DataFrame([{"a": 8}]) + df5 = pd.DataFrame([{"a": 16}]) + + ppp.prepare(["source1", "source2"]) + + dfs = list(ppp.process(df1, "source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df2, "source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df3, "source2")) + assert len(dfs) == 2 + assert dfs[0]["a"][0] == 5 + assert dfs[1]["a"][0] == 6 + + dfs = list(ppp.process(df4, "source1")) + assert len(dfs) == 1 + assert dfs[0]["a"][0] == 12 + + dfs = list(ppp.finished("source1")) + assert len(dfs) == 0 + + dfs = list(ppp.process(df5, "source2")) + assert len(dfs) == 3 + assert dfs[0]["a"][0] == 17 + assert dfs[1]["a"][0] == 18 + assert dfs[2]["a"][0] == 24 + + dfs = list(ppp.finished("source2")) + assert len(dfs) == 0 + + +def test_product_plugin_sources(): + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1"]) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2", "source3"]) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2"]) + list(ppp.process(pd.DataFrame(), "source3")) + + with pytest.raises(ValueError): + ppp = PPP() + ppp.prepare(["source1", "source2"]) + list(ppp.finished("source3")) + + +class FIP(FileInputPlugin): + version = "0" + + def num_files(self): + return 3 + + def load_file(self, file_number, row_limit): + if row_limit is None: + row_limit = 1000000 + return [f"hello{file_number}"] * row_limit + + +def test_fileinput(caplog): + caplog.set_level("INFO") + fip = FIP("fip") + + fip.prepare([], 1000) + data = list(fip.finalize()) + + assert len(data) >= 999 + assert sorted(set(data)) == ["hello0", "hello1", "hello2"] + + assert "100%" in caplog.text + + +class TPCPP(PandasConcatProcessPlugin): + version = "0" + + def process_dataframe(self, dataframe): + return dataframe + + +def test_concat(): + df1 = pd.DataFrame([{"a": 1}]) + df2 = pd.DataFrame([{"a": 2}]) + df3 = pd.DataFrame([{"a": 4}]) + + pcpp = TPCPP() + pcpp.prepare(["a"]) + pcpp.process(df1, "a") + pcpp.process(df2, "a") + pcpp.process(df3, "a") + + dfs = list(pcpp.finalize()) + assert len(dfs) == 1 + assert all(dfs[0]["a"] == [1, 2, 4]) + + +class TPTSTSP(PandasTransformSingleToSinglePlugin): + version = "0" + column = ColumnChoiceParam("Column", "a") + output = StringParam("Output", "c") + + def process_value(self, value): + return value * 3 + 1 if value else None + + +def test_transform_sts(): + thing = TPTSTSP() + dfi = pd.DataFrame([[1, 4], [2, 5], [3, 6]], columns=["a", "b"]) + + dfo = thing.process_dataframe(dfi) + assert all(dfo["c"] == [4, 7, 10]) + + dfo = thing.process_dataframe(dfi.set_index("a")) + assert all(dfo["c"] == [4, 7, 10]) + + dfo = thing.process_dataframe(dfi.set_index(["a", "b"])) + assert all(dfo["c"] == [4, 7, 10]) + + thing.column = "d" + dfo = thing.process_dataframe(dfi) + assert list(dfo["c"]) == [None, None, None] + + dfi = pd.DataFrame([[1, 4], [1, 5], [1, 6]], columns=["i", "d"]).set_index("i") + dfo = thing.process_dataframe(dfi).reset_index() + + assert list(dfo["i"]) == [1, 1, 1] + assert list(dfo["d"]) == [4, 5, 6] + assert list(dfo["c"]) == [13, 16, 19] + + +class TPTRTDP(PandasTransformRowToDictPlugin): + version = "0" + + def process_row(self, row): + return {"c": row["a"] * 3 + 1} + + +def test_transform_rtd(): + thing = TPTRTDP() + dfi = pd.DataFrame([[1, 4], [2, 5], [3, 6]], columns=["a", "b"]) + dfo = thing.process_dataframe(dfi) + assert all(dfo["c"] == [4, 7, 10]) + + +class TPTDTDP(PandasTransformDictToDictPlugin): + version = "0" + + def process_dict(self, data): + return {"c": data["a"] * 3 + 1} + + +def test_transform_dtd(): + thing = TPTDTDP() + dfi = pd.DataFrame([[1, 4], [2, 5], [3, 6]], columns=["a", "b"]) + dfo = thing.process_dataframe(dfi) + assert all(dfo["c"] == [4, 7, 10]) + + dfo = thing.process_dataframe(dfi.set_index("a")) + assert all(dfo["c"] == [4, 7, 10]) + + dfo = thing.process_dataframe(dfi.set_index(["a", "b"])) + assert all(dfo["c"] == [4, 7, 10]) diff --git a/tests/utils/test_files.py b/tests/utils/test_files.py new file mode 100644 index 0000000..9605f83 --- /dev/null +++ b/tests/utils/test_files.py @@ -0,0 +1,9 @@ +from countess.utils.files import clean_filename + + +def test_clean_filename(): + assert clean_filename("baz.qux.quux") == "baz" + assert clean_filename("foo/bar/baz.qux.quux") == "baz" + assert clean_filename("foo/bar/baz") == "baz" + assert clean_filename("fnord") == "fnord" + assert clean_filename("") == "" diff --git a/tests/utils/test_parallel.py b/tests/utils/test_parallel.py index 76d455c..ade5387 100644 --- a/tests/utils/test_parallel.py +++ b/tests/utils/test_parallel.py @@ -1,4 +1,6 @@ -from countess.utils.parallel import multiprocess_map +import pytest + +from countess.utils.parallel import IterableMultiprocessQueue, multiprocess_map def test_multiprocess_map(): @@ -40,3 +42,17 @@ def function(value): 42, 49, ] + + +def test_multiprocess_map_stopped(): + impq = IterableMultiprocessQueue() + + impq.put("1") + impq.put("2") + impq.put("3") + impq.finish() + + with pytest.raises(ValueError): + impq.put("4") + + assert sorted(list(impq)) == ["1", "2", "3"]