From 5eee2d0a5ef723802783bc690f722c6fa8d616eb Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Fri, 6 Sep 2024 15:27:56 +0100 Subject: [PATCH 01/15] Introduce task class + small refactoring Signed-off-by: Merel Theisen --- kedro/runner/__init__.py | 2 + kedro/runner/parallel_runner.py | 19 +-- kedro/runner/runner.py | 196 +++------------------------ kedro/runner/sequential_runner.py | 13 +- kedro/runner/task.py | 218 ++++++++++++++++++++++++++++++ kedro/runner/thread_runner.py | 18 +-- 6 files changed, 243 insertions(+), 223 deletions(-) create mode 100644 kedro/runner/task.py diff --git a/kedro/runner/__init__.py b/kedro/runner/__init__.py index 0725d56a8a..51910bf369 100644 --- a/kedro/runner/__init__.py +++ b/kedro/runner/__init__.py @@ -5,12 +5,14 @@ from .parallel_runner import ParallelRunner from .runner import AbstractRunner, run_node from .sequential_runner import SequentialRunner +from .task import Task from .thread_runner import ThreadRunner __all__ = [ "AbstractRunner", "ParallelRunner", "SequentialRunner", + "Task", "ThreadRunner", "run_node", ] diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 62d7e1216b..14c9a10a60 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -27,7 +27,7 @@ MemoryDataset, SharedMemoryDataset, ) -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -317,19 +317,4 @@ def _run( node = future.result() done_nodes.add(node) - # Decrement load counts, and release any datasets we - # have finished with. This is particularly important - # for the shared, default datasets we created above. - for dataset in node.inputs: - load_counts[dataset] -= 1 - if ( - load_counts[dataset] < 1 - and dataset not in pipeline.inputs() - ): - catalog.release(dataset) - for dataset in node.outputs: - if ( - load_counts[dataset] < 1 - and dataset not in pipeline.outputs() - ): - catalog.release(dataset) + decrement_and_release_datasets(node, catalog, load_counts, pipeline) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 2ffd0389e4..0742bf2602 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -5,24 +5,15 @@ from __future__ import annotations import inspect -import itertools as it import logging from abc import ABC, abstractmethod from collections import deque -from concurrent.futures import ( - ALL_COMPLETED, - Future, - ThreadPoolExecutor, - as_completed, - wait, -) -from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator - -from more_itertools import interleave +from typing import TYPE_CHECKING, Any, Collection, Iterable from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import DataCatalog, MemoryDataset from kedro.pipeline import Pipeline +from kedro.runner.task import Task if TYPE_CHECKING: from pluggy import PluginManager @@ -403,6 +394,7 @@ def run_node( The node argument. """ + if is_async and inspect.isgeneratorfunction(node.func): raise ValueError( f"Async data loading and saving does not work with " @@ -411,175 +403,19 @@ def run_node( f"in node {node!s}." ) - if is_async: - node = _run_node_async(node, catalog, hook_manager, session_id) - else: - node = _run_node_sequential(node, catalog, hook_manager, session_id) - - for name in node.confirms: - catalog.confirm(name) - return node - - -def _collect_inputs_from_hook( # noqa: PLR0913 - node: Node, - catalog: DataCatalog, - inputs: dict[str, Any], - is_async: bool, - hook_manager: PluginManager, - session_id: str | None = None, -) -> dict[str, Any]: - inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook - hook_response = hook_manager.hook.before_node_run( - node=node, - catalog=catalog, - inputs=inputs, - is_async=is_async, - session_id=session_id, - ) - - additional_inputs = {} - if ( - hook_response is not None - ): # all hooks on a _NullPluginManager will return None instead of a list - for response in hook_response: - if response is not None and not isinstance(response, dict): - response_type = type(response).__name__ - raise TypeError( - f"'before_node_run' must return either None or a dictionary mapping " - f"dataset names to updated values, got '{response_type}' instead." - ) - additional_inputs.update(response or {}) - - return additional_inputs - - -def _call_node_run( # noqa: PLR0913 - node: Node, - catalog: DataCatalog, - inputs: dict[str, Any], - is_async: bool, - hook_manager: PluginManager, - session_id: str | None = None, -) -> dict[str, Any]: - try: - outputs = node.run(inputs) - except Exception as exc: - hook_manager.hook.on_node_error( - error=exc, - node=node, - catalog=catalog, - inputs=inputs, - is_async=is_async, - session_id=session_id, - ) - raise exc - hook_manager.hook.after_node_run( - node=node, - catalog=catalog, - inputs=inputs, - outputs=outputs, - is_async=is_async, - session_id=session_id, - ) - return outputs - - -def _run_node_sequential( - node: Node, - catalog: DataCatalog, - hook_manager: PluginManager, - session_id: str | None = None, -) -> Node: - inputs = {} - - for name in node.inputs: - hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node) - inputs[name] = catalog.load(name) - hook_manager.hook.after_dataset_loaded( - dataset_name=name, data=inputs[name], node=node - ) - - is_async = False - - additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - inputs.update(additional_inputs) - - outputs = _call_node_run( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - - items: Iterable = outputs.items() - # if all outputs are iterators, then the node is a generator node - if all(isinstance(d, Iterator) for d in outputs.values()): - # Python dictionaries are ordered, so we are sure - # the keys and the chunk streams are in the same order - # [a, b, c] - keys = list(outputs.keys()) - # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] - streams = list(outputs.values()) - # zip an endless cycle of the keys - # with an interleaved iterator of the streams - # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete - items = zip(it.cycle(keys), interleave(*streams)) - - for name, data in items: - hook_manager.hook.before_dataset_saved(dataset_name=name, data=data, node=node) - catalog.save(name, data) - hook_manager.hook.after_dataset_saved(dataset_name=name, data=data, node=node) + task = Task(node, catalog, hook_manager, is_async, session_id) + node = task.execute() return node -def _run_node_async( - node: Node, - catalog: DataCatalog, - hook_manager: PluginManager, - session_id: str | None = None, -) -> Node: - def _synchronous_dataset_load(dataset_name: str) -> Any: - """Minimal wrapper to ensure Hooks are run synchronously - within an asynchronous dataset load.""" - hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node) - return_ds = catalog.load(dataset_name) - hook_manager.hook.after_dataset_loaded( - dataset_name=dataset_name, data=return_ds, node=node - ) - return return_ds - - with ThreadPoolExecutor() as pool: - inputs: dict[str, Future] = {} - - for name in node.inputs: - inputs[name] = pool.submit(_synchronous_dataset_load, name) - - wait(inputs.values(), return_when=ALL_COMPLETED) - inputs = {key: value.result() for key, value in inputs.items()} - is_async = True - additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - inputs.update(additional_inputs) - - outputs = _call_node_run( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - - future_dataset_mapping = {} - for name, data in outputs.items(): - hook_manager.hook.before_dataset_saved( - dataset_name=name, data=data, node=node - ) - future = pool.submit(catalog.save, name, data) - future_dataset_mapping[future] = (name, data) - - for future in as_completed(future_dataset_mapping): - exception = future.exception() - if exception: - raise exception - name, data = future_dataset_mapping[future] - hook_manager.hook.after_dataset_saved( - dataset_name=name, data=data, node=node - ) - return node +def decrement_and_release_datasets( + node: Node, catalog: DataCatalog, load_counts, pipeline +): + """Decrement dataset load counts and release any datasets we've finished with""" + for dataset in node.inputs: + load_counts[dataset] -= 1 + if load_counts[dataset] < 1 and dataset not in pipeline.inputs(): + catalog.release(dataset) + for dataset in node.outputs: + if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): + catalog.release(dataset) diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 48dac3cd54..f79e8b6aaa 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -9,7 +9,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -81,15 +81,8 @@ def _run( self._suggest_resume_scenario(pipeline, done_nodes, catalog) raise - # decrement load counts and release any data sets we've finished with - for dataset in node.inputs: - load_counts[dataset] -= 1 - if load_counts[dataset] < 1 and dataset not in pipeline.inputs(): - catalog.release(dataset) - for dataset in node.outputs: - if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): - catalog.release(dataset) + decrement_and_release_datasets(node, catalog, load_counts, pipeline) self._logger.info( - "Completed %d out of %d tasks", exec_index + 1, len(nodes) + "Completed %d out of %d tasks", len(done_nodes), len(nodes) ) diff --git a/kedro/runner/task.py b/kedro/runner/task.py new file mode 100644 index 0000000000..4e3c4d1820 --- /dev/null +++ b/kedro/runner/task.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import itertools as it +from concurrent.futures import ( + ALL_COMPLETED, + Future, + ThreadPoolExecutor, + as_completed, + wait, +) +from typing import TYPE_CHECKING, Any, Iterable, Iterator + +from more_itertools import interleave + +if TYPE_CHECKING: + from pluggy import PluginManager + + from kedro.io import DataCatalog + from kedro.pipeline.node import Node + + +class Task: + def __init__(self, node, catalog, hook_manager, is_async, session_id): + self.node = node + self.catalog = catalog + self.hook_manager = hook_manager + self.is_async = is_async + self.session_id = session_id + + def execute(self): + if self.is_async: + node = self._run_node_async( + self.node, self.catalog, self.hook_manager, self.session_id + ) + else: + node = self._run_node_sequential( + self.node, self.catalog, self.hook_manager, self.session_id + ) + + for name in node.confirms: + self.catalog.confirm(name) + + return node + + def _run_node_sequential( + self, + node: Node, + catalog: DataCatalog, + hook_manager: PluginManager, + session_id: str | None = None, + ) -> Node: + inputs = {} + + for name in node.inputs: + hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node) + inputs[name] = catalog.load(name) + hook_manager.hook.after_dataset_loaded( + dataset_name=name, data=inputs[name], node=node + ) + + is_async = False + + additional_inputs = self._collect_inputs_from_hook( + node, catalog, inputs, is_async, hook_manager, session_id=session_id + ) + inputs.update(additional_inputs) + + outputs = self._call_node_run( + node, catalog, inputs, is_async, hook_manager, session_id=session_id + ) + + items: Iterable = outputs.items() + # if all outputs are iterators, then the node is a generator node + if all(isinstance(d, Iterator) for d in outputs.values()): + # Python dictionaries are ordered, so we are sure + # the keys and the chunk streams are in the same order + # [a, b, c] + keys = list(outputs.keys()) + # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] + streams = list(outputs.values()) + # zip an endless cycle of the keys + # with an interleaved iterator of the streams + # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete + items = zip(it.cycle(keys), interleave(*streams)) + + for name, data in items: + hook_manager.hook.before_dataset_saved( + dataset_name=name, data=data, node=node + ) + catalog.save(name, data) + hook_manager.hook.after_dataset_saved( + dataset_name=name, data=data, node=node + ) + return node + + def _run_node_async( + self, + node: Node, + catalog: DataCatalog, + hook_manager: PluginManager, + session_id: str | None = None, + ) -> Node: + with ThreadPoolExecutor() as pool: + inputs: dict[str, Future] = {} + + for name in node.inputs: + inputs[name] = pool.submit( + self._synchronous_dataset_load, name, node, catalog, hook_manager + ) + + wait(inputs.values(), return_when=ALL_COMPLETED) + inputs = {key: value.result() for key, value in inputs.items()} + is_async = True + additional_inputs = self._collect_inputs_from_hook( + node, catalog, inputs, is_async, hook_manager, session_id=session_id + ) + inputs.update(additional_inputs) + + outputs = self._call_node_run( + node, catalog, inputs, is_async, hook_manager, session_id=session_id + ) + + future_dataset_mapping = {} + for name, data in outputs.items(): + hook_manager.hook.before_dataset_saved( + dataset_name=name, data=data, node=node + ) + future = pool.submit(catalog.save, name, data) + future_dataset_mapping[future] = (name, data) + + for future in as_completed(future_dataset_mapping): + exception = future.exception() + if exception: + raise exception + name, data = future_dataset_mapping[future] + hook_manager.hook.after_dataset_saved( + dataset_name=name, data=data, node=node + ) + return node + + @staticmethod + def _synchronous_dataset_load( + dataset_name: str, node: Node, catalog: DataCatalog, hook_manager: PluginManager + ) -> Any: + """Minimal wrapper to ensure Hooks are run synchronously + within an asynchronous dataset load.""" + hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node) + return_ds = catalog.load(dataset_name) + hook_manager.hook.after_dataset_loaded( + dataset_name=dataset_name, data=return_ds, node=node + ) + return return_ds + + @staticmethod + def _collect_inputs_from_hook( # noqa: PLR0913 + node: Node, + catalog: DataCatalog, + inputs: dict[str, Any], + is_async: bool, + hook_manager: PluginManager, + session_id: str | None = None, + ) -> dict[str, Any]: + inputs = ( + inputs.copy() + ) # shallow copy to prevent in-place modification by the hook + hook_response = hook_manager.hook.before_node_run( + node=node, + catalog=catalog, + inputs=inputs, + is_async=is_async, + session_id=session_id, + ) + + additional_inputs = {} + if ( + hook_response is not None + ): # all hooks on a _NullPluginManager will return None instead of a list + for response in hook_response: + if response is not None and not isinstance(response, dict): + response_type = type(response).__name__ + raise TypeError( + f"'before_node_run' must return either None or a dictionary mapping " + f"dataset names to updated values, got '{response_type}' instead." + ) + additional_inputs.update(response or {}) + + return additional_inputs + + @staticmethod + def _call_node_run( # noqa: PLR0913 + node: Node, + catalog: DataCatalog, + inputs: dict[str, Any], + is_async: bool, + hook_manager: PluginManager, + session_id: str | None = None, + ) -> dict[str, Any]: + try: + outputs = node.run(inputs) + except Exception as exc: + hook_manager.hook.on_node_error( + error=exc, + node=node, + catalog=catalog, + inputs=inputs, + is_async=is_async, + session_id=session_id, + ) + raise exc + hook_manager.hook.after_node_run( + node=node, + catalog=catalog, + inputs=inputs, + outputs=outputs, + is_async=is_async, + session_id=session_id, + ) + return outputs diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index b4751a602a..734190b158 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -11,7 +11,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -143,18 +143,4 @@ def _run( "Completed %d out of %d tasks", len(done_nodes), len(nodes) ) - # Decrement load counts, and release any datasets we - # have finished with. - for dataset in node.inputs: - load_counts[dataset] -= 1 - if ( - load_counts[dataset] < 1 - and dataset not in pipeline.inputs() - ): - catalog.release(dataset) - for dataset in node.outputs: - if ( - load_counts[dataset] < 1 - and dataset not in pipeline.outputs() - ): - catalog.release(dataset) + decrement_and_release_datasets(node, catalog, load_counts, pipeline) From f094fe7ee64687d7b3038f846f41043218c5ca80 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 3 Oct 2024 15:18:00 +0100 Subject: [PATCH 02/15] Merge in main and make release_datasets private Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 4 +- kedro/runner/runner.py | 200 ++---------------------------- kedro/runner/sequential_runner.py | 4 +- kedro/runner/task.py | 26 ++-- kedro/runner/thread_runner.py | 4 +- 5 files changed, 37 insertions(+), 201 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index b4b66e4010..82b84c5816 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -27,7 +27,7 @@ MemoryDataset, SharedMemoryDataset, ) -from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node +from kedro.runner.runner import AbstractRunner, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -319,4 +319,4 @@ def _run( node = future.result() done_nodes.add(node) - decrement_and_release_datasets(node, catalog, load_counts, pipeline) + self._release_datasets(node, catalog, load_counts, pipeline) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 7f7dfc022c..1e83b1547a 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -212,6 +212,19 @@ def _suggest_resume_scenario( f"argument to your previous command:\n{postfix}" ) + @staticmethod + def _release_datasets( + node: Node, catalog: CatalogProtocol, load_counts: dict, pipeline: Pipeline + ) -> None: + """Decrement dataset load counts and release any datasets we've finished with""" + for dataset in node.inputs: + load_counts[dataset] -= 1 + if load_counts[dataset] < 1 and dataset not in pipeline.inputs(): + catalog.release(dataset) + for dataset in node.outputs: + if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): + catalog.release(dataset) + def _find_nodes_to_resume_from( pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol @@ -402,193 +415,6 @@ def run_node( f"in node {node!s}." ) - task = Task(node, catalog, hook_manager, is_async, session_id) node = task.execute() return node - - -def decrement_and_release_datasets( - node: Node, catalog: DataCatalog, load_counts, pipeline -): - """Decrement dataset load counts and release any datasets we've finished with""" - for dataset in node.inputs: - load_counts[dataset] -= 1 - if load_counts[dataset] < 1 and dataset not in pipeline.inputs(): - catalog.release(dataset) - for dataset in node.outputs: - if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): - catalog.release(dataset) - - if is_async: - node = _run_node_async(node, catalog, hook_manager, session_id) - else: - node = _run_node_sequential(node, catalog, hook_manager, session_id) - - for name in node.confirms: - catalog.confirm(name) - return node - - -def _collect_inputs_from_hook( # noqa: PLR0913 - node: Node, - catalog: DataCatalog, - inputs: dict[str, Any], - is_async: bool, - hook_manager: PluginManager, - session_id: str | None = None, -) -> dict[str, Any]: - inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook - hook_response = hook_manager.hook.before_node_run( - node=node, - catalog=catalog, - inputs=inputs, - is_async=is_async, - session_id=session_id, - ) - - additional_inputs = {} - if ( - hook_response is not None - ): # all hooks on a _NullPluginManager will return None instead of a list - for response in hook_response: - if response is not None and not isinstance(response, dict): - response_type = type(response).__name__ - raise TypeError( - f"'before_node_run' must return either None or a dictionary mapping " - f"dataset names to updated values, got '{response_type}' instead." - ) - additional_inputs.update(response or {}) - - return additional_inputs - - -def _call_node_run( # noqa: PLR0913 - node: Node, - catalog: DataCatalog, - inputs: dict[str, Any], - is_async: bool, - hook_manager: PluginManager, - session_id: str | None = None, -) -> dict[str, Any]: - try: - outputs = node.run(inputs) - except Exception as exc: - hook_manager.hook.on_node_error( - error=exc, - node=node, - catalog=catalog, - inputs=inputs, - is_async=is_async, - session_id=session_id, - ) - raise exc - hook_manager.hook.after_node_run( - node=node, - catalog=catalog, - inputs=inputs, - outputs=outputs, - is_async=is_async, - session_id=session_id, - ) - return outputs - - -def _run_node_sequential( - node: Node, - catalog: DataCatalog, - hook_manager: PluginManager, - session_id: str | None = None, -) -> Node: - inputs = {} - - for name in node.inputs: - hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node) - inputs[name] = catalog.load(name) - hook_manager.hook.after_dataset_loaded( - dataset_name=name, data=inputs[name], node=node - ) - - is_async = False - - additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - inputs.update(additional_inputs) - - outputs = _call_node_run( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - - items: Iterable = outputs.items() - # if all outputs are iterators, then the node is a generator node - if all(isinstance(d, Iterator) for d in outputs.values()): - # Python dictionaries are ordered, so we are sure - # the keys and the chunk streams are in the same order - # [a, b, c] - keys = list(outputs.keys()) - # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] - streams = list(outputs.values()) - # zip an endless cycle of the keys - # with an interleaved iterator of the streams - # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete - items = zip(it.cycle(keys), interleave(*streams)) - - for name, data in items: - hook_manager.hook.before_dataset_saved(dataset_name=name, data=data, node=node) - catalog.save(name, data) - hook_manager.hook.after_dataset_saved(dataset_name=name, data=data, node=node) - return node - - -def _run_node_async( - node: Node, - catalog: CatalogProtocol, - hook_manager: PluginManager, - session_id: str | None = None, -) -> Node: - def _synchronous_dataset_load(dataset_name: str) -> Any: - """Minimal wrapper to ensure Hooks are run synchronously - within an asynchronous dataset load.""" - hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node) - return_ds = catalog.load(dataset_name) - hook_manager.hook.after_dataset_loaded( - dataset_name=dataset_name, data=return_ds, node=node - ) - return return_ds - - with ThreadPoolExecutor() as pool: - inputs: dict[str, Future] = {} - - for name in node.inputs: - inputs[name] = pool.submit(_synchronous_dataset_load, name) - - wait(inputs.values(), return_when=ALL_COMPLETED) - inputs = {key: value.result() for key, value in inputs.items()} - is_async = True - additional_inputs = _collect_inputs_from_hook( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - inputs.update(additional_inputs) - - outputs = _call_node_run( - node, catalog, inputs, is_async, hook_manager, session_id=session_id - ) - - future_dataset_mapping = {} - for name, data in outputs.items(): - hook_manager.hook.before_dataset_saved( - dataset_name=name, data=data, node=node - ) - future = pool.submit(catalog.save, name, data) - future_dataset_mapping[future] = (name, data) - - for future in as_completed(future_dataset_mapping): - exception = future.exception() - if exception: - raise exception - name, data = future_dataset_mapping[future] - hook_manager.hook.after_dataset_saved( - dataset_name=name, data=data, node=node - ) - return node diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 6810a066fc..dab940a9ac 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -9,7 +9,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node +from kedro.runner.runner import AbstractRunner, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -81,7 +81,7 @@ def _run( self._suggest_resume_scenario(pipeline, done_nodes, catalog) raise - decrement_and_release_datasets(node, catalog, load_counts, pipeline) + self._release_datasets(node, catalog, load_counts, pipeline) self._logger.info( "Completed %d out of %d tasks", len(done_nodes), len(nodes) diff --git a/kedro/runner/task.py b/kedro/runner/task.py index 4e3c4d1820..b5683a4c27 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -15,19 +15,26 @@ if TYPE_CHECKING: from pluggy import PluginManager - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline.node import Node class Task: - def __init__(self, node, catalog, hook_manager, is_async, session_id): + def __init__( + self, + node: Node, + catalog: CatalogProtocol, + hook_manager: PluginManager, + is_async: bool, + session_id: str | None = None, + ): self.node = node self.catalog = catalog self.hook_manager = hook_manager self.is_async = is_async self.session_id = session_id - def execute(self): + def execute(self) -> Node: if self.is_async: node = self._run_node_async( self.node, self.catalog, self.hook_manager, self.session_id @@ -45,7 +52,7 @@ def execute(self): def _run_node_sequential( self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: @@ -96,7 +103,7 @@ def _run_node_sequential( def _run_node_async( self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: @@ -140,7 +147,10 @@ def _run_node_async( @staticmethod def _synchronous_dataset_load( - dataset_name: str, node: Node, catalog: DataCatalog, hook_manager: PluginManager + dataset_name: str, + node: Node, + catalog: CatalogProtocol, + hook_manager: PluginManager, ) -> Any: """Minimal wrapper to ensure Hooks are run synchronously within an asynchronous dataset load.""" @@ -154,7 +164,7 @@ def _synchronous_dataset_load( @staticmethod def _collect_inputs_from_hook( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, @@ -189,7 +199,7 @@ def _collect_inputs_from_hook( # noqa: PLR0913 @staticmethod def _call_node_run( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index 730e743183..c0aedf912b 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -11,7 +11,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, decrement_and_release_datasets, run_node +from kedro.runner.runner import AbstractRunner, run_node if TYPE_CHECKING: from pluggy import PluginManager @@ -143,4 +143,4 @@ def _run( "Completed %d out of %d tasks", len(done_nodes), len(nodes) ) - decrement_and_release_datasets(node, catalog, load_counts, pipeline) + self._release_datasets(node, catalog, load_counts, pipeline) From 666cad3998620c9b36f42d952a568ce7563b71f8 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 3 Oct 2024 16:58:17 +0100 Subject: [PATCH 03/15] Fix session tests Signed-off-by: Merel Theisen --- .../session/test_session_extension_hooks.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/framework/session/test_session_extension_hooks.py b/tests/framework/session/test_session_extension_hooks.py index 81eda65b46..4eab743580 100644 --- a/tests/framework/session/test_session_extension_hooks.py +++ b/tests/framework/session/test_session_extension_hooks.py @@ -22,7 +22,7 @@ from kedro.pipeline import node, pipeline from kedro.pipeline.node import Node from kedro.runner import ParallelRunner -from kedro.runner.runner import _run_node_async +from kedro.runner.task import Task from tests.framework.session.conftest import ( _assert_hook_call_record_has_expected_parameters, _assert_pipeline_equal, @@ -580,11 +580,8 @@ def test_after_dataset_load_hook_async( mock_session.load_context() # run the node asynchronously with an instance of `LogCatalog` - _run_node_async( - node=sample_node, - catalog=memory_catalog, - hook_manager=mock_session._hook_manager, - ) + task = Task(node=sample_node, catalog=memory_catalog, hook_manager=mock_session._hook_manager, is_async=True) + task.execute() hooks_log_messages = [r.message for r in logs_listener.logs] @@ -604,11 +601,8 @@ def test_after_dataset_load_hook_async_multiple_outputs( hook_manager.hook, "after_dataset_saved" ) - _run_node_async( - node=sample_node_multiple_outputs, - catalog=memory_catalog, - hook_manager=hook_manager, - ) + task = Task(node=sample_node_multiple_outputs, catalog=memory_catalog, hook_manager=hook_manager, is_async=True) + task.execute() after_dataset_saved_mock.assert_has_calls( [ From ffc1113a48b4c1023bda7ec99aec3abd5822fb19 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Fri, 4 Oct 2024 14:33:28 +0100 Subject: [PATCH 04/15] Fix lint Signed-off-by: Merel Theisen --- .../session/test_session_extension_hooks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/framework/session/test_session_extension_hooks.py b/tests/framework/session/test_session_extension_hooks.py index 4eab743580..64d2e4509f 100644 --- a/tests/framework/session/test_session_extension_hooks.py +++ b/tests/framework/session/test_session_extension_hooks.py @@ -580,7 +580,12 @@ def test_after_dataset_load_hook_async( mock_session.load_context() # run the node asynchronously with an instance of `LogCatalog` - task = Task(node=sample_node, catalog=memory_catalog, hook_manager=mock_session._hook_manager, is_async=True) + task = Task( + node=sample_node, + catalog=memory_catalog, + hook_manager=mock_session._hook_manager, + is_async=True, + ) task.execute() hooks_log_messages = [r.message for r in logs_listener.logs] @@ -601,7 +606,12 @@ def test_after_dataset_load_hook_async_multiple_outputs( hook_manager.hook, "after_dataset_saved" ) - task = Task(node=sample_node_multiple_outputs, catalog=memory_catalog, hook_manager=hook_manager, is_async=True) + task = Task( + node=sample_node_multiple_outputs, + catalog=memory_catalog, + hook_manager=hook_manager, + is_async=True, + ) task.execute() after_dataset_saved_mock.assert_has_calls( From 7d0f3b68ba9041e6a504bdc317f037c8dab734fc Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 8 Oct 2024 16:02:53 +0100 Subject: [PATCH 05/15] Make Task runnable and call inside runners Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 35 +++++++++++++--------------- kedro/runner/sequential_runner.py | 12 ++++++++-- kedro/runner/task.py | 13 +++++++++++ kedro/runner/thread_runner.py | 19 +++++++-------- tests/runner/test_parallel_runner.py | 23 ++---------------- 5 files changed, 50 insertions(+), 52 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 82b84c5816..6ee9d5fa31 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -27,7 +27,7 @@ MemoryDataset, SharedMemoryDataset, ) -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner.runner import AbstractRunner if TYPE_CHECKING: from pluggy import PluginManager @@ -58,14 +58,10 @@ def _bootstrap_subprocess( configure_logging(logging_config) -def _run_node_synchronization( # noqa: PLR0913 - node: Node, - catalog: CatalogProtocol, - is_async: bool = False, - session_id: str | None = None, +def _run_node_synchronization( package_name: str | None = None, logging_config: dict[str, Any] | None = None, -) -> Node: +) -> None: """Run a single `Node` with inputs from and outputs to the `catalog`. A ``PluginManager`` instance is created in each subprocess because the @@ -91,8 +87,6 @@ def _run_node_synchronization( # noqa: PLR0913 _register_hooks(hook_manager, settings.HOOKS) _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) - return run_node(node, catalog, hook_manager, is_async, session_id) - class ParallelRunner(AbstractRunner): """``ParallelRunner`` is an ``AbstractRunner`` implementation. It can @@ -287,17 +281,20 @@ def _run( ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} todo_nodes -= ready for node in ready: - futures.add( - pool.submit( - _run_node_synchronization, - node, - catalog, - self._is_async, - session_id, - package_name=PACKAGE_NAME, - logging_config=LOGGING, # type: ignore[arg-type] - ) + from kedro.runner.task import Task + + _run_node_synchronization( + package_name=PACKAGE_NAME, + logging_config=LOGGING, # type: ignore[arg-type] + ) + task = Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=self._is_async, + session_id=session_id, ) + futures.add(pool.submit(task)) if not futures: if todo_nodes: debug_data = { diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index dab940a9ac..10f3024dc3 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -9,7 +9,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner.runner import AbstractRunner if TYPE_CHECKING: from pluggy import PluginManager @@ -75,7 +75,15 @@ def _run( for exec_index, node in enumerate(nodes): try: - run_node(node, catalog, hook_manager, self._is_async, session_id) + from kedro.runner.task import Task + + Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=self._is_async, + session_id=session_id, + ).execute() done_nodes.add(node) except Exception: self._suggest_resume_scenario(pipeline, done_nodes, catalog) diff --git a/kedro/runner/task.py b/kedro/runner/task.py index b5683a4c27..05a619235b 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import itertools as it from concurrent.futures import ( ALL_COMPLETED, @@ -35,6 +36,14 @@ def __init__( self.session_id = session_id def execute(self) -> Node: + if self.is_async and inspect.isgeneratorfunction(self.node.func): + raise ValueError( + f"Async data loading and saving does not work with " + f"nodes wrapping generator functions. Please make " + f"sure you don't use `yield` anywhere " + f"in node {self.node!s}." + ) + if self.is_async: node = self._run_node_async( self.node, self.catalog, self.hook_manager, self.session_id @@ -49,6 +58,10 @@ def execute(self) -> Node: return node + def __call__(self) -> Node: + """Make the class instance callable by ProcessPoolExecutor.""" + return self.execute() + def _run_node_sequential( self, node: Node, diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index c0aedf912b..19cfaafdbd 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -11,7 +11,8 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from kedro.runner.runner import AbstractRunner, run_node +from kedro.runner import Task +from kedro.runner.runner import AbstractRunner if TYPE_CHECKING: from pluggy import PluginManager @@ -117,16 +118,14 @@ def _run( ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} todo_nodes -= ready for node in ready: - futures.add( - pool.submit( - run_node, - node, - catalog, - hook_manager, - self._is_async, - session_id, - ) + task = Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=self._is_async, + session_id=session_id, ) + futures.add(pool.submit(task)) if not futures: assert not todo_nodes, (todo_nodes, done_nodes, ready, done) # noqa: S101 break diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index 11165799a0..dba04c1153 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -345,7 +345,6 @@ def test_release_transcoded(self, is_async): assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] -@pytest.mark.parametrize("is_async", [False, True]) class TestRunNodeSynchronisationHelper: """Test class for _run_node_synchronization helper. It is tested manually in isolation since it's called in the subprocess, which ParallelRunner @@ -367,40 +366,22 @@ def mock_configure_project(self, mocker): def test_package_name_and_logging_provided( self, mock_logging, - mock_run_node, mock_configure_project, - is_async, mocker, ): mocker.patch("multiprocessing.get_start_method", return_value="spawn") - node_ = mocker.sentinel.node - catalog = mocker.sentinel.catalog - session_id = "fake_session_id" package_name = mocker.sentinel.package_name _run_node_synchronization( - node_, - catalog, - is_async, - session_id, package_name=package_name, logging_config={"fake_logging_config": True}, ) - mock_run_node.assert_called_once() mock_logging.assert_called_once_with({"fake_logging_config": True}) mock_configure_project.assert_called_once_with(package_name) - def test_package_name_not_provided( - self, mock_logging, mock_run_node, is_async, mocker - ): + def test_package_name_not_provided(self, mock_logging, mocker): mocker.patch("multiprocessing.get_start_method", return_value="fork") - node_ = mocker.sentinel.node - catalog = mocker.sentinel.catalog - session_id = "fake_session_id" package_name = mocker.sentinel.package_name - _run_node_synchronization( - node_, catalog, is_async, session_id, package_name=package_name - ) - mock_run_node.assert_called_once() + _run_node_synchronization(package_name=package_name) mock_logging.assert_not_called() From a40f9fb2d72ea2a6df65ca6b1fef776bc09513fb Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 8 Oct 2024 16:16:23 +0100 Subject: [PATCH 06/15] Fix lint Signed-off-by: Merel Theisen --- kedro/runner/runner.py | 10 ---------- kedro/runner/task.py | 3 ++- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 908813bdab..931fd46763 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -8,18 +8,8 @@ import logging from abc import ABC, abstractmethod from collections import deque -from collections.abc import Iterator -from concurrent.futures import ( - ALL_COMPLETED, - Future, - ThreadPoolExecutor, - as_completed, - wait, -) from typing import TYPE_CHECKING, Any -from more_itertools import interleave - from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import CatalogProtocol, MemoryDataset from kedro.pipeline import Pipeline diff --git a/kedro/runner/task.py b/kedro/runner/task.py index 05a619235b..9ec5def7c7 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -2,6 +2,7 @@ import inspect import itertools as it +from collections.abc import Iterable, Iterator from concurrent.futures import ( ALL_COMPLETED, Future, @@ -9,7 +10,7 @@ as_completed, wait, ) -from typing import TYPE_CHECKING, Any, Iterable, Iterator +from typing import TYPE_CHECKING, Any from more_itertools import interleave From 5eb7cab560711bcb2b0d21e30c06fc4a32d23a49 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 8 Oct 2024 16:53:56 +0100 Subject: [PATCH 07/15] Fix parallel runner Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 40 +++++++++++++++++----------- tests/runner/test_parallel_runner.py | 27 +++++++++++++++++-- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 42b6dff8f9..0dad9780fa 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -60,10 +60,14 @@ def _bootstrap_subprocess( configure_logging(logging_config) -def _run_node_synchronization( +def _run_node_synchronization( # noqa: PLR0913 + node: Node, + catalog: CatalogProtocol, + is_async: bool = False, + session_id: str | None = None, package_name: str | None = None, logging_config: dict[str, Any] | None = None, -) -> None: +) -> Any: """Run a single `Node` with inputs from and outputs to the `catalog`. A ``PluginManager`` instance is created in each subprocess because the @@ -88,6 +92,15 @@ def _run_node_synchronization( hook_manager = _create_hook_manager() _register_hooks(hook_manager, settings.HOOKS) _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) + from kedro.runner import Task + + return Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=is_async, + session_id=session_id, + ).execute() class ParallelRunner(AbstractRunner): @@ -283,20 +296,17 @@ def _run( ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} todo_nodes -= ready for node in ready: - from kedro.runner.task import Task - - _run_node_synchronization( - package_name=PACKAGE_NAME, - logging_config=LOGGING, # type: ignore[arg-type] - ) - task = Task( - node=node, - catalog=catalog, - hook_manager=hook_manager, - is_async=self._is_async, - session_id=session_id, + futures.add( + pool.submit( + _run_node_synchronization, + node, + catalog, + self._is_async, + session_id, + package_name=PACKAGE_NAME, + logging_config=LOGGING, # type: ignore[arg-type] + ) ) - futures.add(pool.submit(task)) if not futures: if todo_nodes: debug_data = { diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index dba04c1153..c73384719b 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -345,6 +345,7 @@ def test_release_transcoded(self, is_async): assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] +@pytest.mark.parametrize("is_async", [False, True]) class TestRunNodeSynchronisationHelper: """Test class for _run_node_synchronization helper. It is tested manually in isolation since it's called in the subprocess, which ParallelRunner @@ -359,6 +360,10 @@ def mock_logging(self, mocker): def mock_run_node(self, mocker): return mocker.patch("kedro.runner.parallel_runner.run_node") + @pytest.fixture + def mock_execute_task(self, mocker): + return mocker.patch("kedro.runner.task.Task.execute") + @pytest.fixture def mock_configure_project(self, mocker): return mocker.patch("kedro.framework.project.configure_project") @@ -367,21 +372,39 @@ def test_package_name_and_logging_provided( self, mock_logging, mock_configure_project, + mock_execute_task, + is_async, mocker, ): mocker.patch("multiprocessing.get_start_method", return_value="spawn") + node_ = mocker.sentinel.node + catalog = mocker.sentinel.catalog + session_id = "fake_session_id" package_name = mocker.sentinel.package_name _run_node_synchronization( + node_, + catalog, + is_async, + session_id, package_name=package_name, logging_config={"fake_logging_config": True}, ) + mock_execute_task.assert_called_once() mock_logging.assert_called_once_with({"fake_logging_config": True}) mock_configure_project.assert_called_once_with(package_name) - def test_package_name_not_provided(self, mock_logging, mocker): + def test_package_name_not_provided( + self, mock_logging, mock_execute_task, is_async, mocker + ): mocker.patch("multiprocessing.get_start_method", return_value="fork") + node_ = mocker.sentinel.node + catalog = mocker.sentinel.catalog + session_id = "fake_session_id" package_name = mocker.sentinel.package_name - _run_node_synchronization(package_name=package_name) + _run_node_synchronization( + node_, catalog, is_async, session_id, package_name=package_name + ) + mock_execute_task.assert_called_once() mock_logging.assert_not_called() From 0fa8ceaf3e5460b88466731d7b73e4286872bec2 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Wed, 9 Oct 2024 14:26:52 +0200 Subject: [PATCH 08/15] Fix test coverage Signed-off-by: Merel Theisen --- tests/runner/test_task.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/runner/test_task.py diff --git a/tests/runner/test_task.py b/tests/runner/test_task.py new file mode 100644 index 0000000000..c098f5604a --- /dev/null +++ b/tests/runner/test_task.py @@ -0,0 +1,20 @@ +import pytest + +from kedro.framework.hooks.manager import _NullPluginManager +from kedro.pipeline import node +from kedro.runner import Task + + +def generate_one(): + yield from range(10) + + +class TestTask: + def test_generator_fail_async(self, mocker, catalog): + fake_dataset = mocker.Mock() + catalog.add("result", fake_dataset) + n = node(generate_one, inputs=None, outputs="result") + + with pytest.raises(Exception, match="nodes wrapping generator functions"): + task = Task(n, catalog, _NullPluginManager(), is_async=True) + task.execute() From 54389ee18fbfcba0e6adad9de58ec8c9caee775e Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 15 Oct 2024 12:40:09 +0100 Subject: [PATCH 09/15] Move helper methods from parallel runner to task Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 61 +--------------------------- kedro/runner/task.py | 60 +++++++++++++++++++++++++++ pyproject.toml | 1 + tests/runner/test_parallel_runner.py | 2 +- 4 files changed, 63 insertions(+), 61 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 0dad9780fa..efe5b067b9 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -4,7 +4,6 @@ from __future__ import annotations -import multiprocessing import os import sys from collections import Counter @@ -15,12 +14,6 @@ from pickle import PicklingError from typing import TYPE_CHECKING, Any -from kedro.framework.hooks.manager import ( - _create_hook_manager, - _register_hooks, - _register_hooks_entry_points, -) -from kedro.framework.project import settings from kedro.io import ( CatalogProtocol, DatasetNotFoundError, @@ -50,59 +43,6 @@ class ParallelRunnerManager(SyncManager): ParallelRunnerManager.register("MemoryDataset", MemoryDataset) -def _bootstrap_subprocess( - package_name: str, logging_config: dict[str, Any] | None = None -) -> None: - from kedro.framework.project import configure_logging, configure_project - - configure_project(package_name) - if logging_config: - configure_logging(logging_config) - - -def _run_node_synchronization( # noqa: PLR0913 - node: Node, - catalog: CatalogProtocol, - is_async: bool = False, - session_id: str | None = None, - package_name: str | None = None, - logging_config: dict[str, Any] | None = None, -) -> Any: - """Run a single `Node` with inputs from and outputs to the `catalog`. - - A ``PluginManager`` instance is created in each subprocess because the - ``PluginManager`` can't be serialised. - - Args: - node: The ``Node`` to run. - catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. - is_async: If True, the node inputs and outputs are loaded and saved - asynchronously with threads. Defaults to False. - session_id: The session id of the pipeline run. - package_name: The name of the project Python package. - logging_config: A dictionary containing logging configuration. - - Returns: - The node argument. - - """ - if multiprocessing.get_start_method() == "spawn" and package_name: - _bootstrap_subprocess(package_name, logging_config) - - hook_manager = _create_hook_manager() - _register_hooks(hook_manager, settings.HOOKS) - _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) - from kedro.runner import Task - - return Task( - node=node, - catalog=catalog, - hook_manager=hook_manager, - is_async=is_async, - session_id=session_id, - ).execute() - - class ParallelRunner(AbstractRunner): """``ParallelRunner`` is an ``AbstractRunner`` implementation. It can be used to run the ``Pipeline`` in parallel groups formed by toposort. @@ -290,6 +230,7 @@ def _run( max_workers = self._get_required_workers_count(pipeline) from kedro.framework.project import LOGGING, PACKAGE_NAME + from kedro.runner.task import _run_node_synchronization with ProcessPoolExecutor(max_workers=max_workers) as pool: while True: diff --git a/kedro/runner/task.py b/kedro/runner/task.py index 9ec5def7c7..0bab46eb99 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -2,6 +2,7 @@ import inspect import itertools as it +import multiprocessing from collections.abc import Iterable, Iterator from concurrent.futures import ( ALL_COMPLETED, @@ -14,6 +15,13 @@ from more_itertools import interleave +from kedro.framework.hooks.manager import ( + _create_hook_manager, + _register_hooks, + _register_hooks_entry_points, +) +from kedro.framework.project import settings + if TYPE_CHECKING: from pluggy import PluginManager @@ -21,6 +29,58 @@ from kedro.pipeline.node import Node +def _bootstrap_subprocess( + package_name: str, logging_config: dict[str, Any] | None = None +) -> None: + from kedro.framework.project import configure_logging, configure_project + + configure_project(package_name) + if logging_config: + configure_logging(logging_config) + + +def _run_node_synchronization( # noqa: PLR0913 + node: Node, + catalog: CatalogProtocol, + is_async: bool = False, + session_id: str | None = None, + package_name: str | None = None, + logging_config: dict[str, Any] | None = None, +) -> Any: + """Run a single `Node` with inputs from and outputs to the `catalog`. + + A ``PluginManager`` instance is created in each subprocess because the + ``PluginManager`` can't be serialised. + + Args: + node: The ``Node`` to run. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. + is_async: If True, the node inputs and outputs are loaded and saved + asynchronously with threads. Defaults to False. + session_id: The session id of the pipeline run. + package_name: The name of the project Python package. + logging_config: A dictionary containing logging configuration. + + Returns: + The node argument. + + """ + if multiprocessing.get_start_method() == "spawn" and package_name: + _bootstrap_subprocess(package_name, logging_config) + + hook_manager = _create_hook_manager() + _register_hooks(hook_manager, settings.HOOKS) + _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) + + return Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=is_async, + session_id=session_id, + ).execute() + + class Task: def __init__( self, diff --git a/pyproject.toml b/pyproject.toml index 23b60e9a61..ba92dfd4c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ layers = [ ] ignore_imports = [ "kedro.runner.parallel_runner -> kedro.framework.project", + "kedro.runner.task -> kedro.framework.project", "kedro.framework.hooks.specs -> kedro.framework.context" ] diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index c73384719b..5c2cdde718 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -19,8 +19,8 @@ from kedro.runner.parallel_runner import ( _MAX_WINDOWS_WORKERS, ParallelRunnerManager, - _run_node_synchronization, ) +from kedro.runner.task import _run_node_synchronization from tests.runner.conftest import ( exception_fn, identity, From 94e489438be09186db17cb722ebd2ffb9352ad40 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 15 Oct 2024 12:45:13 +0100 Subject: [PATCH 10/15] Mark run_node as deprecated Signed-off-by: Merel Theisen --- kedro/runner/runner.py | 6 ++++++ tests/runner/test_run_node.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 931fd46763..9c12fb64f7 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -6,10 +6,12 @@ import inspect import logging +import warnings from abc import ABC, abstractmethod from collections import deque from typing import TYPE_CHECKING, Any +from kedro import KedroDeprecationWarning from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import CatalogProtocol, MemoryDataset from kedro.pipeline import Pipeline @@ -408,6 +410,10 @@ def run_node( The node argument. """ + warnings.warn( + "`run_node()` has been deprecated and will be removed in Kedro 0.20.0", + KedroDeprecationWarning, + ) if is_async and inspect.isgeneratorfunction(node.func): raise ValueError( diff --git a/tests/runner/test_run_node.py b/tests/runner/test_run_node.py index ad95b4838b..5bd5cb77d8 100644 --- a/tests/runner/test_run_node.py +++ b/tests/runner/test_run_node.py @@ -1,5 +1,7 @@ import pytest +from pytest import warns +from kedro import KedroDeprecationWarning from kedro.framework.hooks.manager import _NullPluginManager from kedro.pipeline import node from kedro.runner import run_node @@ -87,3 +89,15 @@ def test_generator_node_dict(self, mocker, catalog): assert left.save.call_args_list == expected_left assert 10 == right.save.call_count assert right.save.call_args_list == expected_right + + def test_run_node_deprecated(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_dict, inputs=None, outputs={"idx": "left", "square": "right"}) + with warns( + KedroDeprecationWarning, + match=r"\`run_node\(\)\` has been deprecated", + ): + run_node(n, catalog, _NullPluginManager()) From 7346f9ae42d9eb30774a89fe5dc12d80201089ef Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 15 Oct 2024 14:38:18 +0100 Subject: [PATCH 11/15] Fix tests Signed-off-by: Merel Theisen --- tests/framework/session/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/framework/session/conftest.py b/tests/framework/session/conftest.py index e34082829f..899b41fc6d 100644 --- a/tests/framework/session/conftest.py +++ b/tests/framework/session/conftest.py @@ -358,7 +358,6 @@ def _mock_imported_settings_paths(mocker, mock_settings): for path in [ "kedro.framework.session.session.settings", "kedro.framework.project.settings", - "kedro.runner.parallel_runner.settings", ]: mocker.patch(path, mock_settings) return mock_settings From 7a0e4d9282232743cd65f49405008c49028ac1e4 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Tue, 15 Oct 2024 15:10:51 +0100 Subject: [PATCH 12/15] Fix tests Signed-off-by: Merel Theisen --- tests/framework/session/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/framework/session/conftest.py b/tests/framework/session/conftest.py index 899b41fc6d..ea9fbbdb0d 100644 --- a/tests/framework/session/conftest.py +++ b/tests/framework/session/conftest.py @@ -358,6 +358,7 @@ def _mock_imported_settings_paths(mocker, mock_settings): for path in [ "kedro.framework.session.session.settings", "kedro.framework.project.settings", + "kedro.runner.task.settings", ]: mocker.patch(path, mock_settings) return mock_settings From 14b10d8590b391b2d9de5ca14a689ea24df3c7a9 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 17 Oct 2024 14:16:16 +0100 Subject: [PATCH 13/15] Refactor helper methods to go inside Task, making hook_manager an optional argument, and adding parallel as boolean flag Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 20 +++--- kedro/runner/runner.py | 8 ++- kedro/runner/task.py | 103 +++++++++++++-------------- pyproject.toml | 1 - tests/runner/test_parallel_runner.py | 66 ----------------- tests/runner/test_task.py | 61 +++++++++++++++- 6 files changed, 124 insertions(+), 135 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 95434fdd30..5edfac464d 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -229,25 +229,21 @@ def _run( done = None max_workers = self._get_required_workers_count(pipeline) - from kedro.framework.project import LOGGING, PACKAGE_NAME - from kedro.runner.task import _run_node_synchronization + from kedro.runner.task import Task with ProcessPoolExecutor(max_workers=max_workers) as pool: while True: ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} todo_nodes -= ready for node in ready: - futures.add( - pool.submit( - _run_node_synchronization, - node, - catalog, - self._is_async, - session_id, - package_name=PACKAGE_NAME, - logging_config=LOGGING, # type: ignore[arg-type] - ) + task = Task( + node=node, + catalog=catalog, + is_async=self._is_async, + session_id=session_id, + parallel=True, ) + futures.add(pool.submit(task)) if not futures: if todo_nodes: debug_data = { diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 9c12fb64f7..4c4ac75e0d 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -423,6 +423,12 @@ def run_node( f"in node {node!s}." ) - task = Task(node, catalog, hook_manager, is_async, session_id) + task = Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=is_async, + session_id=session_id, + ) node = task.execute() return node diff --git a/kedro/runner/task.py b/kedro/runner/task.py index 0bab46eb99..c7de9b8edd 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -29,74 +29,32 @@ from kedro.pipeline.node import Node -def _bootstrap_subprocess( - package_name: str, logging_config: dict[str, Any] | None = None -) -> None: - from kedro.framework.project import configure_logging, configure_project - - configure_project(package_name) - if logging_config: - configure_logging(logging_config) - - -def _run_node_synchronization( # noqa: PLR0913 - node: Node, - catalog: CatalogProtocol, - is_async: bool = False, - session_id: str | None = None, - package_name: str | None = None, - logging_config: dict[str, Any] | None = None, -) -> Any: - """Run a single `Node` with inputs from and outputs to the `catalog`. - - A ``PluginManager`` instance is created in each subprocess because the - ``PluginManager`` can't be serialised. - - Args: - node: The ``Node`` to run. - catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. - is_async: If True, the node inputs and outputs are loaded and saved - asynchronously with threads. Defaults to False. - session_id: The session id of the pipeline run. - package_name: The name of the project Python package. - logging_config: A dictionary containing logging configuration. - - Returns: - The node argument. - - """ - if multiprocessing.get_start_method() == "spawn" and package_name: - _bootstrap_subprocess(package_name, logging_config) - - hook_manager = _create_hook_manager() - _register_hooks(hook_manager, settings.HOOKS) - _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) - - return Task( - node=node, - catalog=catalog, - hook_manager=hook_manager, - is_async=is_async, - session_id=session_id, - ).execute() - - class Task: - def __init__( + def __init__( # noqa: PLR0913 self, node: Node, catalog: CatalogProtocol, - hook_manager: PluginManager, is_async: bool, + hook_manager: PluginManager | None = None, session_id: str | None = None, + parallel: bool = False, ): self.node = node self.catalog = catalog self.hook_manager = hook_manager self.is_async = is_async self.session_id = session_id + self.parallel = parallel def execute(self) -> Node: + if self.parallel: + from kedro.framework.project import LOGGING, PACKAGE_NAME + + hook_manager = Task._run_node_synchronization( + package_name=PACKAGE_NAME, logging_config=LOGGING + ) + self.hook_manager = hook_manager + if self.is_async and inspect.isgeneratorfunction(self.node.func): raise ValueError( f"Async data loading and saving does not work with " @@ -123,6 +81,43 @@ def __call__(self) -> Node: """Make the class instance callable by ProcessPoolExecutor.""" return self.execute() + @staticmethod + def _bootstrap_subprocess( + package_name: str, logging_config: dict[str, Any] | None = None + ) -> None: + from kedro.framework.project import configure_logging, configure_project + + configure_project(package_name) + if logging_config: + configure_logging(logging_config) + + @staticmethod + def _run_node_synchronization( + package_name: str | None = None, + logging_config: dict[str, Any] | None = None, + ) -> Any: + """Run a single `Node` with inputs from and outputs to the `catalog`. + + A ``PluginManager`` instance is created in each subprocess because the + ``PluginManager`` can't be serialised. + + Args: + package_name: The name of the project Python package. + logging_config: A dictionary containing logging configuration. + + Returns: + The node argument. + + """ + if multiprocessing.get_start_method() == "spawn" and package_name: + Task._bootstrap_subprocess(package_name, logging_config) + + hook_manager = _create_hook_manager() + _register_hooks(hook_manager, settings.HOOKS) + _register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS) + + return hook_manager + def _run_node_sequential( self, node: Node, diff --git a/pyproject.toml b/pyproject.toml index 9b4957147f..b53c266789 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,7 +169,6 @@ layers = [ "config" ] ignore_imports = [ - "kedro.runner.parallel_runner -> kedro.framework.project", "kedro.runner.task -> kedro.framework.project", "kedro.framework.hooks.specs -> kedro.framework.context" ] diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index 5c2cdde718..0f989048f1 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -20,7 +20,6 @@ _MAX_WINDOWS_WORKERS, ParallelRunnerManager, ) -from kedro.runner.task import _run_node_synchronization from tests.runner.conftest import ( exception_fn, identity, @@ -343,68 +342,3 @@ def test_release_transcoded(self, is_async): # we want to see both datasets being released assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] - - -@pytest.mark.parametrize("is_async", [False, True]) -class TestRunNodeSynchronisationHelper: - """Test class for _run_node_synchronization helper. It is tested manually - in isolation since it's called in the subprocess, which ParallelRunner - patches have no access to. - """ - - @pytest.fixture(autouse=True) - def mock_logging(self, mocker): - return mocker.patch("logging.config.dictConfig") - - @pytest.fixture - def mock_run_node(self, mocker): - return mocker.patch("kedro.runner.parallel_runner.run_node") - - @pytest.fixture - def mock_execute_task(self, mocker): - return mocker.patch("kedro.runner.task.Task.execute") - - @pytest.fixture - def mock_configure_project(self, mocker): - return mocker.patch("kedro.framework.project.configure_project") - - def test_package_name_and_logging_provided( - self, - mock_logging, - mock_configure_project, - mock_execute_task, - is_async, - mocker, - ): - mocker.patch("multiprocessing.get_start_method", return_value="spawn") - node_ = mocker.sentinel.node - catalog = mocker.sentinel.catalog - session_id = "fake_session_id" - package_name = mocker.sentinel.package_name - - _run_node_synchronization( - node_, - catalog, - is_async, - session_id, - package_name=package_name, - logging_config={"fake_logging_config": True}, - ) - mock_execute_task.assert_called_once() - mock_logging.assert_called_once_with({"fake_logging_config": True}) - mock_configure_project.assert_called_once_with(package_name) - - def test_package_name_not_provided( - self, mock_logging, mock_execute_task, is_async, mocker - ): - mocker.patch("multiprocessing.get_start_method", return_value="fork") - node_ = mocker.sentinel.node - catalog = mocker.sentinel.catalog - session_id = "fake_session_id" - package_name = mocker.sentinel.package_name - - _run_node_synchronization( - node_, catalog, is_async, session_id, package_name=package_name - ) - mock_execute_task.assert_called_once() - mock_logging.assert_not_called() diff --git a/tests/runner/test_task.py b/tests/runner/test_task.py index c098f5604a..ac6f42e6af 100644 --- a/tests/runner/test_task.py +++ b/tests/runner/test_task.py @@ -10,11 +10,70 @@ def generate_one(): class TestTask: + @pytest.fixture(autouse=True) + def mock_logging(self, mocker): + return mocker.patch("logging.config.dictConfig") + + @pytest.fixture + def mock_configure_project(self, mocker): + return mocker.patch("kedro.framework.project.configure_project") + def test_generator_fail_async(self, mocker, catalog): fake_dataset = mocker.Mock() catalog.add("result", fake_dataset) n = node(generate_one, inputs=None, outputs="result") with pytest.raises(Exception, match="nodes wrapping generator functions"): - task = Task(n, catalog, _NullPluginManager(), is_async=True) + task = Task( + node=n, + catalog=catalog, + hook_manager=_NullPluginManager(), + is_async=True, + ) task.execute() + + @pytest.mark.parametrize("is_async", [False, True]) + def test_package_name_and_logging_provided( + self, + mock_logging, + mock_configure_project, + is_async, + mocker, + ): + mocker.patch("multiprocessing.get_start_method", return_value="spawn") + node_ = mocker.sentinel.node + catalog = mocker.sentinel.catalog + session_id = "fake_session_id" + package_name = mocker.sentinel.package_name + + task = Task( + node=node_, + catalog=catalog, + session_id=session_id, + is_async=is_async, + parallel=True, + ) + task._run_node_synchronization( + package_name=package_name, + logging_config={"fake_logging_config": True}, + ) + mock_logging.assert_called_once_with({"fake_logging_config": True}) + mock_configure_project.assert_called_once_with(package_name) + + @pytest.mark.parametrize("is_async", [False, True]) + def test_package_name_not_provided(self, mock_logging, is_async, mocker): + mocker.patch("multiprocessing.get_start_method", return_value="fork") + node_ = mocker.sentinel.node + catalog = mocker.sentinel.catalog + session_id = "fake_session_id" + package_name = mocker.sentinel.package_name + + task = Task( + node=node_, + catalog=catalog, + session_id=session_id, + is_async=is_async, + parallel=True, + ) + task._run_node_synchronization(package_name=package_name) + mock_logging.assert_not_called() From 9cf2ed0d0e7d6271f5197e0a997b3d85167fe6f1 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 17 Oct 2024 15:14:29 +0100 Subject: [PATCH 14/15] Fix lint and handle no hook_manager Signed-off-by: Merel Theisen --- kedro/runner/task.py | 42 +++++++++++++++++++++++++++++---------- tests/runner/test_task.py | 16 +++++++++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/kedro/runner/task.py b/kedro/runner/task.py index c7de9b8edd..437eec94ab 100644 --- a/kedro/runner/task.py +++ b/kedro/runner/task.py @@ -29,6 +29,15 @@ from kedro.pipeline.node import Node +class TaskError(Exception): + """``TaskError`` raised by ``Task`` + in case of failure of provided task arguments + + """ + + pass + + class Task: def __init__( # noqa: PLR0913 self, @@ -47,14 +56,6 @@ def __init__( # noqa: PLR0913 self.parallel = parallel def execute(self) -> Node: - if self.parallel: - from kedro.framework.project import LOGGING, PACKAGE_NAME - - hook_manager = Task._run_node_synchronization( - package_name=PACKAGE_NAME, logging_config=LOGGING - ) - self.hook_manager = hook_manager - if self.is_async and inspect.isgeneratorfunction(self.node.func): raise ValueError( f"Async data loading and saving does not work with " @@ -63,13 +64,32 @@ def execute(self) -> Node: f"in node {self.node!s}." ) + if not self.hook_manager and not self.parallel: + raise TaskError( + "No hook_manager provided. This is only allowed when running a ``Task`` with ``ParallelRunner``." + ) + + if self.parallel: + from kedro.framework.project import LOGGING, PACKAGE_NAME + + hook_manager = Task._run_node_synchronization( + package_name=PACKAGE_NAME, + logging_config=LOGGING, # type: ignore[arg-type] + ) + self.hook_manager = hook_manager if self.is_async: node = self._run_node_async( - self.node, self.catalog, self.hook_manager, self.session_id + self.node, + self.catalog, + self.hook_manager, # type: ignore[arg-type] + self.session_id, ) else: node = self._run_node_sequential( - self.node, self.catalog, self.hook_manager, self.session_id + self.node, + self.catalog, + self.hook_manager, # type: ignore[arg-type] + self.session_id, ) for name in node.confirms: @@ -173,7 +193,7 @@ def _run_node_async( self, node: Node, catalog: CatalogProtocol, - hook_manager: PluginManager, + hook_manager: PluginManager, # type: ignore[arg-type] session_id: str | None = None, ) -> Node: with ThreadPoolExecutor() as pool: diff --git a/tests/runner/test_task.py b/tests/runner/test_task.py index ac6f42e6af..a00098ca90 100644 --- a/tests/runner/test_task.py +++ b/tests/runner/test_task.py @@ -3,6 +3,7 @@ from kedro.framework.hooks.manager import _NullPluginManager from kedro.pipeline import node from kedro.runner import Task +from kedro.runner.task import TaskError def generate_one(): @@ -77,3 +78,18 @@ def test_package_name_not_provided(self, mock_logging, is_async, mocker): ) task._run_node_synchronization(package_name=package_name) mock_logging.assert_not_called() + + def test_raise_task_exception(self, mocker): + node_ = mocker.sentinel.node + catalog = mocker.sentinel.catalog + session_id = "fake_session_id" + + with pytest.raises(TaskError, match="No hook_manager provided."): + task = Task( + node=node_, + catalog=catalog, + is_async=False, + session_id=session_id, + parallel=False, + ) + task.execute() From e94d510c7e1df832ea499a343c3c214234b553a5 Mon Sep 17 00:00:00 2001 From: Merel Theisen Date: Thu, 17 Oct 2024 15:17:01 +0100 Subject: [PATCH 15/15] Clean up imports Signed-off-by: Merel Theisen --- kedro/runner/parallel_runner.py | 3 +-- kedro/runner/sequential_runner.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 5edfac464d..4f20295285 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -21,6 +21,7 @@ SharedMemoryDataset, ) from kedro.runner.runner import AbstractRunner +from kedro.runner.task import Task if TYPE_CHECKING: from collections.abc import Iterable @@ -229,8 +230,6 @@ def _run( done = None max_workers = self._get_required_workers_count(pipeline) - from kedro.runner.task import Task - with ProcessPoolExecutor(max_workers=max_workers) as pool: while True: ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 10f3024dc3..508f95234f 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any from kedro.runner.runner import AbstractRunner +from kedro.runner.task import Task if TYPE_CHECKING: from pluggy import PluginManager @@ -75,8 +76,6 @@ def _run( for exec_index, node in enumerate(nodes): try: - from kedro.runner.task import Task - Task( node=node, catalog=catalog,