Skip to content

Commit

Permalink
Merge branch 'fix/4007-thread-runner-with-dataset-factories-bu' into …
Browse files Browse the repository at this point in the history
…fix/4007-thread-runner-with-dataset-factories
  • Loading branch information
ElenaKhaustova committed Aug 27, 2024
2 parents 669e930 + 0ee3330 commit 0ba5eed
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 6 deletions.
8 changes: 7 additions & 1 deletion kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
validate_settings,
)
from kedro.io.core import generate_timestamp
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.runner import AbstractRunner, SequentialRunner, ThreadRunner
from kedro.utils import _find_kedro_project

if TYPE_CHECKING:
Expand Down Expand Up @@ -395,6 +395,12 @@ def run( # noqa: PLR0913
)

try:
if isinstance(runner, ThreadRunner):
for ds in filtered_pipeline.datasets():
if catalog._match_pattern(
catalog._dataset_patterns, ds
) or catalog._match_pattern(catalog._default_pattern, ds):
_ = catalog._get_dataset(ds)
run_result = runner.run(
filtered_pipeline, catalog, hook_manager, session_id
)
Expand Down
5 changes: 0 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
as_completed,
wait,
)
from threading import Lock
from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator

from more_itertools import interleave
Expand All @@ -30,8 +29,6 @@

from kedro.pipeline.node import Node

load_dataset_lock = Lock()


class AbstractRunner(ABC):
"""``AbstractRunner`` is the base class for all ``Pipeline`` runner
Expand Down Expand Up @@ -498,9 +495,7 @@ def _run_node_sequential(

for name in node.inputs:
hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node)
load_dataset_lock.acquire()
inputs[name] = catalog.load(name)
load_dataset_lock.release()
hook_manager.hook.after_dataset_loaded(
dataset_name=name, data=inputs[name], node=node
)
Expand Down
71 changes: 71 additions & 0 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def mock_runner(mocker):
return mock_runner


@pytest.fixture
def mock_thread_runner(mocker):
mock_runner = mocker.patch(
"kedro.runner.thread_runner.ThreadRunner",
autospec=True,
)
mock_runner.__name__ = "MockThreadRunner`"
return mock_runner


@pytest.fixture
def mock_context_class(mocker):
mock_cls = create_attrs_autospec(KedroContext)
Expand Down Expand Up @@ -693,6 +703,67 @@ def test_run(
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
def test_run_thread_runner(
self,
fake_project,
fake_session_id,
fake_pipeline_name,
mock_context_class,
mock_thread_runner,
mocker,
):
"""Test running the project via the session"""

mock_hook = mocker.patch(
"kedro.framework.session.session._create_hook_manager"
).return_value.hook
mock_pipelines = mocker.patch(
"kedro.framework.session.session.pipelines",
return_value={
_FAKE_PIPELINE_NAME: mocker.Mock(),
"__default__": mocker.Mock(),
},
)
mock_context = mock_context_class.return_value
mock_catalog = mock_context._get_catalog.return_value
mock_pipeline = mock_pipelines.__getitem__.return_value.filter.return_value

with KedroSession.create(fake_project) as session:
session.run(runner=mock_thread_runner, pipeline_name=fake_pipeline_name)

record_data = {
"session_id": fake_session_id,
"project_path": fake_project.as_posix(),
"env": mock_context.env,
"kedro_version": kedro_version,
"tags": None,
"from_nodes": None,
"to_nodes": None,
"node_names": None,
"from_inputs": None,
"to_outputs": None,
"load_versions": None,
"extra_params": {},
"pipeline_name": fake_pipeline_name,
"namespace": None,
"runner": mock_thread_runner.__name__,
}

mock_hook.before_pipeline_run.assert_called_once_with(
run_params=record_data, pipeline=mock_pipeline, catalog=mock_catalog
)
mock_thread_runner.run.assert_called_once_with(
mock_pipeline, mock_catalog, session._hook_manager, fake_session_id
)
mock_hook.after_pipeline_run.assert_called_once_with(
run_params=record_data,
run_result=mock_thread_runner.run.return_value,
pipeline=mock_pipeline,
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
def test_run_multiple_times(
Expand Down

0 comments on commit 0ba5eed

Please sign in to comment.