Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix for multi-lines imports %load_node #4068

Merged
merged 9 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

## Bug fixes and other changes
* Moved `_find_run_command()` and `_find_run_command_in_plugins()` from `__main__.py` in the project template to the framework itself.
* Fixed a bug where `%load_node` breaks with multi-lines import statements.

## Breaking changes to the API

Expand Down
26 changes: 23 additions & 3 deletions kedro/ipython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,30 @@ def _prepare_imports(node_func: Callable) -> str:
if python_file:
import_statement = []
with open(python_file) as file:
# Handle multiline imports, i.e.
# from lib import (
# a,
# b,
# c
# )
# This will not work with all edge cases but good enough with common cases that
# are formatted automatically by black, ruff etc.
inside_bracket = False
# Parse any line start with from or import statement
for line in file.readlines():
if line.startswith("from") or line.startswith("import"):
import_statement.append(line.strip())

for _ in file.readlines():
line = _.strip()
if not inside_bracket:
# The common case
if line.startswith("from") or line.startswith("import"):
import_statement.append(line)
if line.endswith("("):
inside_bracket = True
# Inside multi-lines import, append everything.
else:
import_statement.append(line)
if line.endswith(")"):
inside_bracket = False

clean_imports = "\n".join(import_statement).strip()
return clean_imports
Expand Down
3 changes: 2 additions & 1 deletion tests/ipython/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from kedro.pipeline import node
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from . import dummy_function_fixtures # noqa It is needed for the inspect module
from . import dummy_function_fixtures # noqa: F401
from .dummy_function_fixtures import (
dummy_function,
dummy_function_with_loop,
dummy_function_with_variable_length,
dummy_nested_function,
)
from .dummy_multiline_fixtures import dummy_multiline_import_function # noqa: F401

# Constants
PACKAGE_NAME = "fake_package_name"
Expand Down
18 changes: 18 additions & 0 deletions tests/ipython/dummy_multiline_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ruff: noqa
# multi-lines import
from logging import (
INFO,
DEBUG,
WARN,
ERROR,
)


def dummy_multiline_import_function(dummy_input, my_input):
"""
Returns True if input is not
"""
# this is an in-line comment in the body of the function
random_assignment = "Added for a longer function"
random_assignment += "make sure to modify variable"
return not dummy_input
20 changes: 18 additions & 2 deletions tests/ipython/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from .conftest import dummy_function, dummy_function_with_loop, dummy_nested_function
from .conftest import (
dummy_function,
dummy_function_with_loop,
dummy_multiline_import_function,
dummy_nested_function,
)


class TestLoadKedroObjects:
Expand Down Expand Up @@ -338,7 +343,7 @@ def test_node_not_found(self, dummy_pipelines):
in str(excinfo.value)
)

def test_prepare_imports(self, mocker, dummy_module_literal):
def test_prepare_imports(self, mocker):
func_imports = """import logging # noqa
from logging import config # noqa
import logging as dummy_logging # noqa
Expand All @@ -347,6 +352,17 @@ def test_prepare_imports(self, mocker, dummy_module_literal):
result = _prepare_imports(dummy_function)
assert result == func_imports

def test_prepare_imports_multiline(self, mocker):
func_imports = """from logging import (
INFO,
DEBUG,
WARN,
ERROR,
)"""

result = _prepare_imports(dummy_multiline_import_function)
assert result == func_imports

def test_prepare_imports_func_not_found(self, mocker):
mocker.patch("inspect.getsourcefile", return_value=None)

Expand Down