diff --git a/.github/styles/Kedro/ignore.txt b/.github/styles/Kedro/ignore.txt index 9634d1b14b..3d568cddc9 100644 --- a/.github/styles/Kedro/ignore.txt +++ b/.github/styles/Kedro/ignore.txt @@ -44,3 +44,5 @@ transcoding transcode Claypot ethanknights +Aneira +Printify diff --git a/.github/workflows/all-checks.yml b/.github/workflows/all-checks.yml index 2dfb971e3d..afd9caced0 100644 --- a/.github/workflows/all-checks.yml +++ b/.github/workflows/all-checks.yml @@ -26,7 +26,7 @@ jobs: strategy: matrix: os: [ windows-latest, ubuntu-latest ] - python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11", "3.12" ] uses: ./.github/workflows/unit-tests.yml with: os: ${{ matrix.os }} @@ -36,7 +36,7 @@ jobs: strategy: matrix: os: [ windows-latest, ubuntu-latest ] - python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11", "3.12" ] uses: ./.github/workflows/e2e-tests.yml with: os: ${{ matrix.os }} @@ -59,7 +59,7 @@ jobs: strategy: matrix: os: [ windows-latest, ubuntu-latest ] - python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11", "3.12" ] uses: ./.github/workflows/pip-compile.yml with: os: ${{ matrix.os }} diff --git a/.github/workflows/benchmark-performance.yml b/.github/workflows/benchmark-performance.yml new file mode 100644 index 0000000000..30922193c3 --- /dev/null +++ b/.github/workflows/benchmark-performance.yml @@ -0,0 +1,59 @@ +name: ASV Benchmark + +on: + push: + branches: + - main # Run benchmarks on every commit to the main branch + workflow_dispatch: + + +jobs: + + benchmark: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: "kedro" + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install asv # Install ASV + + - name: Run ASV benchmarks + run: | + cd kedro + asv machine --machine=github-actions + asv run -v --machine=github-actions + + - name: Set git email and name + run: | + git config --global user.email "kedro@kedro.com" + git config --global user.name "Kedro" + + - name: Checkout target repository + uses: actions/checkout@v4 + with: + repository: kedro-org/kedro-benchmark-results + token: ${{ secrets.GH_TAGGING_TOKEN }} + ref: 'main' + path: "kedro-benchmark-results" + + - name: Copy files to target repository + run: | + cp -r /home/runner/work/kedro/kedro/kedro/.asv /home/runner/work/kedro/kedro/kedro-benchmark-results/ + + - name: Commit and Push changes to kedro-org/kedro-benchmark-results + run: | + cd kedro-benchmark-results + git add . + git commit -m "Add results" + git push diff --git a/.github/workflows/docs-only-checks.yml b/.github/workflows/docs-only-checks.yml index d0cca88abd..edd95de234 100644 --- a/.github/workflows/docs-only-checks.yml +++ b/.github/workflows/docs-only-checks.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.9", "3.10", "3.11", "3.12" ] uses: ./.github/workflows/lint.yml with: os: ${{ matrix.os }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9cd73b3ad2..c9a5c5397c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks -default_stages: [commit, manual] +default_stages: [pre-commit, manual] repos: - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/README.md b/README.md index 5cc0bda930..5e82fcebc2 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@

-[![Python version](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue.svg)](https://pypi.org/project/kedro/) +[![Python version](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue.svg)](https://pypi.org/project/kedro/) [![PyPI version](https://badge.fury.io/py/kedro.svg)](https://pypi.org/project/kedro/) [![Conda version](https://img.shields.io/conda/vn/conda-forge/kedro.svg)](https://anaconda.org/conda-forge/kedro) [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://github.com/kedro-org/kedro/blob/main/LICENSE.md) diff --git a/RELEASE.md b/RELEASE.md index 24722396d4..5447340938 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,12 +1,29 @@ # Upcoming Release ## Major features and improvements +* Dropped Python 3.8 support. +* Implemented `KedroDataCatalog` repeating `DataCatalog` functionality with a few API enhancements: + * Removed `_FrozenDatasets` and access datasets as properties; + * Added get dataset by name feature; + * `add_feed_dict()` was simplified to only add raw data; + * Datasets' initialisation was moved out from `from_config()` method to the constructor. +* Moved development requirements from `requirements.txt` to the dedicated section in `pyproject.toml` for project template. +* Implemented `Protocol` abstraction for the current `DataCatalog` and adding new catalog implementations. +* Refactored `kedro run` and `kedro catalog` commands. +* Moved pattern resolution logic from `DataCatalog` to a separate component - `CatalogConfigResolver`. Updated `DataCatalog` to use `CatalogConfigResolver` internally. * Made packaged Kedro projects return `session.run()` output to be used when running it in the interactive environment. * Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking. + +**Note:** ``KedroDataCatalog`` is an experimental feature and is under active development. Therefore, it is possible we'll introduce breaking changes to this class, so be mindful of that if you decide to use it already. Let us know if you have any feedback about the ``KedroDataCatalog`` or ideas for new features. + ## Bug fixes and other changes * Fixed bug where using dataset factories breaks with `ThreadRunner`. +* Fixed a bug where `SharedMemoryDataset.exists` would not call the underlying `MemoryDataset`. +* Fixed template projects example tests. +* Made credentials loading consistent between `KedroContext._get_catalog()` and `resolve_patterns` so that both use `_get_config_credentials()` ## Breaking changes to the API +* Removed `ShelveStore` to address a security vulnerability. ## Documentation changes * Fix logo on PyPI page. @@ -15,6 +32,10 @@ ## Community contributions * [Puneet](https://github.com/puneeter) * [ethanknights](https://github.com/ethanknights) +* [Manezki](https://github.com/Manezki) +* [MigQ2](https://github.com/MigQ2) +* [Felix Scherz](https://github.com/felixscherz) +* [Yu-Sheng Li](https://github.com/kevin1kevin1k) # Release 0.19.8 diff --git a/asv.conf.json b/asv.conf.json new file mode 100644 index 0000000000..2cfcd3a057 --- /dev/null +++ b/asv.conf.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "project": "Kedro", + "project_url": "https://kedro.org/", + "repo": ".", + "install_command": ["pip install -e ."], + "branches": ["main"], + "environment_type": "virtualenv", + "show_commit_url": "http://github.com/kedro-org/kedro/commit/", + "results_dir": ".asv/results", + "html_dir": ".asv/html" +} diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/benchmarks/benchmark_dummy.py b/benchmarks/benchmark_dummy.py new file mode 100644 index 0000000000..fc047eb712 --- /dev/null +++ b/benchmarks/benchmark_dummy.py @@ -0,0 +1,16 @@ +# Write the benchmarking functions here. +# See "Writing benchmarks" in the asv docs for more information. + + +class TimeSuite: + """ + A dummy benchmark suite to test with asv framework. + """ + def setup(self): + self.d = {} + for x in range(500): + self.d[x] = None + + def time_keys(self): + for key in self.d.keys(): + pass diff --git a/docs/source/api/kedro.framework.session.shelvestore.ShelveStore.rst b/docs/source/api/kedro.framework.session.shelvestore.ShelveStore.rst deleted file mode 100644 index bb1b278487..0000000000 --- a/docs/source/api/kedro.framework.session.shelvestore.ShelveStore.rst +++ /dev/null @@ -1,6 +0,0 @@ -kedro.framework.session.shelvestore.ShelveStore -================================================ - -.. currentmodule:: kedro.framework.session.shelvestore - -.. autoclass:: ShelveStore diff --git a/docs/source/conf.py b/docs/source/conf.py index 562f5a4b0e..a883f76bd6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -70,7 +70,7 @@ intersphinx_mapping = { "kedro-viz": ("https://docs.kedro.org/projects/kedro-viz/en/v6.6.1/", None), "kedro-datasets": ("https://docs.kedro.org/projects/kedro-datasets/en/kedro-datasets-2.0.0/", None), - "cpython": ("https://docs.python.org/3.8/", None), + "cpython": ("https://docs.python.org/3.9/", None), "ipython": ("https://ipython.readthedocs.io/en/8.21.0/", None), "mlflow": ("https://www.mlflow.org/docs/2.12.1/", None), "kedro-mlflow": ("https://kedro-mlflow.readthedocs.io/en/0.12.2/", None), @@ -127,11 +127,14 @@ "typing.Type", "typing.Set", "kedro.config.config.ConfigLoader", + "kedro.io.catalog_config_resolver.CatalogConfigResolver", "kedro.io.core.AbstractDataset", "kedro.io.core.AbstractVersionedDataset", + "kedro.io.core.CatalogProtocol", "kedro.io.core.DatasetError", "kedro.io.core.Version", "kedro.io.data_catalog.DataCatalog", + "kedro.io.kedro_data_catalog.KedroDataCatalog", "kedro.io.memory_dataset.MemoryDataset", "kedro.io.partitioned_dataset.PartitionedDataset", "kedro.pipeline.pipeline.Pipeline", @@ -168,6 +171,9 @@ "D[k] if k in D, else d. d defaults to None.", "None. Update D from mapping/iterable E and F.", "Patterns", + "CatalogConfigResolver", + "CatalogProtocol", + "KedroDataCatalog", ), "py:data": ( "typing.Any", diff --git a/docs/source/contribution/technical_steering_committee.md b/docs/source/contribution/technical_steering_committee.md index a17590bdad..b324c15910 100644 --- a/docs/source/contribution/technical_steering_committee.md +++ b/docs/source/contribution/technical_steering_committee.md @@ -61,10 +61,10 @@ We look for commitment markers who can do the following: | [Huong Nguyen](https://github.com/Huongg) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Ivan Danov](https://github.com/idanov) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Jitendra Gundaniya](https://github.com/jitu5) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | -| [Joel Schwarzmann](https://github.com/datajoely) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | +| [Joel Schwarzmann](https://github.com/datajoely) | [Aneira Health](https://www.aneira.health) | | [Juan Luis Cano](https://github.com/astrojuanlu) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Laura Couto](https://github.com/lrcouto) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | -| [Marcin Zabłocki](https://github.com/marrrcin) | [Printify, Inc.](https://printify.com/) | +| [Marcin Zabłocki](https://github.com/marrrcin) | [Printify, Inc.](https://printify.com/) | | [Merel Theisen](https://github.com/merelcht) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Nok Lam Chan](https://github.com/noklam) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Rashida Kanchwala](https://github.com/rashidakanchwala) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | diff --git a/docs/source/data/how_to_create_a_custom_dataset.md b/docs/source/data/how_to_create_a_custom_dataset.md index 01ad199f55..7f39987dd7 100644 --- a/docs/source/data/how_to_create_a_custom_dataset.md +++ b/docs/source/data/how_to_create_a_custom_dataset.md @@ -4,7 +4,7 @@ ## AbstractDataset -If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to override the `_load` and `_save` and provides `load` and `save` methods that enrich the corresponding private methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. +If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to implement the `load` and `save` methods while providing wrappers that enrich the corresponding methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. ## Scenario @@ -31,8 +31,8 @@ Consult the [Pillow documentation](https://pillow.readthedocs.io/en/stable/insta At the minimum, a valid Kedro dataset needs to subclass the base {py:class}`~kedro.io.AbstractDataset` and provide an implementation for the following abstract methods: -* `_load` -* `_save` +* `load` +* `save` * `_describe` `AbstractDataset` is generically typed with an input data type for saving data, and an output data type for loading data. @@ -70,7 +70,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ self._filepath = filepath - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -78,7 +78,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ ... - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath""" ... @@ -96,11 +96,11 @@ src/kedro_pokemon/datasets └── image_dataset.py ``` -## Implement the `_load` method with `fsspec` +## Implement the `load` method with `fsspec` Many of the built-in Kedro datasets rely on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) as a consistent interface to different data sources, as described earlier in the section about the [Data Catalog](../data/data_catalog.md#dataset-filepath). In this example, it's particularly convenient to use `fsspec` in conjunction with `Pillow` to read image data, since it allows the dataset to work flexibly with different image locations and formats. -Here is the implementation of the `_load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array: +Here is the implementation of the `load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array:
Click to expand @@ -130,7 +130,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -168,14 +168,14 @@ In [2]: from PIL import Image In [3]: Image.fromarray(image).show() ``` -## Implement the `_save` method with `fsspec` +## Implement the `save` method with `fsspec` Similarly, we can implement the `_save` method as follows: ```python class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems save_path = get_filepath_str(self._filepath, self._protocol) @@ -243,7 +243,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -254,7 +254,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._filepath, self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -312,7 +312,7 @@ To add versioning support to the new dataset we need to extend the {py:class}`~kedro.io.AbstractVersionedDataset` to: * Accept a `version` keyword argument as part of the constructor -* Adapt the `_load` and `_save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively +* Adapt the `load` and `save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively The following amends the full implementation of our basic `ImageDataset`. It now loads and saves data to and from a versioned subfolder (`data/01_raw/pokemon-images-and-types/images/images/pikachu.png//pikachu.png` with `version` being a datetime-formatted string `YYYY-MM-DDThh.mm.ss.sssZ` by default): @@ -359,7 +359,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): glob_function=self._fs.glob, ) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -370,7 +370,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -435,7 +435,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas + glob_function=self._fs.glob, + ) + - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -447,7 +447,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" - save_path = get_filepath_str(self._filepath, self._protocol) + save_path = get_filepath_str(self._get_save_path(), self._protocol) diff --git a/docs/source/deployment/aws_step_functions.md b/docs/source/deployment/aws_step_functions.md index 8a3d70da53..9802eb542b 100644 --- a/docs/source/deployment/aws_step_functions.md +++ b/docs/source/deployment/aws_step_functions.md @@ -156,7 +156,7 @@ This file acts as the handler for each Lambda function in our pipeline, receives ```Dockerfile # Define global args ARG FUNCTION_DIR="/home/app/" -ARG RUNTIME_VERSION="3.8" +ARG RUNTIME_VERSION="3.9" # Stage 1 - bundle base image + runtime # Grab a fresh copy of the image and install GCC diff --git a/docs/source/deployment/databricks/databricks_deployment_workflow.md b/docs/source/deployment/databricks/databricks_deployment_workflow.md index aebb1cb94f..4cc4b2c57b 100644 --- a/docs/source/deployment/databricks/databricks_deployment_workflow.md +++ b/docs/source/deployment/databricks/databricks_deployment_workflow.md @@ -10,7 +10,7 @@ Here are some typical use cases for running a packaged Kedro project as a Databr - **Data engineering pipeline**: the output of your Kedro project is a file or set of files containing cleaned and processed data. - **Machine learning with MLflow**: your Kedro project runs an ML model; metrics about your experiments are tracked in MLflow. -- **Automated and scheduled runs**: your Kedro project should be [run on Databricks automatically](https://docs.databricks.com/workflows/jobs/schedule-jobs.html#add-a-job-schedule). +- **Automated and scheduled runs**: your Kedro project should be [run on Databricks automatically](https://docs.databricks.com/en/jobs/scheduled.html#add-a-job-schedule). - **CI/CD integration**: you have a CI/CD pipeline that produces a packaged Kedro project. Running your packaged project as a Databricks job is very different from running it from a Databricks notebook. The Databricks job cluster has to be provisioned and started for each run, which is significantly slower than running it as a notebook on a cluster that has already been started. In addition, there is no way to change your project's code once it has been packaged. Instead, you must change your code, create a new package, and then upload it to Databricks again. diff --git a/docs/source/deployment/databricks/databricks_ide_development_workflow.md b/docs/source/deployment/databricks/databricks_ide_development_workflow.md index f85799272d..7146ec7927 100644 --- a/docs/source/deployment/databricks/databricks_ide_development_workflow.md +++ b/docs/source/deployment/databricks/databricks_ide_development_workflow.md @@ -31,7 +31,7 @@ The main steps in this tutorial are as follows: - An active [Databricks deployment](https://docs.databricks.com/getting-started/index.html). - A [Databricks cluster](https://docs.databricks.com/clusters/configure.html) configured with a recent version (>= 11.3 is recommended) of the Databricks runtime. -- [Conda installed](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) on your local machine in order to create a virtual environment with a specific version of Python (>= 3.8 is required). If you have Python >= 3.8 installed, you can use other software to create a virtual environment. +- [Conda installed](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) on your local machine in order to create a virtual environment with a specific version of Python (>= 3.9 is required). If you have Python >= 3.9 installed, you can use other software to create a virtual environment. ## Set up your project diff --git a/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md b/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md index 8fc4b76f4c..8a0616708b 100644 --- a/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md +++ b/docs/source/deployment/databricks/databricks_notebooks_development_workflow.md @@ -19,7 +19,7 @@ This tutorial introduces a Kedro project development workflow using only the Dat - An active [Databricks deployment](https://docs.databricks.com/getting-started/index.html). - A [Databricks cluster](https://docs.databricks.com/clusters/configure.html) configured with a recent version (>= 11.3 is recommended) of the Databricks runtime. -- Python >= 3.8 installed. +- Python >= 3.9 installed. - Git installed. - A [GitHub](https://github.com/) account. - A Python environment management system installed, [venv](https://docs.python.org/3/library/venv.html), [virtualenv](https://virtualenv.pypa.io/en/latest/) or [Conda](https://docs.conda.io/en/latest/) are popular choices. diff --git a/docs/source/development/automated_testing.md b/docs/source/development/automated_testing.md index ed3efe3287..c4ca9a6538 100644 --- a/docs/source/development/automated_testing.md +++ b/docs/source/development/automated_testing.md @@ -19,21 +19,36 @@ There are many testing frameworks available for Python. One of the most popular Let's look at how you can start working with `pytest` in your Kedro project. -### Prerequisite: Install your Kedro project +### Install test requirements +Before getting started with test requirements, it is important to ensure you have installed your project locally. This allows you to test different parts of your project by importing them into your test files. + + +To install your project including all the project-specific dependencies and test requirements: +1. Add the following section to the `pyproject.toml` file located in the project root: +```toml +[project.optional-dependencies] +dev = [ + "pytest-cov", + "pytest-mock", + "pytest", +] +``` + +2. Navigate to the root directory of the project and run: +```bash +pip install ."[dev]" +``` -Before getting started with `pytest`, it is important to ensure you have installed your project locally. This allows you to test different parts of your project by importing them into your test files. +Alternatively, you can individually install test requirements as you would install other packages with `pip`, making sure you have installed your project locally and your [project's virtual environment is active](../get_started/install.md#create-a-virtual-environment-for-your-kedro-project). -To install your project, navigate to your project root and run the following command: +1. To install your project, navigate to your project root and run the following command: ```bash pip install -e . ``` - >**NOTE**: The option `-e` installs an editable version of your project, allowing you to make changes to the project files without needing to re-install them each time. -### Install `pytest` - -Install `pytest` as you would install other packages with `pip`, making sure your [project's virtual environment is active](../get_started/install.md#create-a-virtual-environment-for-your-kedro-project). +2. Install test requirements one by one: ```bash pip install pytest ``` diff --git a/docs/source/development/linting.md b/docs/source/development/linting.md index 61989cdf85..fbc0b0147c 100644 --- a/docs/source/development/linting.md +++ b/docs/source/development/linting.md @@ -18,17 +18,17 @@ There are a variety of Python tools available to use with your Kedro projects. T type. ### Install the tools -Install `ruff` by adding the following lines to your project's `requirements.txt` -file: -```text -ruff # Used for linting, formatting and sorting module imports +To install `ruff` add the following section to the `pyproject.toml` file located in the project root: +```toml +[project.optional-dependencies] +dev = ["ruff"] ``` -To install all the project-specific dependencies, including the linting tools, navigate to the root directory of the +Then to install your project including all the project-specific dependencies and the linting tools, navigate to the root directory of the project and run: ```bash -pip install -r requirements.txt +pip install ."[dev]" ``` Alternatively, you can individually install the linting tools using the following shell commands: diff --git a/docs/source/get_started/install.md b/docs/source/get_started/install.md index d61d084a89..fe75decda2 100644 --- a/docs/source/get_started/install.md +++ b/docs/source/get_started/install.md @@ -1,7 +1,7 @@ # Set up Kedro ## Installation prerequisites -* **Python**: Kedro supports macOS, Linux, and Windows and is built for Python 3.8+. You'll select a version of Python when you create a virtual environment for your Kedro project. +* **Python**: Kedro supports macOS, Linux, and Windows and is built for Python 3.9+. You'll select a version of Python when you create a virtual environment for your Kedro project. * **Virtual environment**: You should create a new virtual environment for *each* new Kedro project you work on to isolate its Python dependencies from those of other projects. @@ -55,7 +55,7 @@ deactivate conda create --name kedro-environment python=3.10 -y ``` -The example below uses Python 3.10, and creates a virtual environment called `kedro-environment`. You can opt for a different version of Python (any version >= 3.8 and <3.12) for your project, and you can name it anything you choose. +The example below uses Python 3.10, and creates a virtual environment called `kedro-environment`. You can opt for a different version of Python (any version >= 3.9 and <3.12) for your project, and you can name it anything you choose. The `conda` virtual environment is not dependent on your current working directory and can be activated from any directory: @@ -136,7 +136,7 @@ When migrating an existing project to a newer Kedro version, make sure you also ## Summary * Kedro can be used on Windows, macOS or Linux. -* Installation prerequisites include a virtual environment manager like `conda`, Python 3.8+, and `git`. +* Installation prerequisites include a virtual environment manager like `conda`, Python 3.9+, and `git`. * You should install Kedro using `pip install kedro`. If you encounter any problems as you set up Kedro, ask for help on Kedro's [Slack organisation](https://slack.kedro.org) or review the [searchable archive of Slack discussions](https://linen-slack.kedro.org/). diff --git a/docs/source/index.rst b/docs/source/index.rst index ce8d85a9a1..ecbdbbe381 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,9 +23,9 @@ Welcome to Kedro's award-winning documentation! :target: https://opensource.org/license/apache2-0-php/ :alt: License is Apache 2.0 -.. image:: https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue.svg +.. image:: https://img.shields.io/badge/3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue.svg :target: https://pypi.org/project/kedro/ - :alt: Python version 3.8, 3.9, 3.10, 3.11 + :alt: Python version 3.9, 3.10, 3.11, 3.12 .. image:: https://badge.fury.io/py/kedro.svg :target: https://pypi.org/project/kedro/ diff --git a/docs/source/integrations/pyspark_integration.md b/docs/source/integrations/pyspark_integration.md index f1278955f5..e8bbb359ea 100644 --- a/docs/source/integrations/pyspark_integration.md +++ b/docs/source/integrations/pyspark_integration.md @@ -35,7 +35,7 @@ class SparkHooks: """ # Load the spark configuration in spark.yaml using the config loader - parameters = context.config_loader.get("spark*", "spark*/**") + parameters = context.config_loader["spark"] spark_conf = SparkConf().setAll(parameters.items()) # Initialise the spark session diff --git a/docs/source/introduction/index.md b/docs/source/introduction/index.md index 38dab844a3..d9de994933 100644 --- a/docs/source/introduction/index.md +++ b/docs/source/introduction/index.md @@ -18,6 +18,6 @@ Use the left-hand table of contents to explore the documentation available for m ```{note} We have designed the preliminary documentation and the [spaceflights tutorial](../tutorial/spaceflights_tutorial.md) for anyone new to Kedro. The more knowledge of Python you have, the easier you will find the learning curve. -There are many excellent online resources for learning Python; you should choose those that reference Python 3, as Kedro is built for Python 3.8+. There are curated lists of online resources, such as the [official Python programming language website](https://www.python.org/) and this list of [free programming books and tutorials](https://github.com/EbookFoundation/free-programming-books/blob/master/books/free-programming-books-langs.md#python). +There are many excellent online resources for learning Python; you should choose those that reference Python 3, as Kedro is built for Python 3.9+. There are curated lists of online resources, such as the [official Python programming language website](https://www.python.org/) and this list of [free programming books and tutorials](https://github.com/EbookFoundation/free-programming-books/blob/master/books/free-programming-books-langs.md#python). ``` diff --git a/docs/source/meta/images/slice_pipeline_kedro_viz.gif b/docs/source/meta/images/slice_pipeline_kedro_viz.gif new file mode 100644 index 0000000000..2d49c9e766 Binary files /dev/null and b/docs/source/meta/images/slice_pipeline_kedro_viz.gif differ diff --git a/docs/source/nodes_and_pipelines/slice_a_pipeline.md b/docs/source/nodes_and_pipelines/slice_a_pipeline.md index 2324a12fb0..2b2871dffe 100644 --- a/docs/source/nodes_and_pipelines/slice_a_pipeline.md +++ b/docs/source/nodes_and_pipelines/slice_a_pipeline.md @@ -1,6 +1,13 @@ # Slice a pipeline -Sometimes it is desirable to run a subset, or a 'slice' of a pipeline's nodes. In this page, we illustrate the programmatic options that Kedro provides. You can also use the [Kedro CLI to pass parameters to `kedro run`](../development/commands_reference.md#run-the-project) command and slice a pipeline. +Sometimes it is desirable to run a subset, or a 'slice' of a pipeline's nodes. There are two primary ways to achieve this: + + +1. **Visually through Kedro-Viz:** This approach allows you to visually choose and slice pipeline nodes, which then generates a run command for executing the slice within your Kedro project. Detailed steps on how to achieve this are available in the Kedro-Viz documentation: [Slice a Pipeline](https://docs.kedro.org/projects/kedro-viz/en/stable/slice_a_pipeline.html). + +![](../meta/images/slice_pipeline_kedro_viz.gif) + +2. **Programmatically with the Kedro CLI.** You can also use the [Kedro CLI to pass parameters to `kedro run`](../development/commands_reference.md#run-the-project) command and slice a pipeline. In this page, we illustrate the programmatic options that Kedro provides. Let's look again at the example pipeline from the [pipeline introduction documentation](./pipeline_introduction.md#how-to-build-a-pipeline), which computes the variance of a set of numbers: diff --git a/docs/source/notebooks_and_ipython/notebook-example/add_kedro_to_a_notebook.ipynb b/docs/source/notebooks_and_ipython/notebook-example/add_kedro_to_a_notebook.ipynb index 3c036386cc..1e91ef0d29 100644 --- a/docs/source/notebooks_and_ipython/notebook-example/add_kedro_to_a_notebook.ipynb +++ b/docs/source/notebooks_and_ipython/notebook-example/add_kedro_to_a_notebook.ipynb @@ -683,7 +683,6 @@ "####################\n", "# Data processing #\n", "####################\n", - "from typing import Dict, Tuple\n", "\n", "import pandas as pd\n", "from sklearn.linear_model import LinearRegression\n", @@ -736,7 +735,7 @@ "##################################\n", "\n", "\n", - "def split_data(data: pd.DataFrame, parameters: Dict) -> Tuple:\n", + "def split_data(data: pd.DataFrame, parameters: dict) -> tuple:\n", " X = data[parameters[\"features\"]]\n", " y = data[\"price\"]\n", " X_train, X_test, y_train, y_test = train_test_split(\n", @@ -796,7 +795,6 @@ "outputs": [], "source": [ "# Kedro setup for data management and configuration\n", - "from typing import Dict, Tuple\n", "\n", "import pandas as pd\n", "from sklearn.linear_model import LinearRegression\n", @@ -873,7 +871,7 @@ "##################################\n", "\n", "\n", - "def split_data(data: pd.DataFrame, parameters: Dict) -> Tuple:\n", + "def split_data(data: pd.DataFrame, parameters: dict) -> tuple:\n", " X = data[parameters[\"features\"]]\n", " y = data[\"price\"]\n", " X_train, X_test, y_train, y_test = train_test_split(\n", diff --git a/docs/source/starters/new_project_tools.md b/docs/source/starters/new_project_tools.md index ab43308ecb..c405200407 100644 --- a/docs/source/starters/new_project_tools.md +++ b/docs/source/starters/new_project_tools.md @@ -44,7 +44,7 @@ To skip this step in future use --tools To find out more: https://docs.kedro.org/en/stable/starters/new_project_tools.html Tools -1) Lint: Basic linting with Black and Ruff +1) Lint: Basic linting with Ruff 2) Test: Basic testing with pytest 3) Log: Additional, environment-specific logging options 4) Docs: A Sphinx documentation setup @@ -65,8 +65,7 @@ A list of available tools can also be accessed by running `kedro new --help` Tools - 1) Linting: Provides a basic linting setup with Black - and Ruff + 1) Linting: Provides a basic linting setup with Ruff 2) Testing: Provides basic testing setup with pytest @@ -165,7 +164,7 @@ The available tools include: [linting](#linting), [testing](#testing), [custom l ### Linting -The Kedro linting tool introduces [`black`](https://black.readthedocs.io/en/stable/index.html) and [`ruff`](https://docs.astral.sh/ruff/) as dependencies in your new project's requirements. After project creation, make sure these are installed by running the following command from the project root: +The Kedro linting tool introduces [`ruff`](https://docs.astral.sh/ruff/) as dependency in your new project's requirements. After project creation, make sure these are installed by running the following command from the project root: ```bash pip install -r requirements.txt @@ -175,7 +174,6 @@ The linting tool will configure `ruff` with the following settings by default: ```toml #pyproject.toml -[tool.ruff] line-length = 88 show-fixes = true select = [ @@ -187,7 +185,7 @@ select = [ "PL", # Pylint "T201", # Print Statement ] -ignore = ["E501"] # Black takes care of line-too-long +ignore = ["E501"] # Ruff format takes care of line-too-long ``` With these installed, you can then make use of the following commands to format and lint your code: diff --git a/docs/source/tutorial/tutorial_template.md b/docs/source/tutorial/tutorial_template.md index d8462f1b20..1586487707 100644 --- a/docs/source/tutorial/tutorial_template.md +++ b/docs/source/tutorial/tutorial_template.md @@ -32,7 +32,6 @@ The spaceflights project dependencies are stored in `requirements.txt`(you may f ```text # code quality packages -ipython>=7.31.1, <8.0; python_version < '3.8' ipython~=8.10; python_version >= '3.8' ruff==0.1.8 diff --git a/features/environment.py b/features/environment.py index 75bd183ed8..f1d45df926 100644 --- a/features/environment.py +++ b/features/environment.py @@ -5,7 +5,6 @@ import os import shutil import subprocess -import sys import tempfile import venv from pathlib import Path @@ -131,11 +130,6 @@ def _install_project_requirements(context): .splitlines() ) install_reqs = [req for req in install_reqs if "{" not in req and "#" not in req] - # For Python versions 3.9 and above we use the new dataset dependency format introduced in `kedro-datasets` 3.0.0 - if sys.version_info.minor > MINOR_PYTHON_38_VERSION: - install_reqs.append("kedro-datasets[pandas-csvdataset]") - # For Python 3.8 we use the older `kedro-datasets` dependency format - else: - install_reqs.append("kedro-datasets[pandas.CSVDataset]") + install_reqs.append("kedro-datasets[pandas-csvdataset]") call([context.pip, "install", *install_reqs], env=context.env) return context diff --git a/features/load_node.feature b/features/load_node.feature index fbc5a65a07..e745378e22 100644 --- a/features/load_node.feature +++ b/features/load_node.feature @@ -5,5 +5,6 @@ Feature: load_node in new project And I have run a non-interactive kedro new with starter "default" Scenario: Execute ipython load_node magic - When I execute the load_node magic command + When I install project and its dev dependencies + And I execute the load_node magic command Then the logs should show that load_node executed successfully diff --git a/features/steps/cli_steps.py b/features/steps/cli_steps.py index 7ee2c153d8..62cda23001 100644 --- a/features/steps/cli_steps.py +++ b/features/steps/cli_steps.py @@ -755,3 +755,13 @@ def exec_magic_command(context): def change_dir(context, dir): """Execute Kedro target.""" util.chdir(dir) + + +@when("I install project and its dev dependencies") +def pip_install_project_and_dev_dependencies(context): + """Install project and its development dependencies using pip.""" + _ = run( + [context.pip, "install", ".[dev]"], + env=context.env, + cwd=str(context.root_project_dir), + ) diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml index 462dd26eee..eb7cb5f113 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml @@ -12,15 +12,21 @@ dynamic = ["dependencies", "version"] [project.optional-dependencies] docs = [ - "docutils<0.18.0", - "sphinx~=3.4.3", - "sphinx_rtd_theme==0.5.1", + "docutils<0.21", + "sphinx>=5.3,<7.3", + "sphinx_rtd_theme==2.0.0", "nbsphinx==0.8.1", - "sphinx-autodoc-typehints==1.11.1", - "sphinx_copybutton==0.3.1", + "sphinx-autodoc-typehints==1.20.2", + "sphinx_copybutton==0.5.2", "ipykernel>=5.3, <7.0", - "Jinja2<3.1.0", - "myst-parser~=0.17.2", + "Jinja2<3.2.0", + "myst-parser>=1.0,<2.1" +] +dev = [ + "pytest-cov~=3.0", + "pytest-mock>=1.7.1, <2.0", + "pytest~=7.2", + "ruff~=0.1.8" ] [tool.setuptools.dynamic] diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt b/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt index 8da5d60851..014df14d12 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt @@ -1,10 +1,5 @@ -ruff==0.1.8 ipython>=8.10 jupyterlab>=3.0 notebook kedro~={{ cookiecutter.kedro_version}} -kedro-datasets[pandas-csvdataset]; python_version >= "3.9" -kedro-datasets[pandas.CSVDataset]<2.0.0; python_version < '3.9' -pytest-cov~=3.0 -pytest-mock>=1.7.1, <2.0 -pytest~=7.2 +kedro-datasets[pandas-csvdataset] diff --git a/features/steps/util.py b/features/steps/util.py index 437b3f6f5e..588e4530e1 100644 --- a/features/steps/util.py +++ b/features/steps/util.py @@ -6,9 +6,10 @@ import re from contextlib import contextmanager from time import sleep, time -from typing import TYPE_CHECKING, Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: + from collections.abc import Iterator from pathlib import Path diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index c4850159a1..8d82ebf360 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -8,10 +8,10 @@ import logging import mimetypes import typing -from collections.abc import KeysView +from collections.abc import Iterable, KeysView from enum import Enum, auto from pathlib import Path -from typing import Any, Callable, Iterable +from typing import Any, Callable import fsspec from omegaconf import DictConfig, OmegaConf diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 223980dade..25fad6083d 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -2,9 +2,8 @@ from __future__ import annotations -import copy from collections import defaultdict -from itertools import chain +from itertools import chain, filterfalse from typing import TYPE_CHECKING, Any import click @@ -28,6 +27,11 @@ def _create_session(package_name: str, **kwargs: Any) -> KedroSession: return KedroSession.create(**kwargs) +def is_parameter(dataset_name: str) -> bool: + """Check if dataset is a parameter.""" + return dataset_name.startswith("params:") or dataset_name == "parameters" + + @click.group(name="Kedro") def catalog_cli() -> None: # pragma: no cover pass @@ -88,21 +92,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: # resolve any factory datasets in the pipeline factory_ds_by_type = defaultdict(list) - for ds_name in default_ds: - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy + for ds_name in default_ds: + if data_catalog.config_resolver.match_pattern(ds_name): + ds_config = data_catalog.config_resolver.resolve_pattern(ds_name) + factory_ds_by_type[ds_config.get("type", "DefaultDataset")].append( + ds_name ) - factory_ds_by_type[ds_config["type"]].append(ds_name) default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values())) @@ -128,12 +124,10 @@ def _map_type_to_datasets( datasets of the specific type as a value. """ mapping = defaultdict(list) # type: ignore[var-annotated] - for dataset in datasets: - is_param = dataset.startswith("params:") or dataset == "parameters" - if not is_param: - ds_type = datasets_meta[dataset].__class__.__name__ - if dataset not in mapping[ds_type]: - mapping[ds_type].append(dataset) + for dataset_name in filterfalse(is_parameter, datasets): + ds_type = datasets_meta[dataset_name].__class__.__name__ + if dataset_name not in mapping[ds_type]: + mapping[ds_type].append(dataset_name) return mapping @@ -170,20 +164,12 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N f"'{pipeline_name}' pipeline not found! Existing pipelines: {existing_pipelines}" ) - pipe_datasets = { - ds_name - for ds_name in pipeline.datasets() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + pipeline_datasets = set(filterfalse(is_parameter, pipeline.datasets())) - catalog_datasets = { - ds_name - for ds_name in context.catalog._datasets.keys() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + catalog_datasets = set(filterfalse(is_parameter, context.catalog.list())) # Datasets that are missing in Data Catalog - missing_ds = sorted(pipe_datasets - catalog_datasets) + missing_ds = sorted(pipeline_datasets - catalog_datasets) if missing_ds: catalog_path = ( context.project_path @@ -221,12 +207,9 @@ def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: session = _create_session(metadata.package_name, env=env) context = session.load_context() - catalog_factories = { - **context.catalog._dataset_patterns, - **context.catalog._default_pattern, - } + catalog_factories = context.catalog.config_resolver.list_patterns() if catalog_factories: - click.echo(yaml.dump(list(catalog_factories.keys()))) + click.echo(yaml.dump(catalog_factories)) else: click.echo("There are no dataset factories in the catalog.") @@ -242,7 +225,7 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: context = session.load_context() catalog_config = context.config_loader["catalog"] - credentials_config = context.config_loader.get("credentials", None) + credentials_config = context._get_config_credentials() data_catalog = DataCatalog.from_config( catalog=catalog_config, credentials=credentials_config ) @@ -250,35 +233,25 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: explicit_datasets = { ds_name: ds_config for ds_name, ds_config in catalog_config.items() - if not data_catalog._is_pattern(ds_name) + if not data_catalog.config_resolver.is_pattern(ds_name) } target_pipelines = pipelines.keys() - datasets = set() + pipeline_datasets = set() for pipe in target_pipelines: pl_obj = pipelines.get(pipe) if pl_obj: - datasets.update(pl_obj.datasets()) + pipeline_datasets.update(pl_obj.datasets()) - for ds_name in datasets: - is_param = ds_name.startswith("params:") or ds_name == "parameters" - if ds_name in explicit_datasets or is_param: + for ds_name in pipeline_datasets: + if ds_name in explicit_datasets or is_parameter(ds_name): continue - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) + ds_config = data_catalog.config_resolver.resolve_pattern(ds_name) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy - ) + # Exclude MemoryDatasets not set in the catalog explicitly + if ds_config: explicit_datasets[ds_name] = ds_config secho(yaml.dump(explicit_datasets)) diff --git a/kedro/framework/cli/cli.py b/kedro/framework/cli/cli.py index b22fa70d00..f5917e1b87 100644 --- a/kedro/framework/cli/cli.py +++ b/kedro/framework/cli/cli.py @@ -10,10 +10,13 @@ import traceback from collections import defaultdict from pathlib import Path -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any import click +if TYPE_CHECKING: + from collections.abc import Sequence + from kedro import __version__ as version from kedro.framework.cli import BRIGHT_BLACK, ORANGE from kedro.framework.cli.hooks import get_cli_hook_manager diff --git a/kedro/framework/cli/micropkg.py b/kedro/framework/cli/micropkg.py index d2733ff712..a80efd172a 100644 --- a/kedro/framework/cli/micropkg.py +++ b/kedro/framework/cli/micropkg.py @@ -12,7 +12,8 @@ import toml from importlib import import_module from pathlib import Path -from typing import Any, Iterable, Iterator, TYPE_CHECKING +from typing import Any, TYPE_CHECKING + import click from omegaconf import OmegaConf @@ -42,6 +43,7 @@ if TYPE_CHECKING: from kedro.framework.startup import ProjectMetadata from importlib_metadata import PackageMetadata + from collections.abc import Iterable, Iterator _PYPROJECT_TOML_TEMPLATE = """ [build-system] diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index e165c15579..1b50408cc5 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -15,10 +15,14 @@ import typing import warnings from collections import defaultdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence from importlib import import_module from itertools import chain from pathlib import Path -from typing import IO, Any, Callable, Iterable, Sequence +from typing import IO, Any, Callable import click import importlib_metadata @@ -55,7 +59,7 @@ def call(cmd: list[str], **kwargs: Any) -> None: # pragma: no cover Raises: click.exceptions.Exit: If `subprocess.run` returns non-zero code. """ - click.echo(" ".join(shlex.quote(c) for c in cmd)) + click.echo(shlex.join(cmd)) code = subprocess.run(cmd, **kwargs).returncode # noqa: PLW1510, S603 if code: raise click.exceptions.Exit(code=code) diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 3b61b747f6..5c14cbae38 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -14,7 +14,7 @@ from kedro.config import AbstractConfigLoader, MissingConfigException from kedro.framework.project import settings -from kedro.io import DataCatalog # noqa: TCH001 +from kedro.io import CatalogProtocol, DataCatalog # noqa: TCH001 from kedro.pipeline.transcoding import _transcode_split if TYPE_CHECKING: @@ -123,7 +123,7 @@ def _convert_paths_to_absolute_posix( return conf_dictionary -def _validate_transcoded_datasets(catalog: DataCatalog) -> None: +def _validate_transcoded_datasets(catalog: CatalogProtocol) -> None: """Validates transcoded datasets are correctly named Args: @@ -178,13 +178,13 @@ class KedroContext: ) @property - def catalog(self) -> DataCatalog: - """Read-only property referring to Kedro's ``DataCatalog`` for this context. + def catalog(self) -> CatalogProtocol: + """Read-only property referring to Kedro's catalog` for this context. Returns: - DataCatalog defined in `catalog.yml`. + catalog defined in `catalog.yml`. Raises: - KedroContextError: Incorrect ``DataCatalog`` registered for the project. + KedroContextError: Incorrect catalog registered for the project. """ return self._get_catalog() @@ -213,13 +213,13 @@ def _get_catalog( self, save_version: str | None = None, load_versions: dict[str, str] | None = None, - ) -> DataCatalog: - """A hook for changing the creation of a DataCatalog instance. + ) -> CatalogProtocol: + """A hook for changing the creation of a catalog instance. Returns: - DataCatalog defined in `catalog.yml`. + catalog defined in `catalog.yml`. Raises: - KedroContextError: Incorrect ``DataCatalog`` registered for the project. + KedroContextError: Incorrect catalog registered for the project. """ # '**/catalog*' reads modular pipeline configs diff --git a/kedro/framework/hooks/manager.py b/kedro/framework/hooks/manager.py index 5cbbcf9f27..ceec064246 100644 --- a/kedro/framework/hooks/manager.py +++ b/kedro/framework/hooks/manager.py @@ -3,7 +3,8 @@ """ import logging -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any from pluggy import PluginManager diff --git a/kedro/framework/hooks/specs.py b/kedro/framework/hooks/specs.py index b0037a0878..3b32eb294c 100644 --- a/kedro/framework/hooks/specs.py +++ b/kedro/framework/hooks/specs.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from kedro.framework.context import KedroContext - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -22,7 +22,7 @@ class DataCatalogSpecs: @hook_spec def after_catalog_created( # noqa: PLR0913 self, - catalog: DataCatalog, + catalog: CatalogProtocol, conf_catalog: dict[str, Any], conf_creds: dict[str, Any], feed_dict: dict[str, Any], @@ -53,7 +53,7 @@ class NodeSpecs: def before_node_run( self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, session_id: str, @@ -63,7 +63,7 @@ def before_node_run( Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -81,7 +81,7 @@ def before_node_run( def after_node_run( # noqa: PLR0913 self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], outputs: dict[str, Any], is_async: bool, @@ -93,7 +93,7 @@ def after_node_run( # noqa: PLR0913 Args: node: The ``Node`` that ran. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -110,7 +110,7 @@ def on_node_error( # noqa: PLR0913 self, error: Exception, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, session_id: str, @@ -122,7 +122,7 @@ def on_node_error( # noqa: PLR0913 Args: error: The uncaught exception thrown during the node run. node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -137,7 +137,7 @@ class PipelineSpecs: @hook_spec def before_pipeline_run( - self, run_params: dict[str, Any], pipeline: Pipeline, catalog: DataCatalog + self, run_params: dict[str, Any], pipeline: Pipeline, catalog: CatalogProtocol ) -> None: """Hook to be invoked before a pipeline runs. @@ -164,7 +164,7 @@ def before_pipeline_run( } pipeline: The ``Pipeline`` that will be run. - catalog: The ``DataCatalog`` to be used during the run. + catalog: An implemented instance of ``CatalogProtocol`` to be used during the run. """ pass @@ -174,7 +174,7 @@ def after_pipeline_run( run_params: dict[str, Any], run_result: dict[str, Any], pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """Hook to be invoked after a pipeline runs. @@ -202,7 +202,7 @@ def after_pipeline_run( run_result: The output of ``Pipeline`` run. pipeline: The ``Pipeline`` that was run. - catalog: The ``DataCatalog`` used during the run. + catalog: An implemented instance of ``CatalogProtocol`` used during the run. """ pass @@ -212,7 +212,7 @@ def on_pipeline_error( error: Exception, run_params: dict[str, Any], pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """Hook to be invoked if a pipeline run throws an uncaught Exception. The signature of this error hook should match the signature of ``before_pipeline_run`` @@ -242,7 +242,7 @@ def on_pipeline_error( } pipeline: The ``Pipeline`` that will was run. - catalog: The ``DataCatalog`` used during the run. + catalog: An implemented instance of ``CatalogProtocol`` used during the run. """ pass diff --git a/kedro/framework/project/__init__.py b/kedro/framework/project/__init__.py index a3248b9daf..195fa077f6 100644 --- a/kedro/framework/project/__init__.py +++ b/kedro/framework/project/__init__.py @@ -20,6 +20,7 @@ from dynaconf import LazySettings from dynaconf.validator import ValidationError, Validator +from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline, pipeline if TYPE_CHECKING: @@ -59,6 +60,25 @@ def validate( ) +class _ImplementsCatalogProtocolValidator(Validator): + """A validator to check if the supplied setting value is a subclass of the default class""" + + def validate( + self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any + ) -> None: + super().validate(settings, *args, **kwargs) + + protocol = CatalogProtocol + for name in self.names: + setting_value = getattr(settings, name) + if not isinstance(setting_value(), protocol): + raise ValidationError( + f"Invalid value '{setting_value.__module__}.{setting_value.__qualname__}' " + f"received for setting '{name}'. It must implement " + f"'{protocol.__module__}.{protocol.__qualname__}'." + ) + + class _HasSharedParentClassValidator(Validator): """A validator to check that the parent of the default class is an ancestor of the settings value.""" @@ -115,8 +135,9 @@ class _ProjectSettings(LazySettings): _CONFIG_LOADER_ARGS = Validator( "CONFIG_LOADER_ARGS", default={"base_env": "base", "default_run_env": "local"} ) - _DATA_CATALOG_CLASS = _IsSubclassValidator( - "DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog") + _DATA_CATALOG_CLASS = _ImplementsCatalogProtocolValidator( + "DATA_CATALOG_CLASS", + default=_get_default_class("kedro.io.DataCatalog"), ) def __init__(self, *args: Any, **kwargs: Any): diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 91928f7c4b..ec0dc9bf4d 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -11,7 +11,7 @@ import traceback from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any import click @@ -28,6 +28,8 @@ from kedro.utils import _find_kedro_project if TYPE_CHECKING: + from collections.abc import Iterable + from kedro.config import AbstractConfigLoader from kedro.framework.context import KedroContext from kedro.framework.session.store import BaseSessionStore @@ -394,13 +396,11 @@ def run( # noqa: PLR0913 run_params=record_data, pipeline=filtered_pipeline, catalog=catalog ) + if isinstance(runner, ThreadRunner): + for ds in filtered_pipeline.datasets(): + if catalog.config_resolver.match_pattern(ds): + _ = catalog._get_dataset(ds) try: - if isinstance(runner, ThreadRunner): - for ds in filtered_pipeline.datasets(): - if catalog._match_pattern( - catalog._dataset_patterns, ds - ) or catalog._match_pattern(catalog._default_pattern, ds): - _ = catalog._get_dataset(ds) run_result = runner.run( filtered_pipeline, catalog, hook_manager, session_id ) diff --git a/kedro/framework/session/shelvestore.py b/kedro/framework/session/shelvestore.py deleted file mode 100644 index 5fbac073ef..0000000000 --- a/kedro/framework/session/shelvestore.py +++ /dev/null @@ -1,46 +0,0 @@ -"""This module implements a dict-like store object used to persist Kedro sessions. -This module is separated from store.py to ensure it's only imported when exported explicitly. -""" - -from __future__ import annotations - -import dbm -import shelve -from multiprocessing import Lock -from pathlib import Path -from typing import Any - -from .store import BaseSessionStore - - -class ShelveStore(BaseSessionStore): - """Stores the session data on disk using `shelve` package. - This is an example of how to persist data on disk.""" - - _lock = Lock() - - @property - def _location(self) -> Path: - return Path(self._path).expanduser().resolve() / self._session_id / "store" - - def read(self) -> dict[str, Any]: - """Read the data from disk using `shelve` package.""" - data: dict[str, Any] = {} - try: - with shelve.open(str(self._location), flag="r") as _sh: # noqa: S301 - data = dict(_sh) - except dbm.error: - pass - return data - - def save(self) -> None: - """Save the data on disk using `shelve` package.""" - location = self._location - location.parent.mkdir(parents=True, exist_ok=True) - - with self._lock, shelve.open(str(location)) as _sh: # noqa: S301 - keys_to_del = _sh.keys() - self.data.keys() - for key in keys_to_del: - del _sh[key] - - _sh.update(self.data) diff --git a/kedro/io/__init__.py b/kedro/io/__init__.py index aba59827e9..9697e1bd35 100644 --- a/kedro/io/__init__.py +++ b/kedro/io/__init__.py @@ -5,15 +5,18 @@ from __future__ import annotations from .cached_dataset import CachedDataset +from .catalog_config_resolver import CatalogConfigResolver from .core import ( AbstractDataset, AbstractVersionedDataset, + CatalogProtocol, DatasetAlreadyExistsError, DatasetError, DatasetNotFoundError, Version, ) from .data_catalog import DataCatalog +from .kedro_data_catalog import KedroDataCatalog from .lambda_dataset import LambdaDataset from .memory_dataset import MemoryDataset from .shared_memory_dataset import SharedMemoryDataset @@ -22,10 +25,13 @@ "AbstractDataset", "AbstractVersionedDataset", "CachedDataset", + "CatalogProtocol", "DataCatalog", + "CatalogConfigResolver", "DatasetAlreadyExistsError", "DatasetError", "DatasetNotFoundError", + "KedroDataCatalog", "LambdaDataset", "MemoryDataset", "SharedMemoryDataset", diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index 5f8d96dc36..85d9341db5 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -103,7 +103,7 @@ def __repr__(self) -> str: } return self._pretty_repr(object_description) - def _load(self) -> Any: + def load(self) -> Any: data = self._cache.load() if self._cache.exists() else self._dataset.load() if not self._cache.exists(): @@ -111,7 +111,7 @@ def _load(self) -> Any: return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) diff --git a/kedro/io/catalog_config_resolver.py b/kedro/io/catalog_config_resolver.py new file mode 100644 index 0000000000..dc55d18b3c --- /dev/null +++ b/kedro/io/catalog_config_resolver.py @@ -0,0 +1,291 @@ +"""``CatalogConfigResolver`` resolves dataset configurations and datasets' +patterns based on catalog configuration and credentials provided. +""" + +from __future__ import annotations + +import copy +import logging +import re +from typing import Any + +from parse import parse + +from kedro.io.core import DatasetError + +Patterns = dict[str, dict[str, Any]] + +CREDENTIALS_KEY = "credentials" + + +class CatalogConfigResolver: + """Resolves dataset configurations based on patterns and credentials.""" + + def __init__( + self, + config: dict[str, dict[str, Any]] | None = None, + credentials: dict[str, dict[str, Any]] | None = None, + ): + self._runtime_patterns: Patterns = {} + self._dataset_patterns, self._default_pattern = self._extract_patterns( + config, credentials + ) + self._resolved_configs = self._resolve_config_credentials(config, credentials) + + @property + def config(self) -> dict[str, dict[str, Any]]: + return self._resolved_configs + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) + + @staticmethod + def is_pattern(pattern: str) -> bool: + """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" + return "{" in pattern + + @staticmethod + def _pattern_specificity(pattern: str) -> int: + """Calculate the specificity of a pattern based on characters outside curly brackets.""" + # Remove all the placeholders from the pattern and count the number of remaining chars + result = re.sub(r"\{.*?\}", "", pattern) + return len(result) + + @classmethod + def _sort_patterns(cls, dataset_patterns: Patterns) -> Patterns: + """Sort a dictionary of dataset patterns according to parsing rules. + + In order: + 1. Decreasing specificity (number of characters outside the curly brackets) + 2. Decreasing number of placeholders (number of curly bracket pairs) + 3. Alphabetically + """ + sorted_keys = sorted( + dataset_patterns, + key=lambda pattern: ( + -(cls._pattern_specificity(pattern)), + -pattern.count("{"), + pattern, + ), + ) + catch_all = [ + pattern for pattern in sorted_keys if cls._pattern_specificity(pattern) == 0 + ] + if len(catch_all) > 1: + raise DatasetError( + f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." + ) + return {key: dataset_patterns[key] for key in sorted_keys} + + @staticmethod + def _fetch_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: + """Fetch the specified credentials from the provided credentials dictionary. + + Args: + credentials_name: Credentials name. + credentials: A dictionary with all credentials. + + Returns: + The set of requested credentials. + + Raises: + KeyError: When a data set with the given name has not yet been + registered. + + """ + try: + return credentials[credentials_name] + except KeyError as exc: + raise KeyError( + f"Unable to find credentials '{credentials_name}': check your data " + "catalog and credentials configuration. See " + "https://kedro.readthedocs.io/en/stable/kedro.io.DataCatalog.html " + "for an example." + ) from exc + + @classmethod + def _resolve_credentials( + cls, config: dict[str, Any], credentials: dict[str, Any] + ) -> dict[str, Any]: + """Return the dataset configuration where credentials are resolved using + credentials dictionary provided. + + Args: + config: Original dataset config, which may contain unresolved credentials. + credentials: A dictionary with all credentials. + + Returns: + The dataset config, where all the credentials are successfully resolved. + """ + config = copy.deepcopy(config) + + def _resolve_value(key: str, value: Any) -> Any: + if key == CREDENTIALS_KEY and isinstance(value, str): + return cls._fetch_credentials(value, credentials) + if isinstance(value, dict): + return {k: _resolve_value(k, v) for k, v in value.items()} + return value + + return {k: _resolve_value(k, v) for k, v in config.items()} + + @classmethod + def _validate_pattern_config(cls, ds_name: str, ds_config: dict[str, Any]) -> None: + """Checks whether a dataset factory pattern configuration is valid - all + keys used in the configuration present in the dataset factory pattern name. + + Args: + ds_name: Dataset factory pattern name. + ds_config: Dataset pattern configuration. + + Raises: + DatasetError: when keys used in the configuration do not present in the dataset factory pattern name. + + """ + # Find all occurrences of {} in the string including brackets + search_regex = r"\{.*?\}" + name_placeholders = set(re.findall(search_regex, ds_name)) + config_placeholders = set() + + def _traverse_config(config: Any) -> None: + if isinstance(config, dict): + for value in config.values(): + _traverse_config(value) + elif isinstance(config, (list, tuple)): + for value in config: + _traverse_config(value) + elif isinstance(config, str) and "}" in config: + config_placeholders.update(set(re.findall(search_regex, config))) + + _traverse_config(ds_config) + + if config_placeholders - name_placeholders: + raise DatasetError( + f"Incorrect dataset configuration provided. " + f"Keys used in the configuration {config_placeholders - name_placeholders} " + f"should present in the dataset factory pattern name {ds_name}." + ) + + @classmethod + def _resolve_dataset_config( + cls, + ds_name: str, + pattern: str, + config: Any, + ) -> Any: + """Resolve dataset configuration based on the provided pattern.""" + resolved_vars = parse(pattern, ds_name) + # Resolve the factory config for the dataset + if isinstance(config, dict): + for key, value in config.items(): + config[key] = cls._resolve_dataset_config(ds_name, pattern, value) + elif isinstance(config, (list, tuple)): + config = [ + cls._resolve_dataset_config(ds_name, pattern, value) for value in config + ] + elif isinstance(config, str) and "}" in config: + config = config.format_map(resolved_vars.named) + return config + + def list_patterns(self) -> list[str]: + """List al patterns available in the catalog.""" + return ( + list(self._dataset_patterns.keys()) + + list(self._default_pattern.keys()) + + list(self._runtime_patterns.keys()) + ) + + def match_pattern(self, ds_name: str) -> str | None: + """Match a dataset name against patterns in a dictionary.""" + all_patterns = self.list_patterns() + matches = (pattern for pattern in all_patterns if parse(pattern, ds_name)) + return next(matches, None) + + def _get_pattern_config(self, pattern: str) -> dict[str, Any]: + return ( + self._dataset_patterns.get(pattern) + or self._default_pattern.get(pattern) + or self._runtime_patterns.get(pattern) + or {} + ) + + @classmethod + def _extract_patterns( + cls, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> tuple[Patterns, Patterns]: + """Extract and sort patterns from the configuration.""" + config = config or {} + credentials = credentials or {} + dataset_patterns = {} + user_default = {} + + for ds_name, ds_config in config.items(): + if cls.is_pattern(ds_name): + cls._validate_pattern_config(ds_name, ds_config) + dataset_patterns[ds_name] = cls._resolve_credentials( + ds_config, credentials + ) + + sorted_patterns = cls._sort_patterns(dataset_patterns) + if sorted_patterns: + # If the last pattern is a catch-all pattern, pop it and set it as the default + if cls._pattern_specificity(list(sorted_patterns.keys())[-1]) == 0: + last_pattern = sorted_patterns.popitem() + user_default = {last_pattern[0]: last_pattern[1]} + + return sorted_patterns, user_default + + def _resolve_config_credentials( + self, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> dict[str, dict[str, Any]]: + """Initialize the dataset configuration with resolved credentials.""" + config = config or {} + credentials = credentials or {} + resolved_configs = {} + + for ds_name, ds_config in config.items(): + if not isinstance(ds_config, dict): + raise DatasetError( + f"Catalog entry '{ds_name}' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + if not self.is_pattern(ds_name): + resolved_configs[ds_name] = self._resolve_credentials( + ds_config, credentials + ) + + return resolved_configs + + def resolve_pattern(self, ds_name: str) -> dict[str, Any]: + """Resolve dataset patterns and return resolved configurations based on the existing patterns.""" + matched_pattern = self.match_pattern(ds_name) + + if matched_pattern and ds_name not in self._resolved_configs: + pattern_config = self._get_pattern_config(matched_pattern) + ds_config = self._resolve_dataset_config( + ds_name, matched_pattern, copy.deepcopy(pattern_config) + ) + + if ( + self._pattern_specificity(matched_pattern) == 0 + and matched_pattern in self._default_pattern + ): + self._logger.warning( + "Config from the dataset factory pattern '%s' in the catalog will be used to " + "override the default dataset creation for '%s'", + matched_pattern, + ds_name, + ) + return ds_config # type: ignore[no-any-return] + + return self._resolved_configs.get(ds_name, {}) + + def add_runtime_patterns(self, dataset_patterns: Patterns) -> None: + """Add new runtime patterns and re-sort them.""" + self._runtime_patterns = {**self._runtime_patterns, **dataset_patterns} + self._runtime_patterns = self._sort_patterns(self._runtime_patterns) diff --git a/kedro/io/core.py b/kedro/io/core.py index f3975c9c3c..53b660835c 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -17,7 +17,15 @@ from glob import iglob from operator import attrgetter from pathlib import Path, PurePath, PurePosixPath -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) from urllib.parse import urlsplit from cachetools import Cache, cachedmethod @@ -29,12 +37,25 @@ if TYPE_CHECKING: import os + from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns + VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ" VERSIONED_FLAG_KEY = "versioned" VERSION_KEY = "version" HTTP_PROTOCOLS = ("http", "https") PROTOCOL_DELIMITER = "://" -CLOUD_PROTOCOLS = ("s3", "s3n", "s3a", "gcs", "gs", "adl", "abfs", "abfss", "gdrive") +CLOUD_PROTOCOLS = ( + "abfs", + "abfss", + "adl", + "gcs", + "gdrive", + "gs", + "oss", + "s3", + "s3a", + "s3n", +) class DatasetError(Exception): @@ -484,20 +505,13 @@ def parse_dataset_definition( class_obj = tmp break else: - hint = "" - if "DataSet" in dataset_type: - hint = ( # pragma: no cover # To remove when we drop support for python 3.8 - "Hint: If you are trying to use a dataset from `kedro-datasets`>=2.0.0, " - "make sure that the dataset name uses the `Dataset` spelling instead of `DataSet`." - ) - else: - hint = ( - "Hint: If you are trying to use a dataset from `kedro-datasets`, " - "make sure that the package is installed in your current environment. " - "You can do so by running `pip install kedro-datasets` or " - "`pip install kedro-datasets[]` to install `kedro-datasets` along with " - "related dependencies for the specific dataset group." - ) + hint = ( + "Hint: If you are trying to use a dataset from `kedro-datasets`, " + "make sure that the package is installed in your current environment. " + "You can do so by running `pip install kedro-datasets` or " + "`pip install kedro-datasets[]` to install `kedro-datasets` along with " + "related dependencies for the specific dataset group." + ) raise DatasetError( f"Class '{dataset_type}' not found, is this a typo?" f"\n{hint}" ) @@ -871,3 +885,70 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None: raise DatasetError( f"Neither white-space nor semicolon are allowed in '{key}'." ) + + +_C = TypeVar("_C") + + +@runtime_checkable +class CatalogProtocol(Protocol[_C]): + _datasets: dict[str, AbstractDataset] + + def __contains__(self, ds_name: str) -> bool: + """Check if a dataset is in the catalog.""" + ... + + @property + def config_resolver(self) -> CatalogConfigResolver: + """Return a copy of the datasets dictionary.""" + ... + + @classmethod + def from_config(cls, catalog: dict[str, dict[str, Any]] | None) -> _C: + """Create a catalog instance from configuration.""" + ... + + def _get_dataset( + self, + dataset_name: str, + version: Any = None, + suggest: bool = True, + ) -> AbstractDataset: + """Retrieve a dataset by its name.""" + ... + + def list(self, regex_search: str | None = None) -> list[str]: + """List all dataset names registered in the catalog.""" + ... + + def save(self, name: str, data: Any) -> None: + """Save data to a registered dataset.""" + ... + + def load(self, name: str, version: str | None = None) -> Any: + """Load data from a registered dataset.""" + ... + + def add(self, ds_name: str, dataset: Any, replace: bool = False) -> None: + """Add a new dataset to the catalog.""" + ... + + def add_feed_dict(self, datasets: dict[str, Any], replace: bool = False) -> None: + """Add datasets to the catalog using the data provided through the `feed_dict`.""" + ... + + def exists(self, name: str) -> bool: + """Checks whether registered data set exists by calling its `exists()` method.""" + ... + + def release(self, name: str) -> None: + """Release any cached data associated with a dataset.""" + ... + + def confirm(self, name: str) -> None: + """Confirm a dataset by its name.""" + ... + + def shallow_copy(self, extra_dataset_patterns: Patterns | None = None) -> _C: + """Returns a shallow copy of the current object.""" + ... diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index d3fd163230..a010f3e852 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -7,15 +7,17 @@ from __future__ import annotations -import copy import difflib import logging import pprint import re -from typing import Any, Dict - -from parse import parse +from typing import Any +from kedro.io.catalog_config_resolver import ( + CREDENTIALS_KEY, # noqa: F401 + CatalogConfigResolver, + Patterns, +) from kedro.io.core import ( AbstractDataset, AbstractVersionedDataset, @@ -28,64 +30,10 @@ from kedro.io.memory_dataset import MemoryDataset from kedro.utils import _format_rich, _has_rich_handler -Patterns = Dict[str, Dict[str, Any]] - -CATALOG_KEY = "catalog" -CREDENTIALS_KEY = "credentials" +CATALOG_KEY = "catalog" # Kept to avoid the breaking change WORDS_REGEX_PATTERN = re.compile(r"\W+") -def _get_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: - """Return a set of credentials from the provided credentials dict. - - Args: - credentials_name: Credentials name. - credentials: A dictionary with all credentials. - - Returns: - The set of requested credentials. - - Raises: - KeyError: When a data set with the given name has not yet been - registered. - - """ - try: - return credentials[credentials_name] - except KeyError as exc: - raise KeyError( - f"Unable to find credentials '{credentials_name}': check your data " - "catalog and credentials configuration. See " - "https://docs.kedro.org/en/stable/api/kedro.io.DataCatalog.html " - "for an example." - ) from exc - - -def _resolve_credentials( - config: dict[str, Any], credentials: dict[str, Any] -) -> dict[str, Any]: - """Return the dataset configuration where credentials are resolved using - credentials dictionary provided. - - Args: - config: Original dataset config, which may contain unresolved credentials. - credentials: A dictionary with all credentials. - - Returns: - The dataset config, where all the credentials are successfully resolved. - """ - config = copy.deepcopy(config) - - def _map_value(key: str, value: Any) -> Any: - if key == CREDENTIALS_KEY and isinstance(value, str): - return _get_credentials(value, credentials) - if isinstance(value, dict): - return {k: _map_value(k, v) for k, v in value.items()} - return value - - return {k: _map_value(k, v) for k, v in config.items()} - - def _sub_nonword_chars(dataset_name: str) -> str: """Replace non-word characters in data set names since Kedro 0.16.2. @@ -103,13 +51,15 @@ class _FrozenDatasets: def __init__( self, - *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset], + *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset] | None, ): """Return a _FrozenDatasets instance from some datasets collections. Each collection could either be another _FrozenDatasets or a dictionary. """ self._original_names: dict[str, str] = {} for collection in datasets_collections: + if collection is None: + continue if isinstance(collection, _FrozenDatasets): self.__dict__.update(collection.__dict__) self._original_names.update(collection._original_names) @@ -125,7 +75,7 @@ def __setattr__(self, key: str, value: Any) -> None: if key == "_original_names": super().__setattr__(key, value) return - msg = "Operation not allowed! " + msg = "Operation not allowed. " if key in self.__dict__: msg += "Please change datasets through configuration." else: @@ -161,10 +111,11 @@ def __init__( # noqa: PLR0913 self, datasets: dict[str, AbstractDataset] | None = None, feed_dict: dict[str, Any] | None = None, - dataset_patterns: Patterns | None = None, + dataset_patterns: Patterns | None = None, # Kept for interface compatibility load_versions: dict[str, str] | None = None, save_version: str | None = None, - default_pattern: Patterns | None = None, + default_pattern: Patterns | None = None, # Kept for interface compatibility + config_resolver: CatalogConfigResolver | None = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataset`` implementations to provide ``load`` and ``save`` capabilities from @@ -195,6 +146,8 @@ def __init__( # noqa: PLR0913 sorted in lexicographical order. default_pattern: A dictionary of the default catch-all pattern that overrides the default pattern provided through the runners. + config_resolver: An instance of CatalogConfigResolver to resolve dataset patterns and configurations. + Example: :: @@ -206,14 +159,21 @@ def __init__( # noqa: PLR0913 >>> save_args={"index": False}) >>> catalog = DataCatalog(datasets={'cars': cars}) """ - self._datasets = dict(datasets or {}) - self.datasets = _FrozenDatasets(self._datasets) - # Keep a record of all patterns in the catalog. - # {dataset pattern name : dataset pattern body} - self._dataset_patterns = dataset_patterns or {} + self._config_resolver = config_resolver or CatalogConfigResolver() + + # Kept to avoid breaking changes + if not config_resolver: + self._config_resolver._dataset_patterns = dataset_patterns or {} + self._config_resolver._default_pattern = default_pattern or {} + + self._datasets: dict[str, AbstractDataset] = {} + self.datasets: _FrozenDatasets | None = None + + self.add_all(datasets or {}) + self._load_versions = load_versions or {} self._save_version = save_version - self._default_pattern = default_pattern or {} + self._use_rich_markup = _has_rich_handler() if feed_dict: @@ -222,6 +182,23 @@ def __init__( # noqa: PLR0913 def __repr__(self) -> str: return self.datasets.__repr__() + def __contains__(self, dataset_name: str) -> bool: + """Check if an item is in the catalog as a materialised dataset or pattern""" + return ( + dataset_name in self._datasets + or self._config_resolver.match_pattern(dataset_name) is not None + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return (self._datasets, self._config_resolver.list_patterns()) == ( + other._datasets, + other.config_resolver.list_patterns(), + ) + + @property + def config_resolver(self) -> CatalogConfigResolver: + return self._config_resolver + @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @@ -303,44 +280,28 @@ class to be loaded is specified with the key ``type`` and their >>> df = catalog.load("cars") >>> catalog.save("boats", df) """ + catalog = catalog or {} datasets = {} - dataset_patterns = {} - catalog = copy.deepcopy(catalog) or {} - credentials = copy.deepcopy(credentials) or {} + config_resolver = CatalogConfigResolver(catalog, credentials) save_version = save_version or generate_timestamp() - load_versions = copy.deepcopy(load_versions) or {} - user_default = {} - - for ds_name, ds_config in catalog.items(): - if not isinstance(ds_config, dict): - raise DatasetError( - f"Catalog entry '{ds_name}' is not a valid dataset configuration. " - "\nHint: If this catalog entry is intended for variable interpolation, " - "make sure that the key is preceded by an underscore." - ) + load_versions = load_versions or {} - ds_config = _resolve_credentials( # noqa: PLW2901 - ds_config, credentials - ) - if cls._is_pattern(ds_name): - # Add each factory to the dataset_patterns dict. - dataset_patterns[ds_name] = ds_config - - else: + for ds_name in catalog: + if not config_resolver.is_pattern(ds_name): datasets[ds_name] = AbstractDataset.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version + ds_name, + config_resolver.config.get(ds_name, {}), + load_versions.get(ds_name), + save_version, ) - sorted_patterns = cls._sort_patterns(dataset_patterns) - if sorted_patterns: - # If the last pattern is a catch-all pattern, pop it and set it as the default - if cls._specificity(list(sorted_patterns.keys())[-1]) == 0: - last_pattern = sorted_patterns.popitem() - user_default = {last_pattern[0]: last_pattern[1]} missing_keys = [ - key - for key in load_versions.keys() - if not (key in catalog or cls._match_pattern(sorted_patterns, key)) + ds_name + for ds_name in load_versions + if not ( + ds_name in config_resolver.config + or config_resolver.match_pattern(ds_name) + ) ] if missing_keys: raise DatasetNotFoundError( @@ -350,107 +311,29 @@ class to be loaded is specified with the key ``type`` and their return cls( datasets=datasets, - dataset_patterns=sorted_patterns, + dataset_patterns=config_resolver._dataset_patterns, load_versions=load_versions, save_version=save_version, - default_pattern=user_default, + default_pattern=config_resolver._default_pattern, + config_resolver=config_resolver, ) - @staticmethod - def _is_pattern(pattern: str) -> bool: - """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" - return "{" in pattern - - @staticmethod - def _match_pattern(dataset_patterns: Patterns, dataset_name: str) -> str | None: - """Match a dataset name against patterns in a dictionary.""" - matches = ( - pattern - for pattern in dataset_patterns.keys() - if parse(pattern, dataset_name) - ) - return next(matches, None) - - @classmethod - def _sort_patterns(cls, dataset_patterns: Patterns) -> dict[str, dict[str, Any]]: - """Sort a dictionary of dataset patterns according to parsing rules. - - In order: - - 1. Decreasing specificity (number of characters outside the curly brackets) - 2. Decreasing number of placeholders (number of curly bracket pairs) - 3. Alphabetically - """ - sorted_keys = sorted( - dataset_patterns, - key=lambda pattern: ( - -(cls._specificity(pattern)), - -pattern.count("{"), - pattern, - ), - ) - catch_all = [ - pattern for pattern in sorted_keys if cls._specificity(pattern) == 0 - ] - if len(catch_all) > 1: - raise DatasetError( - f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." - ) - return {key: dataset_patterns[key] for key in sorted_keys} - - @staticmethod - def _specificity(pattern: str) -> int: - """Helper function to check the length of exactly matched characters not inside brackets. - - Example: - :: - - >>> specificity("{namespace}.companies") = 10 - >>> specificity("{namespace}.{dataset}") = 1 - >>> specificity("france.companies") = 16 - """ - # Remove all the placeholders from the pattern and count the number of remaining chars - result = re.sub(r"\{.*?\}", "", pattern) - return len(result) - def _get_dataset( self, dataset_name: str, version: Version | None = None, suggest: bool = True, ) -> AbstractDataset: - matched_pattern = self._match_pattern( - self._dataset_patterns, dataset_name - ) or self._match_pattern(self._default_pattern, dataset_name) - if dataset_name not in self._datasets and matched_pattern: - # If the dataset is a patterned dataset, materialise it and add it to - # the catalog - config_copy = copy.deepcopy( - self._dataset_patterns.get(matched_pattern) - or self._default_pattern.get(matched_pattern) - or {} - ) - dataset_config = self._resolve_config( - dataset_name, matched_pattern, config_copy - ) - dataset = AbstractDataset.from_config( + ds_config = self._config_resolver.resolve_pattern(dataset_name) + + if dataset_name not in self._datasets and ds_config: + ds = AbstractDataset.from_config( dataset_name, - dataset_config, + ds_config, self._load_versions.get(dataset_name), self._save_version, ) - if ( - self._specificity(matched_pattern) == 0 - and matched_pattern in self._default_pattern - ): - self._logger.warning( - "Config from the dataset factory pattern '%s' in the catalog will be used to " - "override the default dataset creation for '%s'", - matched_pattern, - dataset_name, - ) - - self.add(dataset_name, dataset) + self.add(dataset_name, ds) if dataset_name not in self._datasets: error_msg = f"Dataset '{dataset_name}' not found in the catalog" @@ -462,7 +345,9 @@ def _get_dataset( suggestions = ", ".join(matches) error_msg += f" - did you mean one of these instead: {suggestions}" raise DatasetNotFoundError(error_msg) + dataset = self._datasets[dataset_name] + if version and isinstance(dataset, AbstractVersionedDataset): # we only want to return a similar-looking dataset, # not modify the one stored in the current catalog @@ -470,41 +355,6 @@ def _get_dataset( return dataset - def __contains__(self, dataset_name: str) -> bool: - """Check if an item is in the catalog as a materialised dataset or pattern""" - matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) - if dataset_name in self._datasets or matched_pattern: - return True - return False - - @classmethod - def _resolve_config( - cls, - dataset_name: str, - matched_pattern: str, - config: dict, - ) -> dict[str, Any]: - """Get resolved AbstractDataset from a factory config""" - result = parse(matched_pattern, dataset_name) - # Resolve the factory config for the dataset - if isinstance(config, dict): - for key, value in config.items(): - config[key] = cls._resolve_config(dataset_name, matched_pattern, value) - elif isinstance(config, (list, tuple)): - config = [ - cls._resolve_config(dataset_name, matched_pattern, value) - for value in config - ] - elif isinstance(config, str) and "}" in config: - try: - config = str(config).format_map(result.named) - except KeyError as exc: - raise DatasetError( - f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration " - f"should be present in the dataset factory pattern." - ) from exc - return config - def load(self, name: str, version: str | None = None) -> Any: """Loads a registered data set. @@ -619,7 +469,10 @@ def release(self, name: str) -> None: dataset.release() def add( - self, dataset_name: str, dataset: AbstractDataset, replace: bool = False + self, + dataset_name: str, + dataset: AbstractDataset, + replace: bool = False, ) -> None: """Adds a new ``AbstractDataset`` object to the ``DataCatalog``. @@ -657,7 +510,9 @@ def add( self.datasets = _FrozenDatasets(self.datasets, {dataset_name: dataset}) def add_all( - self, datasets: dict[str, AbstractDataset], replace: bool = False + self, + datasets: dict[str, AbstractDataset], + replace: bool = False, ) -> None: """Adds a group of new data sets to the ``DataCatalog``. @@ -688,8 +543,8 @@ def add_all( >>> >>> assert catalog.list() == ["cars", "planes", "boats"] """ - for name, dataset in datasets.items(): - self.add(name, dataset, replace) + for ds_name, ds in datasets.items(): + self.add(ds_name, ds, replace) def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None: """Add datasets to the ``DataCatalog`` using the data provided through the `feed_dict`. @@ -726,13 +581,13 @@ def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> Non >>> >>> assert catalog.load("data_csv_dataset").equals(df) """ - for dataset_name in feed_dict: - if isinstance(feed_dict[dataset_name], AbstractDataset): - dataset = feed_dict[dataset_name] - else: - dataset = MemoryDataset(data=feed_dict[dataset_name]) # type: ignore[abstract] - - self.add(dataset_name, dataset, replace) + for ds_name, ds_data in feed_dict.items(): + dataset = ( + ds_data + if isinstance(ds_data, AbstractDataset) + else MemoryDataset(data=ds_data) # type: ignore[abstract] + ) + self.add(ds_name, dataset, replace) def list(self, regex_search: str | None = None) -> list[str]: """ @@ -777,7 +632,7 @@ def list(self, regex_search: str | None = None) -> list[str]: raise SyntaxError( f"Invalid regular expression provided: '{regex_search}'" ) from exc - return [dset_name for dset_name in self._datasets if pattern.search(dset_name)] + return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] def shallow_copy( self, extra_dataset_patterns: Patterns | None = None @@ -787,26 +642,15 @@ def shallow_copy( Returns: Copy of the current object. """ - if not self._default_pattern and extra_dataset_patterns: - unsorted_dataset_patterns = { - **self._dataset_patterns, - **extra_dataset_patterns, - } - dataset_patterns = self._sort_patterns(unsorted_dataset_patterns) - else: - dataset_patterns = self._dataset_patterns + if extra_dataset_patterns: + self._config_resolver.add_runtime_patterns(extra_dataset_patterns) return self.__class__( datasets=self._datasets, - dataset_patterns=dataset_patterns, + dataset_patterns=self._config_resolver._dataset_patterns, + default_pattern=self._config_resolver._default_pattern, load_versions=self._load_versions, save_version=self._save_version, - default_pattern=self._default_pattern, - ) - - def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] - return (self._datasets, self._dataset_patterns) == ( - other._datasets, - other._dataset_patterns, + config_resolver=self._config_resolver, ) def confirm(self, name: str) -> None: diff --git a/kedro/io/kedro_data_catalog.py b/kedro/io/kedro_data_catalog.py new file mode 100644 index 0000000000..d07de8151a --- /dev/null +++ b/kedro/io/kedro_data_catalog.py @@ -0,0 +1,355 @@ +"""``KedroDataCatalog`` stores instances of ``AbstractDataset`` implementations to +provide ``load`` and ``save`` capabilities from anywhere in the program. To +use a ``KedroDataCatalog``, you need to instantiate it with a dictionary of datasets. +Then it will act as a single point of reference for your calls, relaying load and +save functions to the underlying datasets. + +``KedroDataCatalog`` is an experimental feature aimed to replace ``DataCatalog`` in the future. +Expect possible breaking changes while using it. +""" + +from __future__ import annotations + +import copy +import difflib +import logging +import re +from typing import Any + +from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns +from kedro.io.core import ( + AbstractDataset, + AbstractVersionedDataset, + CatalogProtocol, + DatasetAlreadyExistsError, + DatasetError, + DatasetNotFoundError, + Version, + generate_timestamp, +) +from kedro.io.memory_dataset import MemoryDataset +from kedro.utils import _format_rich, _has_rich_handler + + +class KedroDataCatalog(CatalogProtocol): + def __init__( + self, + datasets: dict[str, AbstractDataset] | None = None, + raw_data: dict[str, Any] | None = None, + config_resolver: CatalogConfigResolver | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, + ) -> None: + """``KedroDataCatalog`` stores instances of ``AbstractDataset`` + implementations to provide ``load`` and ``save`` capabilities from + anywhere in the program. To use a ``KedroDataCatalog``, you need to + instantiate it with a dictionary of datasets. Then it will act as a + single point of reference for your calls, relaying load and save + functions to the underlying datasets. + + Note: ``KedroDataCatalog`` is an experimental feature and is under active development. Therefore, it is possible we'll introduce breaking changes to this class, so be mindful of that if you decide to use it already. + + Args: + datasets: A dictionary of dataset names and dataset instances. + raw_data: A dictionary with data to be added in memory as `MemoryDataset`` instances. + Keys represent dataset names and the values are raw data. + config_resolver: An instance of CatalogConfigResolver to resolve dataset patterns and configurations. + load_versions: A mapping between dataset names and versions + to load. Has no effect on datasets without enabled versioning. + save_version: Version string to be used for ``save`` operations + by all datasets with enabled versioning. It must: a) be a + case-insensitive string that conforms with operating system + filename limitations, b) always return the latest version when + sorted in lexicographical order. + + Example: + :: + >>> # settings.py + >>> from kedro.io import KedroDataCatalog + >>> + >>> DATA_CATALOG_CLASS = KedroDataCatalog + """ + self._config_resolver = config_resolver or CatalogConfigResolver() + self._datasets = datasets or {} + self._load_versions = load_versions or {} + self._save_version = save_version + + self._use_rich_markup = _has_rich_handler() + + for ds_name, ds_config in self._config_resolver.config.items(): + self._add_from_config(ds_name, ds_config) + + if raw_data: + self.add_feed_dict(raw_data) + + @property + def datasets(self) -> dict[str, Any]: + return copy.copy(self._datasets) + + @datasets.setter + def datasets(self, value: Any) -> None: + raise AttributeError( + "Operation not allowed. Please use KedroDataCatalog.add() instead." + ) + + @property + def config_resolver(self) -> CatalogConfigResolver: + return self._config_resolver + + def __repr__(self) -> str: + return repr(self._datasets) + + def __contains__(self, dataset_name: str) -> bool: + """Check if an item is in the catalog as a materialised dataset or pattern""" + return ( + dataset_name in self._datasets + or self._config_resolver.match_pattern(dataset_name) is not None + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return (self._datasets, self._config_resolver.list_patterns()) == ( + other._datasets, + other.config_resolver.list_patterns(), + ) + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) + + @classmethod + def from_config( + cls, + catalog: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, + ) -> KedroDataCatalog: + """Create a ``KedroDataCatalog`` instance from configuration. This is a + factory method used to provide developers with a way to instantiate + ``KedroDataCatalog`` with configuration parsed from configuration files. + """ + catalog = catalog or {} + config_resolver = CatalogConfigResolver(catalog, credentials) + save_version = save_version or generate_timestamp() + load_versions = load_versions or {} + + missing_keys = [ + ds_name + for ds_name in load_versions + if not ( + ds_name in config_resolver.config + or config_resolver.match_pattern(ds_name) + ) + ] + if missing_keys: + raise DatasetNotFoundError( + f"'load_versions' keys [{', '.join(sorted(missing_keys))}] " + f"are not found in the catalog." + ) + + return cls( + load_versions=load_versions, + save_version=save_version, + config_resolver=config_resolver, + ) + + @staticmethod + def _validate_dataset_config(ds_name: str, ds_config: Any) -> None: + if not isinstance(ds_config, dict): + raise DatasetError( + f"Catalog entry '{ds_name}' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + + def _add_from_config(self, ds_name: str, ds_config: dict[str, Any]) -> None: + # TODO: Add lazy loading feature to store the configuration but not to init actual dataset + # TODO: Initialise actual dataset when load or save + self._validate_dataset_config(ds_name, ds_config) + ds = AbstractDataset.from_config( + ds_name, + ds_config, + self._load_versions.get(ds_name), + self._save_version, + ) + + self.add(ds_name, ds) + + def get_dataset( + self, ds_name: str, version: Version | None = None, suggest: bool = True + ) -> AbstractDataset: + """Get a dataset by name from an internal collection of datasets. + + If a dataset is not in the collection but matches any pattern + it is instantiated and added to the collection first, then returned. + + Args: + ds_name: A dataset name. + version: Optional argument for concrete dataset version to be loaded. + Works only with versioned datasets. + suggest: Optional argument whether to suggest fuzzy-matching datasets' names + in the DatasetNotFoundError message. + + Returns: + An instance of AbstractDataset. + + Raises: + DatasetNotFoundError: When a dataset with the given name + is not in the collection and do not match patterns. + """ + if ds_name not in self._datasets: + ds_config = self._config_resolver.resolve_pattern(ds_name) + if ds_config: + self._add_from_config(ds_name, ds_config) + + dataset = self._datasets.get(ds_name, None) + + if dataset is None: + error_msg = f"Dataset '{ds_name}' not found in the catalog" + # Flag to turn on/off fuzzy-matching which can be time consuming and + # slow down plugins like `kedro-viz` + if suggest: + matches = difflib.get_close_matches(ds_name, self._datasets.keys()) + if matches: + suggestions = ", ".join(matches) + error_msg += f" - did you mean one of these instead: {suggestions}" + raise DatasetNotFoundError(error_msg) + + if version and isinstance(dataset, AbstractVersionedDataset): + # we only want to return a similar-looking dataset, + # not modify the one stored in the current catalog + dataset = dataset._copy(_version=version) + + return dataset + + def _get_dataset( + self, dataset_name: str, version: Version | None = None, suggest: bool = True + ) -> AbstractDataset: + # TODO: remove when removing old catalog + return self.get_dataset(dataset_name, version, suggest) + + def add( + self, ds_name: str, dataset: AbstractDataset, replace: bool = False + ) -> None: + """Adds a new ``AbstractDataset`` object to the ``KedroDataCatalog``.""" + if ds_name in self._datasets: + if replace: + self._logger.warning("Replacing dataset '%s'", ds_name) + else: + raise DatasetAlreadyExistsError( + f"Dataset '{ds_name}' has already been registered" + ) + self._datasets[ds_name] = dataset + + def list(self, regex_search: str | None = None) -> list[str]: + """ + List of all dataset names registered in the catalog. + This can be filtered by providing an optional regular expression + which will only return matching keys. + """ + + if regex_search is None: + return list(self._datasets.keys()) + + if not regex_search.strip(): + self._logger.warning("The empty string will not match any datasets") + return [] + + try: + pattern = re.compile(regex_search, flags=re.IGNORECASE) + except re.error as exc: + raise SyntaxError( + f"Invalid regular expression provided: '{regex_search}'" + ) from exc + return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] + + def save(self, name: str, data: Any) -> None: + """Save data to a registered dataset.""" + dataset = self.get_dataset(name) + + self._logger.info( + "Saving data to %s (%s)...", + _format_rich(name, "dark_orange") if self._use_rich_markup else name, + type(dataset).__name__, + extra={"markup": True}, + ) + + dataset.save(data) + + def load(self, name: str, version: str | None = None) -> Any: + """Loads a registered dataset.""" + load_version = Version(version, None) if version else None + dataset = self.get_dataset(name, version=load_version) + + self._logger.info( + "Loading data from %s (%s)...", + _format_rich(name, "dark_orange") if self._use_rich_markup else name, + type(dataset).__name__, + extra={"markup": True}, + ) + + return dataset.load() + + def release(self, name: str) -> None: + """Release any cached data associated with a dataset + Args: + name: A dataset to be checked. + Raises: + DatasetNotFoundError: When a dataset with the given name + has not yet been registered. + """ + dataset = self.get_dataset(name) + dataset.release() + + def confirm(self, name: str) -> None: + """Confirm a dataset by its name. + Args: + name: Name of the dataset. + Raises: + DatasetError: When the dataset does not have `confirm` method. + """ + self._logger.info("Confirming dataset '%s'", name) + dataset = self.get_dataset(name) + + if hasattr(dataset, "confirm"): + dataset.confirm() + else: + raise DatasetError(f"Dataset '{name}' does not have 'confirm' method") + + def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None: + # TODO: remove when removing old catalog + # This method was simplified to add memory datasets only, since + # adding AbstractDataset can be done via add() method + for ds_name, ds_data in feed_dict.items(): + self.add(ds_name, MemoryDataset(data=ds_data), replace) # type: ignore[abstract] + + def shallow_copy( + self, extra_dataset_patterns: Patterns | None = None + ) -> KedroDataCatalog: + # TODO: remove when removing old catalog + """Returns a shallow copy of the current object. + + Returns: + Copy of the current object. + """ + if extra_dataset_patterns: + self._config_resolver.add_runtime_patterns(extra_dataset_patterns) + return self + + def exists(self, name: str) -> bool: + """Checks whether registered dataset exists by calling its `exists()` + method. Raises a warning and returns False if `exists()` is not + implemented. + + Args: + name: A dataset to be checked. + + Returns: + Whether the dataset output exists. + + """ + try: + dataset = self._get_dataset(name) + except DatasetNotFoundError: + return False + return dataset.exists() diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 56ad92b7f2..1b4bb8a371 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -59,7 +59,7 @@ def __init__( if data is not _EMPTY: self.save.__wrapped__(self, data) # type: ignore[attr-defined] - def _load(self) -> Any: + def load(self) -> Any: if self._data is _EMPTY: raise DatasetError("Data for MemoryDataset has not been saved yet.") @@ -67,7 +67,7 @@ def _load(self) -> Any: data = _copy_with_mode(self._data, copy_mode=copy_mode) return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: copy_mode = self._copy_mode or _infer_copy_mode(data) self._data = _copy_with_mode(data, copy_mode=copy_mode) diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index e2bd63bf7e..139180b578 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -36,10 +36,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError() return getattr(self.shared_memory_dataset, name) # pragma: no cover - def _load(self) -> Any: + def load(self) -> Any: return self.shared_memory_dataset.load() - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: """Calls save method of a shared MemoryDataset in SyncManager.""" try: self.shared_memory_dataset.save(data) @@ -57,3 +57,8 @@ def _save(self, data: Any) -> None: def _describe(self) -> dict[str, Any]: """SharedMemoryDataset doesn't have any constructor argument to return.""" return {} + + def _exists(self) -> bool: + if not self.shared_memory_dataset: + return False + return self.shared_memory_dataset.exists() # type: ignore[no-any-return] diff --git a/kedro/ipython/__init__.py b/kedro/ipython/__init__.py index 7cdbf92138..0e479e8a68 100644 --- a/kedro/ipython/__init__.py +++ b/kedro/ipython/__init__.py @@ -14,7 +14,10 @@ import warnings from pathlib import Path from types import MappingProxyType -from typing import Any, Callable, OrderedDict +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from collections import OrderedDict from IPython.core.getipython import get_ipython from IPython.core.magic import needs_local_scope, register_line_magic diff --git a/kedro/pipeline/modular_pipeline.py b/kedro/pipeline/modular_pipeline.py index 172a5f9593..c6372f37a4 100644 --- a/kedro/pipeline/modular_pipeline.py +++ b/kedro/pipeline/modular_pipeline.py @@ -4,13 +4,15 @@ import copy import difflib -from typing import TYPE_CHECKING, AbstractSet, Iterable +from typing import TYPE_CHECKING from kedro.pipeline.pipeline import Pipeline from .transcoding import TRANSCODING_SEPARATOR, _strip_transcoding, _transcode_split if TYPE_CHECKING: + from collections.abc import Iterable, Set + from kedro.pipeline.node import Node @@ -35,7 +37,7 @@ def _is_parameter(name: str) -> bool: def _validate_inputs_outputs( - inputs: AbstractSet[str], outputs: AbstractSet[str], pipe: Pipeline + inputs: Set[str], outputs: Set[str], pipe: Pipeline ) -> None: """Safeguards to ensure that: - parameters are not specified under inputs @@ -64,9 +66,9 @@ def _validate_inputs_outputs( def _validate_datasets_exist( - inputs: AbstractSet[str], - outputs: AbstractSet[str], - parameters: AbstractSet[str], + inputs: Set[str], + outputs: Set[str], + parameters: Set[str], pipe: Pipeline, ) -> None: """Validate that inputs, parameters and outputs map correctly onto the provided nodes.""" diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 09a83410aa..b382bee8cf 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -9,13 +9,16 @@ import logging import re from collections import Counter -from typing import Any, Callable, Iterable +from typing import TYPE_CHECKING, Any, Callable from warnings import warn from more_itertools import spy, unzip from .transcoding import _strip_transcoding +if TYPE_CHECKING: + from collections.abc import Iterable + class Node: """``Node`` is an auxiliary class facilitating the operations required to diff --git a/kedro/pipeline/pipeline.py b/kedro/pipeline/pipeline.py index 7810a98049..ab7365a154 100644 --- a/kedro/pipeline/pipeline.py +++ b/kedro/pipeline/pipeline.py @@ -8,16 +8,18 @@ import json from collections import Counter, defaultdict -from itertools import chain -from typing import Any, Iterable - from graphlib import CycleError, TopologicalSorter +from itertools import chain +from typing import TYPE_CHECKING, Any import kedro from kedro.pipeline.node import Node, _to_list from .transcoding import _strip_transcoding +if TYPE_CHECKING: + from collections.abc import Iterable + def __getattr__(name: str) -> Any: if name == "TRANSCODING_SEPARATOR": diff --git a/kedro/pipeline/transcoding.py b/kedro/pipeline/transcoding.py index 71f0dac342..eae9a10cf7 100644 --- a/kedro/pipeline/transcoding.py +++ b/kedro/pipeline/transcoding.py @@ -1,9 +1,7 @@ -from typing import Tuple - TRANSCODING_SEPARATOR = "@" -def _transcode_split(element: str) -> Tuple[str, str]: +def _transcode_split(element: str) -> tuple[str, str]: """Split the name by the transcoding separator. If the transcoding part is missing, empty string will be put in. diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 62d7e1216b..7626bf8679 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -13,7 +13,7 @@ from multiprocessing.managers import BaseProxy, SyncManager from multiprocessing.reduction import ForkingPickler from pickle import PicklingError -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from kedro.framework.hooks.manager import ( _create_hook_manager, @@ -22,7 +22,7 @@ ) from kedro.framework.project import settings from kedro.io import ( - DataCatalog, + CatalogProtocol, DatasetNotFoundError, MemoryDataset, SharedMemoryDataset, @@ -30,6 +30,8 @@ from kedro.runner.runner import AbstractRunner, run_node if TYPE_CHECKING: + from collections.abc import Iterable + from pluggy import PluginManager from kedro.pipeline import Pipeline @@ -60,7 +62,7 @@ def _bootstrap_subprocess( def _run_node_synchronization( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, is_async: bool = False, session_id: str | None = None, package_name: str | None = None, @@ -73,7 +75,7 @@ def _run_node_synchronization( # noqa: PLR0913 Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. session_id: The session id of the pipeline run. @@ -118,7 +120,7 @@ def __init__( cannot be larger than 61 and will be set to min(61, max_workers). is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to SharedMemoryDataset for `ParallelRunner`. @@ -168,7 +170,7 @@ def _validate_nodes(cls, nodes: Iterable[Node]) -> None: ) @classmethod - def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline) -> None: + def _validate_catalog(cls, catalog: CatalogProtocol, pipeline: Pipeline) -> None: """Ensure that all data sets are serialisable and that we do not have any non proxied memory data sets being used as outputs as their content will not be synchronized across threads. @@ -213,7 +215,9 @@ def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline) -> None: f"MemoryDatasets" ) - def _set_manager_datasets(self, catalog: DataCatalog, pipeline: Pipeline) -> None: + def _set_manager_datasets( + self, catalog: CatalogProtocol, pipeline: Pipeline + ) -> None: for dataset in pipeline.datasets(): try: catalog.exists(dataset) @@ -240,7 +244,7 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -248,7 +252,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 2ffd0389e4..f6716e070f 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -9,6 +9,7 @@ import logging from abc import ABC, abstractmethod from collections import deque +from collections.abc import Iterator from concurrent.futures import ( ALL_COMPLETED, Future, @@ -16,15 +17,17 @@ as_completed, wait, ) -from typing import TYPE_CHECKING, Any, Collection, Iterable, Iterator +from typing import TYPE_CHECKING, Any from more_itertools import interleave from kedro.framework.hooks.manager import _NullPluginManager -from kedro.io import DataCatalog, MemoryDataset +from kedro.io import CatalogProtocol, MemoryDataset from kedro.pipeline import Pipeline if TYPE_CHECKING: + from collections.abc import Collection, Iterable + from pluggy import PluginManager from kedro.pipeline.node import Node @@ -45,7 +48,7 @@ def __init__( Args: is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets on the Runner instances. """ @@ -59,7 +62,7 @@ def _logger(self) -> logging.Logger: def run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> dict[str, Any]: @@ -68,7 +71,7 @@ def run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. @@ -76,14 +79,13 @@ def run( ValueError: Raised when ``Pipeline`` inputs cannot be satisfied. Returns: - Any node outputs that cannot be processed by the ``DataCatalog``. + Any node outputs that cannot be processed by the catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ hook_or_null_manager = hook_manager or _NullPluginManager() - catalog = catalog.shallow_copy() # Check which datasets used in the pipeline are in the catalog or match # a pattern in the catalog @@ -95,7 +97,7 @@ def run( if unsatisfied: raise ValueError( - f"Pipeline input(s) {unsatisfied} not found in the DataCatalog" + f"Pipeline input(s) {unsatisfied} not found in the {catalog.__class__.__name__}" ) # Identify MemoryDataset in the catalog @@ -125,7 +127,7 @@ def run( return {ds_name: catalog.load(ds_name) for ds_name in free_outputs} def run_only_missing( - self, pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager + self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager ) -> dict[str, Any]: """Run only the missing outputs from the ``Pipeline`` using the datasets provided by ``catalog``, and save results back to the @@ -133,7 +135,7 @@ def run_only_missing( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. Raises: ValueError: Raised when ``Pipeline`` inputs cannot be @@ -141,7 +143,7 @@ def run_only_missing( Returns: Any node outputs that cannot be processed by the - ``DataCatalog``. These are returned in a dictionary, where + catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ @@ -165,7 +167,7 @@ def run_only_missing( def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -174,7 +176,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. @@ -185,7 +187,7 @@ def _suggest_resume_scenario( self, pipeline: Pipeline, done_nodes: Iterable[Node], - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """ Suggest a command to the user to resume a run after it fails. @@ -195,7 +197,7 @@ def _suggest_resume_scenario( Args: pipeline: the ``Pipeline`` of the run. done_nodes: the ``Node``s that executed successfully. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. """ remaining_nodes = set(pipeline.nodes) - set(done_nodes) @@ -224,7 +226,7 @@ def _suggest_resume_scenario( def _find_nodes_to_resume_from( - pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: DataCatalog + pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol ) -> set[str]: """Given a collection of unfinished nodes in a pipeline using a certain catalog, find the node names to pass to pipeline.from_nodes() @@ -234,7 +236,7 @@ def _find_nodes_to_resume_from( Args: pipeline: the ``Pipeline`` to find starting nodes for. unfinished_nodes: collection of ``Node``s that have not finished yet - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of node names to pass to pipeline.from_nodes() to continue @@ -252,7 +254,7 @@ def _find_nodes_to_resume_from( def _find_all_nodes_for_resumed_pipeline( - pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: DataCatalog + pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: CatalogProtocol ) -> set[Node]: """Breadth-first search approach to finding the complete set of ``Node``s which need to run to cover all unfinished nodes, @@ -262,7 +264,7 @@ def _find_all_nodes_for_resumed_pipeline( Args: pipeline: the ``Pipeline`` to analyze. unfinished_nodes: the iterable of ``Node``s which have not finished yet. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: A set containing all input unfinished ``Node``s and all remaining @@ -310,12 +312,12 @@ def _nodes_with_external_inputs(nodes_of_interest: Iterable[Node]) -> set[Node]: return set(p_nodes_with_external_inputs.nodes) -def _enumerate_non_persistent_inputs(node: Node, catalog: DataCatalog) -> set[str]: +def _enumerate_non_persistent_inputs(node: Node, catalog: CatalogProtocol) -> set[str]: """Enumerate non-persistent input datasets of a ``Node``. Args: node: the ``Node`` to check the inputs of. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of names of non-persistent inputs of given ``Node``. @@ -370,7 +372,7 @@ def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[ A list of initial ``Node``s to run given inputs (in topological order). """ - node_names = set(n.name for n in nodes) + node_names = {n.name for n in nodes} if len(node_names) == 0: return [] sub_pipeline = pipeline.only_nodes(*node_names) @@ -380,7 +382,7 @@ def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[ def run_node( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, is_async: bool = False, session_id: str | None = None, @@ -389,7 +391,7 @@ def run_node( Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. hook_manager: The ``PluginManager`` to activate hooks. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. @@ -423,7 +425,7 @@ def run_node( def _collect_inputs_from_hook( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, @@ -456,7 +458,7 @@ def _collect_inputs_from_hook( # noqa: PLR0913 def _call_node_run( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, @@ -487,7 +489,7 @@ def _call_node_run( # noqa: PLR0913 def _run_node_sequential( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: @@ -534,7 +536,7 @@ def _run_node_sequential( def _run_node_async( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 48dac3cd54..c888e737cf 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from pluggy import PluginManager - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline @@ -34,7 +34,7 @@ def __init__( Args: is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to MemoryDataset for `SequentialRunner`. @@ -48,7 +48,7 @@ def __init__( def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -56,7 +56,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index b4751a602a..5ad13b9153 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pluggy import PluginManager - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -43,7 +43,7 @@ def __init__( is_async: If True, set to False, because `ThreadRunner` doesn't support loading and saving the node inputs and outputs asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to MemoryDataset for `ThreadRunner`. @@ -87,7 +87,7 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -95,7 +95,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/templates/project/hooks/utils.py b/kedro/templates/project/hooks/utils.py index 947c234f19..595884248e 100644 --- a/kedro/templates/project/hooks/utils.py +++ b/kedro/templates/project/hooks/utils.py @@ -20,7 +20,9 @@ ] # Configuration key for documentation dependencies -docs_pyproject_requirements = ["project.optional-dependencies"] # For pyproject.toml +docs_pyproject_requirements = ["project.optional-dependencies.docs"] # For pyproject.toml +# Configuration key for linting and testing dependencies +dev_pyproject_requirements = ["project.optional-dependencies.dev"] # For pyproject.toml # Requirements for example pipelines example_pipeline_requirements = "seaborn~=0.12.1\nscikit-learn~=1.0\n" @@ -34,7 +36,7 @@ def _remove_from_file(file_path: Path, content_to_remove: str) -> None: file_path (Path): The path of the file from which to remove content. content_to_remove (str): The content to be removed from the file. """ - with open(file_path, "r") as file: + with open(file_path) as file: lines = file.readlines() # Split the content to remove into lines and remove trailing whitespaces/newlines @@ -84,7 +86,7 @@ def _remove_from_toml(file_path: Path, sections_to_remove: list) -> None: sections_to_remove (list): A list of section keys to remove from the TOML file. """ # Load the TOML file - with open(file_path, "r") as file: + with open(file_path) as file: data = toml.load(file) # Remove the specified sections @@ -160,7 +162,7 @@ def _remove_extras_from_kedro_datasets(file_path: Path) -> None: Args: file_path (Path): The path of the requirements file. """ - with open(file_path, "r") as file: + with open(file_path) as file: lines = file.readlines() for i, line in enumerate(lines): @@ -191,12 +193,14 @@ def setup_template_tools( python_package_name (str): The name of the python package. example_pipeline (str): 'True' if example pipeline was selected """ + + if "Linting" not in selected_tools_list and "Testing" not in selected_tools_list: + _remove_from_toml(pyproject_file_path, dev_pyproject_requirements) + if "Linting" not in selected_tools_list: - _remove_from_file(requirements_file_path, lint_requirements) _remove_from_toml(pyproject_file_path, lint_pyproject_requirements) if "Testing" not in selected_tools_list: - _remove_from_file(requirements_file_path, test_requirements) _remove_from_toml(pyproject_file_path, test_pyproject_requirements) _remove_dir(current_dir / "tests") diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml index f22a91242f..b2ab54c3bb 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml @@ -24,6 +24,12 @@ docs = [ "Jinja2<3.2.0", "myst-parser>=1.0,<2.1" ] +dev = [ + "pytest-cov~=3.0", + "pytest-mock>=1.7.1, <2.0", + "pytest~=7.2", + "ruff~=0.1.8" +] [tool.setuptools.dynamic] diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt b/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt index 9301f4e3f3..1be43016fb 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt @@ -2,7 +2,3 @@ ipython>=8.10 jupyterlab>=3.0 notebook kedro~={{ cookiecutter.kedro_version }} -pytest-cov~=3.0 -pytest-mock>=1.7.1, <2.0 -pytest~=7.2 -ruff~=0.1.8 diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py b/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py index eb57d1908e..c7b3cf08a8 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py @@ -12,7 +12,7 @@ import pytest -from kedro.config import ConfigLoader +from kedro.config import OmegaConfigLoader from kedro.framework.context import KedroContext from kedro.framework.hooks import _create_hook_manager from kedro.framework.project import settings @@ -20,7 +20,7 @@ @pytest.fixture def config_loader(): - return ConfigLoader(conf_source=str(Path.cwd() / settings.CONF_SOURCE)) + return OmegaConfigLoader(conf_source=str(Path.cwd() / settings.CONF_SOURCE)) @pytest.fixture @@ -28,6 +28,7 @@ def project_context(config_loader): return KedroContext( package_name="{{ cookiecutter.python_package }}", project_path=Path.cwd(), + env="local", config_loader=config_loader, hook_manager=_create_hook_manager(), ) diff --git a/pyproject.toml b/pyproject.toml index 8b7b4cb09b..23b60e9a61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ {name = "Kedro"} ] description = "Kedro helps you build production-ready data and analytics pipelines" -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "attrs>=21.3", "build>=0.7.0", @@ -33,7 +33,6 @@ dependencies = [ "rope>=0.21,<2.0", # subject to LGPLv3 license "toml>=0.10.0", "typing_extensions>=4.0", - "graphlib_backport>=1.0.0; python_version < '3.9'", ] keywords = [ "pipelines", @@ -45,10 +44,10 @@ keywords = [ license = {text = "Apache Software License (Apache 2.0)"} classifiers = [ "Development Status :: 4 - Beta", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dynamic = ["readme", "version"] @@ -58,13 +57,11 @@ test = [ "coverage[toml]", "import-linter==2.0", "ipylab>=1.0.0", - "ipython>=7.31.1, <8.0; python_version < '3.8'", - "ipython~=8.10; python_version >= '3.8'", + "ipython~=8.10", "jupyterlab_server>=2.11.1", "jupyterlab>=3,<5", "jupyter~=1.0", - "kedro-datasets; python_version >= '3.9'", - "kedro-datasets<2.0.0; python_version < '3.9'", + "kedro-datasets", "mypy~=1.0", "pandas~=2.0", "pluggy>=1.0", @@ -134,7 +131,7 @@ omit = [ "kedro/runner/parallel_runner.py", "*/site-packages/*", ] -exclude_also = ["raise NotImplementedError", "if TYPE_CHECKING:"] +exclude_also = ["raise NotImplementedError", "if TYPE_CHECKING:", "class CatalogProtocol"] [tool.pytest.ini_options] addopts=""" diff --git a/tests/framework/cli/pipeline/test_pipeline.py b/tests/framework/cli/pipeline/test_pipeline.py index df01d685a6..099790221e 100644 --- a/tests/framework/cli/pipeline/test_pipeline.py +++ b/tests/framework/cli/pipeline/test_pipeline.py @@ -127,11 +127,11 @@ def test_create_pipeline_template_command_line_override( assert not (pipelines_dir / PIPELINE_NAME).exists() # Rename the local template dir to something else so we know the command line flag is taking precedence - try: - # Can skip if already there but copytree has a dirs_exist_ok flag in >python 3.8 only - shutil.copytree(fake_local_template_dir, fake_repo_path / "local_templates") - except FileExistsError: - pass + shutil.copytree( + fake_local_template_dir, + fake_repo_path / "local_templates", + dirs_exist_ok=True, + ) cmd = ["pipeline", "create", PIPELINE_NAME] cmd += ["-t", str(fake_repo_path / "local_templates/pipeline")] diff --git a/tests/framework/cli/test_catalog.py b/tests/framework/cli/test_catalog.py index f34034296e..8905da9c94 100644 --- a/tests/framework/cli/test_catalog.py +++ b/tests/framework/cli/test_catalog.py @@ -490,7 +490,6 @@ def test_rank_catalog_factories( mocked_context.catalog = DataCatalog.from_config( fake_catalog_with_overlapping_factories ) - print("!!!!", mocked_context.catalog._dataset_patterns) result = CliRunner().invoke( fake_project_cli, ["catalog", "rank"], obj=fake_metadata ) @@ -544,10 +543,11 @@ def test_catalog_resolve( "catalog": fake_catalog_config, "credentials": fake_credentials_config, } + mocked_context._get_config_credentials.return_value = fake_credentials_config mocked_context.catalog = DataCatalog.from_config( catalog=fake_catalog_config, credentials=fake_credentials_config ) - placeholder_ds = mocked_context.catalog._dataset_patterns.keys() + placeholder_ds = mocked_context.catalog.config_resolver.list_patterns() pipeline_datasets = {"csv_example", "parquet_example", "explicit_dataset"} mocker.patch.object( diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index 32f618d68f..7f2641da10 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -147,17 +147,11 @@ def _assert_requirements_ok( assert "Congratulations!" in result.output assert f"has been created in the directory \n{root_path}" in result.output - requirements_file_path = root_path / "requirements.txt" pyproject_file_path = root_path / "pyproject.toml" tools_list = _parse_tools_input(tools) if "1" in tools_list: - with open(requirements_file_path) as requirements_file: - requirements = requirements_file.read() - - assert "ruff" in requirements - pyproject_config = toml.load(pyproject_file_path) expected = { "tool": { @@ -171,15 +165,11 @@ def _assert_requirements_ok( } } assert expected["tool"]["ruff"] == pyproject_config["tool"]["ruff"] + assert ( + "ruff~=0.1.8" in pyproject_config["project"]["optional-dependencies"]["dev"] + ) if "2" in tools_list: - with open(requirements_file_path) as requirements_file: - requirements = requirements_file.read() - - assert "pytest-cov~=3.0" in requirements - assert "pytest-mock>=1.7.1, <2.0" in requirements - assert "pytest~=7.2" in requirements - pyproject_config = toml.load(pyproject_file_path) expected = { "pytest": { @@ -198,6 +188,18 @@ def _assert_requirements_ok( assert expected["pytest"] == pyproject_config["tool"]["pytest"] assert expected["coverage"] == pyproject_config["tool"]["coverage"] + assert ( + "pytest-cov~=3.0" + in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + assert ( + "pytest-mock>=1.7.1, <2.0" + in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + assert ( + "pytest~=7.2" in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + if "4" in tools_list: pyproject_config = toml.load(pyproject_file_path) expected = { diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index 61e4bbaa6f..ea62cb04c9 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -261,7 +261,7 @@ def test_wrong_catalog_type(self, mock_settings_file_bad_data_catalog_class): pattern = ( "Invalid value 'tests.framework.context.test_context.BadCatalog' received " "for setting 'DATA_CATALOG_CLASS'. " - "It must be a subclass of 'kedro.io.data_catalog.DataCatalog'." + "It must implement 'kedro.io.core.CatalogProtocol'." ) mock_settings = _ProjectSettings( settings_file=str(mock_settings_file_bad_data_catalog_class) diff --git a/tests/framework/project/test_settings.py b/tests/framework/project/test_settings.py index 74a2ac50ca..ec718067cc 100644 --- a/tests/framework/project/test_settings.py +++ b/tests/framework/project/test_settings.py @@ -7,7 +7,6 @@ from kedro.config import OmegaConfigLoader from kedro.framework.context.context import KedroContext from kedro.framework.project import configure_project, settings, validate_settings -from kedro.framework.session.shelvestore import ShelveStore from kedro.framework.session.store import BaseSessionStore from kedro.io import DataCatalog @@ -40,8 +39,8 @@ def mock_package_name_with_settings_file(tmpdir): DISABLE_HOOKS_FOR_PLUGINS = ("kedro-viz",) - from kedro.framework.session.shelvestore import ShelveStore - SESSION_STORE_CLASS = ShelveStore + from kedro.framework.session.store import BaseSessionStore + SESSION_STORE_CLASS = BaseSessionStore SESSION_STORE_ARGS = {{ "path": "./sessions" }} @@ -103,7 +102,7 @@ def test_settings_after_configuring_project_shows_updated_values( configure_project(mock_package_name_with_settings_file) assert len(settings.HOOKS) == 1 and isinstance(settings.HOOKS[0], ProjectHooks) assert settings.DISABLE_HOOKS_FOR_PLUGINS.to_list() == ["kedro-viz"] - assert settings.SESSION_STORE_CLASS is ShelveStore + assert settings.SESSION_STORE_CLASS is BaseSessionStore assert settings.SESSION_STORE_ARGS == {"path": "./sessions"} assert settings.CONTEXT_CLASS is MyContext assert settings.CONF_SOURCE == "test_conf" diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index bc25db37c7..1c67824d8d 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -5,7 +5,7 @@ import textwrap from collections.abc import Mapping from pathlib import Path -from typing import Any, Type +from typing import Any from unittest.mock import create_autospec import pytest @@ -22,12 +22,10 @@ ValidationError, Validator, _HasSharedParentClassValidator, - _IsSubclassValidator, _ProjectSettings, ) from kedro.framework.session import KedroSession from kedro.framework.session.session import KedroSessionError -from kedro.framework.session.shelvestore import ShelveStore from kedro.framework.session.store import BaseSessionStore from kedro.utils import _has_rich_handler @@ -52,7 +50,7 @@ class BadConfigLoader: NEW_TYPING = sys.version_info[:3] >= (3, 7, 0) # PEP 560 -def create_attrs_autospec(spec: Type, spec_set: bool = True) -> Any: +def create_attrs_autospec(spec: type, spec_set: bool = True) -> Any: """Creates a mock of an attr class (creates mocks recursively on all attributes). https://github.com/python-attrs/attrs/issues/462#issuecomment-1134656377 @@ -235,21 +233,6 @@ class MockSettings(_ProjectSettings): ) -@pytest.fixture -def mock_settings_shelve_session_store(mocker, fake_project): - shelve_location = fake_project / "nested" / "sessions" - - class MockSettings(_ProjectSettings): - _SESSION_STORE_CLASS = _IsSubclassValidator( - "SESSION_STORE_CLASS", default=lambda *_: ShelveStore - ) - _SESSION_STORE_ARGS = Validator( - "SESSION_STORE_ARGS", default={"path": shelve_location.as_posix()} - ) - - return _mock_imported_settings_paths(mocker, MockSettings()) - - @pytest.fixture def fake_session_id(mocker): session_id = "fake_session_id" @@ -502,26 +485,6 @@ def test_default_store(self, fake_project, fake_session_id, caplog): ] assert actual_log_messages == expected_log_messages - @pytest.mark.usefixtures("mock_settings_shelve_session_store") - def test_shelve_store(self, fake_project, fake_session_id, caplog, mocker): - mocker.patch("pathlib.Path.is_file", return_value=True) - shelve_location = fake_project / "nested" / "sessions" - other = KedroSession.create(fake_project) - assert other._store.__class__ is ShelveStore - assert other._store._path == shelve_location.as_posix() - assert other._store._location == shelve_location / fake_session_id / "store" - assert other._store._session_id == fake_session_id - assert not shelve_location.is_dir() - - other.close() # session data persisted - assert shelve_location.is_dir() - actual_log_messages = [ - rec.getMessage() - for rec in caplog.records - if rec.name == STORE_LOGGER_NAME and rec.levelno == logging.DEBUG - ] - assert not actual_log_messages - def test_wrong_store_type(self, mock_settings_file_bad_session_store_class): pattern = ( "Invalid value 'tests.framework.session.test_session.BadStore' received " @@ -730,7 +693,7 @@ def test_run_thread_runner( } mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret) mocker.patch( - "kedro.io.data_catalog.DataCatalog._match_pattern", + "kedro.io.data_catalog.CatalogConfigResolver.match_pattern", return_value=match_pattern, ) diff --git a/tests/framework/session/test_store.py b/tests/framework/session/test_store.py index fa728271e7..0ad1c054ce 100644 --- a/tests/framework/session/test_store.py +++ b/tests/framework/session/test_store.py @@ -1,9 +1,5 @@ import logging -from pathlib import Path -import pytest - -from kedro.framework.session.shelvestore import ShelveStore from kedro.framework.session.store import BaseSessionStore FAKE_SESSION_ID = "fake_session_id" @@ -48,42 +44,3 @@ def test_save(self, caplog): if rec.name == STORE_LOGGER_NAME and rec.levelno == logging.DEBUG ] assert actual_debug_messages == expected_debug_messages - - -@pytest.fixture -def shelve_path(tmp_path): - return Path(tmp_path / "path" / "to" / "sessions") - - -class TestShelveStore: - def test_empty(self, shelve_path): - shelve = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - assert shelve == {} - assert shelve._location == shelve_path / FAKE_SESSION_ID / "store" - assert not shelve_path.exists() - - def test_save(self, shelve_path): - assert not shelve_path.exists() - - shelve = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - shelve["shelve_path"] = shelve_path - shelve.save() - - assert (shelve_path / FAKE_SESSION_ID).is_dir() - - reloaded = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - assert reloaded == {"shelve_path": shelve_path} - - def test_update(self, shelve_path): - shelve = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - shelve["shelve_path"] = shelve_path - shelve.save() - - shelve.update(new_key="new_value") - del shelve["shelve_path"] - reloaded = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - assert reloaded == {"shelve_path": shelve_path} # changes not saved yet - - shelve.save() - reloaded = ShelveStore(str(shelve_path), FAKE_SESSION_ID) - assert reloaded == {"new_key": "new_value"} diff --git a/tests/io/conftest.py b/tests/io/conftest.py index 2cc38aa1ea..9abce4c83e 100644 --- a/tests/io/conftest.py +++ b/tests/io/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from kedro_datasets.pandas import CSVDataset @pytest.fixture @@ -21,3 +22,68 @@ def input_data(request): @pytest.fixture def new_data(): return pd.DataFrame({"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]}) + + +@pytest.fixture +def filepath(tmp_path): + return (tmp_path / "some" / "dir" / "test.csv").as_posix() + + +@pytest.fixture +def dataset(filepath): + return CSVDataset(filepath=filepath, save_args={"index": False}) + + +@pytest.fixture +def correct_config(filepath): + return { + "catalog": { + "boats": {"type": "pandas.CSVDataset", "filepath": filepath}, + "cars": { + "type": "pandas.CSVDataset", + "filepath": "s3://test_bucket/test_file.csv", + "credentials": "s3_credentials", + }, + }, + "credentials": { + "s3_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + }, + } + + +@pytest.fixture +def correct_config_with_nested_creds(correct_config): + correct_config["catalog"]["cars"]["credentials"] = { + "client_kwargs": {"credentials": "other_credentials"}, + "key": "secret", + } + correct_config["credentials"]["other_credentials"] = { + "client_kwargs": { + "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", + "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", + } + } + return correct_config + + +@pytest.fixture +def bad_config(filepath): + return { + "bad": {"type": "tests.io.test_data_catalog.BadDataset", "filepath": filepath} + } + + +@pytest.fixture +def correct_config_with_tracking_ds(tmp_path): + boat_path = (tmp_path / "some" / "dir" / "test.csv").as_posix() + plane_path = (tmp_path / "some" / "dir" / "metrics.json").as_posix() + return { + "catalog": { + "boats": { + "type": "pandas.CSVDataset", + "filepath": boat_path, + "versioned": True, + }, + "planes": {"type": "tracking.MetricsDataset", "filepath": plane_path}, + }, + } diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index dbec57e64d..54cbdf340d 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -1,6 +1,5 @@ import logging import re -import sys from copy import deepcopy from datetime import datetime, timezone from pathlib import Path @@ -29,64 +28,6 @@ ) -@pytest.fixture -def filepath(tmp_path): - return (tmp_path / "some" / "dir" / "test.csv").as_posix() - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -@pytest.fixture -def sane_config(filepath): - return { - "catalog": { - "boats": {"type": "pandas.CSVDataset", "filepath": filepath}, - "cars": { - "type": "pandas.CSVDataset", - "filepath": "s3://test_bucket/test_file.csv", - "credentials": "s3_credentials", - }, - }, - "credentials": { - "s3_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - }, - } - - -@pytest.fixture -def sane_config_with_nested_creds(sane_config): - sane_config["catalog"]["cars"]["credentials"] = { - "client_kwargs": {"credentials": "other_credentials"}, - "key": "secret", - } - sane_config["credentials"]["other_credentials"] = { - "client_kwargs": { - "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", - "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", - } - } - return sane_config - - -@pytest.fixture -def sane_config_with_tracking_ds(tmp_path): - boat_path = (tmp_path / "some" / "dir" / "test.csv").as_posix() - plane_path = (tmp_path / "some" / "dir" / "metrics.json").as_posix() - return { - "catalog": { - "boats": { - "type": "pandas.CSVDataset", - "filepath": boat_path, - "versioned": True, - }, - "planes": {"type": "tracking.MetricsDataset", "filepath": plane_path}, - }, - } - - @pytest.fixture def config_with_dataset_factories(): return { @@ -180,11 +121,6 @@ def config_with_dataset_factories_only_patterns_no_default( return config_with_dataset_factories_only_patterns -@pytest.fixture -def dataset(filepath): - return CSVDataset(filepath=filepath, save_args={"index": False}) - - @pytest.fixture def multi_catalog(): csv = CSVDataset(filepath="abc.csv") @@ -220,21 +156,14 @@ def _describe(self): return {} -@pytest.fixture -def bad_config(filepath): - return { - "bad": {"type": "tests.io.test_data_catalog.BadDataset", "filepath": filepath} - } - - @pytest.fixture def data_catalog(dataset): return DataCatalog(datasets={"test": dataset}) @pytest.fixture -def data_catalog_from_config(sane_config): - return DataCatalog.from_config(**sane_config) +def data_catalog_from_config(correct_config): + return DataCatalog.from_config(**correct_config) class TestDataCatalog: @@ -468,143 +397,126 @@ def test_key_completions(self, data_catalog_from_config): class TestDataCatalogFromConfig: - def test_from_sane_config(self, data_catalog_from_config, dummy_dataframe): + def test_from_correct_config(self, data_catalog_from_config, dummy_dataframe): """Test populating the data catalog from config""" data_catalog_from_config.save("boats", dummy_dataframe) reloaded_df = data_catalog_from_config.load("boats") assert_frame_equal(reloaded_df, dummy_dataframe) - def test_config_missing_type(self, sane_config): + def test_config_missing_type(self, correct_config): """Check the error if type attribute is missing for some data set(s) in the config""" - del sane_config["catalog"]["boats"]["type"] + del correct_config["catalog"]["boats"]["type"] pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "'type' is missing from dataset catalog configuration" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_module(self, sane_config): + def test_config_invalid_module(self, correct_config): """Check the error if the type points to nonexistent module""" - sane_config["catalog"]["boats"]["type"] = ( + correct_config["catalog"]["boats"]["type"] = ( "kedro.invalid_module_name.io.CSVDataset" ) error_msg = "Class 'kedro.invalid_module_name.io.CSVDataset' not found" with pytest.raises(DatasetError, match=re.escape(error_msg)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_relative_import(self, sane_config): + def test_config_relative_import(self, correct_config): """Check the error if the type points to a relative import""" - sane_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" + correct_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" pattern = "'type' class path does not support relative paths" with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_import_kedro_datasets(self, sane_config, mocker): + def test_config_import_kedro_datasets(self, correct_config, mocker): """Test kedro_datasets default path to the dataset class""" # Spy _load_obj because kedro_datasets is not installed and we can't import it. import kedro.io.core spy = mocker.spy(kedro.io.core, "_load_obj") - parse_dataset_definition(sane_config["catalog"]["boats"]) + parse_dataset_definition(correct_config["catalog"]["boats"]) for prefix, call_args in zip(_DEFAULT_PACKAGES, spy.call_args_list): # In Python 3.7 call_args.args is not available thus we access the call # arguments with less meaningful index. # The 1st index returns a tuple, the 2nd index return the name of module. assert call_args[0][0] == f"{prefix}pandas.CSVDataset" - def test_config_import_extras(self, sane_config): + def test_config_import_extras(self, correct_config): """Test kedro_datasets default path to the dataset class""" - sane_config["catalog"]["boats"]["type"] = "pandas.CSVDataset" - assert DataCatalog.from_config(**sane_config) + correct_config["catalog"]["boats"]["type"] = "pandas.CSVDataset" + assert DataCatalog.from_config(**correct_config) - def test_config_missing_class(self, sane_config): + def test_config_missing_class(self, correct_config): """Check the error if the type points to nonexistent class""" - sane_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" + correct_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "Class 'kedro.io.CSVDatasetInvalid' not found, is this a typo?" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) - - @pytest.mark.skipif( - sys.version_info < (3, 9), - reason="for python 3.8 kedro-datasets version 1.8 is used which has the old spelling", - ) - def test_config_incorrect_spelling(self, sane_config): - """Check hint if the type uses the old DataSet spelling""" - sane_config["catalog"]["boats"]["type"] = "pandas.CSVDataSet" - - pattern = ( - "An exception occurred when parsing config for dataset 'boats':\n" - "Class 'pandas.CSVDataSet' not found, is this a typo?" - "\nHint: If you are trying to use a dataset from `kedro-datasets`>=2.0.0," - " make sure that the dataset name uses the `Dataset` spelling instead of `DataSet`." - ) - with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_dataset(self, sane_config): + def test_config_invalid_dataset(self, correct_config): """Check the error if the type points to invalid class""" - sane_config["catalog"]["boats"]["type"] = "DataCatalog" + correct_config["catalog"]["boats"]["type"] = "DataCatalog" pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "Dataset type 'kedro.io.data_catalog.DataCatalog' is invalid: " "all data set types must extend 'AbstractDataset'" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_arguments(self, sane_config): + def test_config_invalid_arguments(self, correct_config): """Check the error if the data set config contains invalid arguments""" - sane_config["catalog"]["boats"]["save_and_load_args"] = False + correct_config["catalog"]["boats"]["save_and_load_args"] = False pattern = ( r"Dataset 'boats' must only contain arguments valid for " r"the constructor of '.*CSVDataset'" ) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_dataset_config(self, sane_config): - sane_config["catalog"]["invalid_entry"] = "some string" + def test_config_invalid_dataset_config(self, correct_config): + correct_config["catalog"]["invalid_entry"] = "some string" pattern = ( "Catalog entry 'invalid_entry' is not a valid dataset configuration. " "\nHint: If this catalog entry is intended for variable interpolation, " "make sure that the key is preceded by an underscore." ) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) def test_empty_config(self): """Test empty config""" assert DataCatalog.from_config(None) - def test_missing_credentials(self, sane_config): + def test_missing_credentials(self, correct_config): """Check the error if credentials can't be located""" - sane_config["catalog"]["cars"]["credentials"] = "missing" + correct_config["catalog"]["cars"]["credentials"] = "missing" with pytest.raises(KeyError, match=r"Unable to find credentials \'missing\'"): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_link_credentials(self, sane_config, mocker): + def test_link_credentials(self, correct_config, mocker): """Test credentials being linked to the relevant data set""" mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") - config = deepcopy(sane_config) + config = deepcopy(correct_config) del config["catalog"]["boats"] DataCatalog.from_config(**config) - expected_client_kwargs = sane_config["credentials"]["s3_credentials"] + expected_client_kwargs = correct_config["credentials"]["s3_credentials"] mock_client.filesystem.assert_called_with("s3", **expected_client_kwargs) - def test_nested_credentials(self, sane_config_with_nested_creds, mocker): + def test_nested_credentials(self, correct_config_with_nested_creds, mocker): mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") - config = deepcopy(sane_config_with_nested_creds) + config = deepcopy(correct_config_with_nested_creds) del config["catalog"]["boats"] DataCatalog.from_config(**config) @@ -621,13 +533,13 @@ def test_nested_credentials(self, sane_config_with_nested_creds, mocker): } mock_client.filesystem.assert_called_once_with("s3", **expected_client_kwargs) - def test_missing_nested_credentials(self, sane_config_with_nested_creds): - del sane_config_with_nested_creds["credentials"]["other_credentials"] + def test_missing_nested_credentials(self, correct_config_with_nested_creds): + del correct_config_with_nested_creds["credentials"]["other_credentials"] pattern = "Unable to find credentials 'other_credentials'" with pytest.raises(KeyError, match=pattern): - DataCatalog.from_config(**sane_config_with_nested_creds) + DataCatalog.from_config(**correct_config_with_nested_creds) - def test_missing_dependency(self, sane_config, mocker): + def test_missing_dependency(self, correct_config, mocker): """Test that dependency is missing.""" pattern = "dependency issue" @@ -639,12 +551,12 @@ def dummy_load(obj_path, *args, **kwargs): mocker.patch("kedro.io.core.load_obj", side_effect=dummy_load) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_idempotent_catalog(self, sane_config): + def test_idempotent_catalog(self, correct_config): """Test that data catalog instantiations are idempotent""" - _ = DataCatalog.from_config(**sane_config) - catalog = DataCatalog.from_config(**sane_config) + _ = DataCatalog.from_config(**correct_config) + catalog = DataCatalog.from_config(**correct_config) assert catalog def test_error_dataset_init(self, bad_config): @@ -684,18 +596,18 @@ def test_confirm(self, tmp_path, caplog, mocker): ("boats", "Dataset 'boats' does not have 'confirm' method"), ], ) - def test_bad_confirm(self, sane_config, dataset_name, pattern): + def test_bad_confirm(self, correct_config, dataset_name, pattern): """Test confirming non existent dataset or the one that does not have `confirm` method""" - data_catalog = DataCatalog.from_config(**sane_config) + data_catalog = DataCatalog.from_config(**correct_config) with pytest.raises(DatasetError, match=re.escape(pattern)): data_catalog.confirm(dataset_name) class TestDataCatalogVersioned: - def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): + def test_from_correct_config_versioned(self, correct_config, dummy_dataframe): """Test load and save of versioned data sets from config""" - sane_config["catalog"]["boats"]["versioned"] = True + correct_config["catalog"]["boats"]["versioned"] = True # Decompose `generate_timestamp` to keep `current_ts` reference. current_ts = datetime.now(tz=timezone.utc) @@ -706,13 +618,13 @@ def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): version = fmt.format(d=current_ts, ms=current_ts.microsecond // 1000) catalog = DataCatalog.from_config( - **sane_config, + **correct_config, load_versions={"boats": version}, save_version=version, ) catalog.save("boats", dummy_dataframe) - path = Path(sane_config["catalog"]["boats"]["filepath"]) + path = Path(correct_config["catalog"]["boats"]["filepath"]) path = path / version / path.name assert path.is_file() @@ -733,12 +645,14 @@ def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): assert actual_timestamp == expected_timestamp @pytest.mark.parametrize("versioned", [True, False]) - def test_from_sane_config_versioned_warn(self, caplog, sane_config, versioned): + def test_from_correct_config_versioned_warn( + self, caplog, correct_config, versioned + ): """Check the warning if `version` attribute was added to the data set config""" - sane_config["catalog"]["boats"]["versioned"] = versioned - sane_config["catalog"]["boats"]["version"] = True - DataCatalog.from_config(**sane_config) + correct_config["catalog"]["boats"]["versioned"] = versioned + correct_config["catalog"]["boats"]["version"] = True + DataCatalog.from_config(**correct_config) log_record = caplog.records[0] expected_log_message = ( "'version' attribute removed from data set configuration since it " @@ -747,21 +661,21 @@ def test_from_sane_config_versioned_warn(self, caplog, sane_config, versioned): assert log_record.levelname == "WARNING" assert expected_log_message in log_record.message - def test_from_sane_config_load_versions_warn(self, sane_config): - sane_config["catalog"]["boats"]["versioned"] = True + def test_from_correct_config_load_versions_warn(self, correct_config): + correct_config["catalog"]["boats"]["versioned"] = True version = generate_timestamp() - load_version = {"non-boart": version} - pattern = r"\'load_versions\' keys \[non-boart\] are not found in the catalog\." + load_version = {"non-boat": version} + pattern = r"\'load_versions\' keys \[non-boat\] are not found in the catalog\." with pytest.raises(DatasetNotFoundError, match=pattern): - DataCatalog.from_config(**sane_config, load_versions=load_version) + DataCatalog.from_config(**correct_config, load_versions=load_version) def test_compare_tracking_and_other_dataset_versioned( - self, sane_config_with_tracking_ds, dummy_dataframe + self, correct_config_with_tracking_ds, dummy_dataframe ): """Test saving of tracking data sets from config results in the same save version as other versioned datasets.""" - catalog = DataCatalog.from_config(**sane_config_with_tracking_ds) + catalog = DataCatalog.from_config(**correct_config_with_tracking_ds) catalog.save("boats", dummy_dataframe) dummy_data = {"col1": 1, "col2": 2, "col3": 3} @@ -779,20 +693,20 @@ def test_compare_tracking_and_other_dataset_versioned( assert tracking_timestamp == csv_timestamp - def test_load_version(self, sane_config, dummy_dataframe, mocker): + def test_load_version(self, correct_config, dummy_dataframe, mocker): """Test load versioned data sets from config""" new_dataframe = pd.DataFrame({"col1": [0, 0], "col2": [0, 0], "col3": [0, 0]}) - sane_config["catalog"]["boats"]["versioned"] = True + correct_config["catalog"]["boats"]["versioned"] = True mocker.patch( "kedro.io.data_catalog.generate_timestamp", side_effect=["first", "second"] ) # save first version of the dataset - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", dummy_dataframe) # save second version of the dataset - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", new_dataframe) assert_frame_equal(catalog.load("boats", version="first"), dummy_dataframe) @@ -800,11 +714,11 @@ def test_load_version(self, sane_config, dummy_dataframe, mocker): assert_frame_equal(catalog.load("boats"), new_dataframe) def test_load_version_on_unversioned_dataset( - self, sane_config, dummy_dataframe, mocker + self, correct_config, dummy_dataframe, mocker ): mocker.patch("kedro.io.data_catalog.generate_timestamp", return_value="first") - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", dummy_dataframe) with pytest.raises(DatasetError): @@ -846,7 +760,7 @@ def test_match_added_to_datasets_on_get(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "{brand}_cars" not in catalog._datasets assert "tesla_cars" not in catalog._datasets - assert "{brand}_cars" in catalog._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns tesla_cars = catalog._get_dataset("tesla_cars") assert isinstance(tesla_cars, CSVDataset) @@ -875,8 +789,8 @@ def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "audi_cars" in catalog._datasets assert "{brand}_cars" not in catalog._datasets - assert "audi_cars" not in catalog._dataset_patterns - assert "{brand}_cars" in catalog._dataset_patterns + assert "audi_cars" not in catalog.config_resolver._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns def test_explicit_entry_not_overwritten(self, config_with_dataset_factories): """Check that the existing catalog entry is not overwritten by config in pattern""" @@ -909,11 +823,7 @@ def test_sorting_order_patterns(self, config_with_dataset_factories_only_pattern "{dataset}s", "{user_default}", ] - assert ( - list(catalog._dataset_patterns.keys()) - + list(catalog._default_pattern.keys()) - == sorted_keys_expected - ) + assert catalog.config_resolver.list_patterns() == sorted_keys_expected def test_multiple_catch_all_patterns_not_allowed( self, config_with_dataset_factories @@ -953,13 +863,13 @@ def test_sorting_order_with_other_dataset_through_extra_pattern( ) sorted_keys_expected = [ "{country}_companies", - "{another}#csv", "{namespace}_{dataset}", "{dataset}s", + "{another}#csv", "{default}", ] assert ( - list(catalog_with_default._dataset_patterns.keys()) == sorted_keys_expected + catalog_with_default.config_resolver.list_patterns() == sorted_keys_expected ) def test_user_default_overwrites_runner_default(self): @@ -988,11 +898,15 @@ def test_user_default_overwrites_runner_default(self): sorted_keys_expected = [ "{dataset}s", "{a_default}", + "{another}#csv", + "{default}", ] - assert "{a_default}" in catalog_with_runner_default._default_pattern assert ( - list(catalog_with_runner_default._dataset_patterns.keys()) - + list(catalog_with_runner_default._default_pattern.keys()) + "{a_default}" + in catalog_with_runner_default.config_resolver._default_pattern + ) + assert ( + catalog_with_runner_default.config_resolver.list_patterns() == sorted_keys_expected ) @@ -1014,13 +928,12 @@ def test_unmatched_key_error_when_parsing_config( self, config_with_dataset_factories_bad_pattern ): """Check error raised when key mentioned in the config is not in pattern name""" - catalog = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern) pattern = ( - "Unable to resolve 'data/01_raw/{brand}_plane.pq' from the pattern '{type}@planes'. " - "Keys used in the configuration should be present in the dataset factory pattern." + "Incorrect dataset configuration provided. Keys used in the configuration {'{brand}'} " + "should present in the dataset factory pattern name {type}@planes." ) with pytest.raises(DatasetError, match=re.escape(pattern)): - catalog._get_dataset("jet@planes") + _ = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern) def test_factory_config_versioned( self, config_with_dataset_factories, filepath, dummy_dataframe diff --git a/tests/io/test_kedro_data_catalog.py b/tests/io/test_kedro_data_catalog.py new file mode 100644 index 0000000000..5e0c463e7d --- /dev/null +++ b/tests/io/test_kedro_data_catalog.py @@ -0,0 +1,632 @@ +import logging +import re +from copy import deepcopy +from datetime import datetime, timezone +from pathlib import Path + +import pandas as pd +import pytest +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from pandas.testing import assert_frame_equal + +from kedro.io import ( + DatasetAlreadyExistsError, + DatasetError, + DatasetNotFoundError, + KedroDataCatalog, + LambdaDataset, + MemoryDataset, +) +from kedro.io.core import ( + _DEFAULT_PACKAGES, + VERSION_FORMAT, + generate_timestamp, + parse_dataset_definition, +) + + +@pytest.fixture +def data_catalog(dataset): + return KedroDataCatalog(datasets={"test": dataset}) + + +@pytest.fixture +def memory_catalog(): + ds1 = MemoryDataset({"data": 42}) + ds2 = MemoryDataset([1, 2, 3, 4, 5]) + return KedroDataCatalog({"ds1": ds1, "ds2": ds2}) + + +@pytest.fixture +def conflicting_feed_dict(): + return {"ds1": 0, "ds3": 1} + + +@pytest.fixture +def multi_catalog(): + csv = CSVDataset(filepath="abc.csv") + parq = ParquetDataset(filepath="xyz.parq") + return KedroDataCatalog({"abc": csv, "xyz": parq}) + + +@pytest.fixture +def data_catalog_from_config(correct_config): + return KedroDataCatalog.from_config(**correct_config) + + +class TestKedroDataCatalog: + def test_save_and_load(self, data_catalog, dummy_dataframe): + """Test saving and reloading the dataset""" + data_catalog.save("test", dummy_dataframe) + reloaded_df = data_catalog.load("test") + + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_add_save_and_load(self, dataset, dummy_dataframe): + """Test adding and then saving and reloading the dataset""" + catalog = KedroDataCatalog(datasets={}) + catalog.add("test", dataset) + catalog.save("test", dummy_dataframe) + reloaded_df = catalog.load("test") + + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_load_error(self, data_catalog): + """Check the error when attempting to load a dataset + from nonexistent source""" + pattern = r"Failed while loading data from data set CSVDataset" + with pytest.raises(DatasetError, match=pattern): + data_catalog.load("test") + + def test_add_dataset_twice(self, data_catalog, dataset): + """Check the error when attempting to add the dataset twice""" + pattern = r"Dataset 'test' has already been registered" + with pytest.raises(DatasetAlreadyExistsError, match=pattern): + data_catalog.add("test", dataset) + + def test_load_from_unregistered(self): + """Check the error when attempting to load unregistered dataset""" + catalog = KedroDataCatalog(datasets={}) + pattern = r"Dataset 'test' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern): + catalog.load("test") + + def test_save_to_unregistered(self, dummy_dataframe): + """Check the error when attempting to save to unregistered dataset""" + catalog = KedroDataCatalog(datasets={}) + pattern = r"Dataset 'test' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern): + catalog.save("test", dummy_dataframe) + + def test_feed_dict(self, memory_catalog, conflicting_feed_dict): + """Test feed dict overriding some of the datasets""" + assert "data" in memory_catalog.load("ds1") + memory_catalog.add_feed_dict(conflicting_feed_dict, replace=True) + assert memory_catalog.load("ds1") == 0 + assert isinstance(memory_catalog.load("ds2"), list) + assert memory_catalog.load("ds3") == 1 + + def test_exists(self, data_catalog, dummy_dataframe): + """Test `exists` method invocation""" + assert not data_catalog.exists("test") + data_catalog.save("test", dummy_dataframe) + assert data_catalog.exists("test") + + def test_exists_not_implemented(self, caplog): + """Test calling `exists` on the dataset, which didn't implement it""" + catalog = KedroDataCatalog(datasets={"test": LambdaDataset(None, None)}) + result = catalog.exists("test") + + log_record = caplog.records[0] + assert log_record.levelname == "WARNING" + assert ( + "'exists()' not implemented for 'LambdaDataset'. " + "Assuming output does not exist." in log_record.message + ) + assert result is False + + def test_exists_invalid(self, data_catalog): + """Check the error when calling `exists` on invalid dataset""" + assert not data_catalog.exists("wrong_key") + + def test_release_unregistered(self, data_catalog): + """Check the error when calling `release` on unregistered dataset""" + pattern = r"Dataset \'wrong_key\' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern) as e: + data_catalog.release("wrong_key") + assert "did you mean" not in str(e.value) + + def test_release_unregistered_typo(self, data_catalog): + """Check the error when calling `release` on mistyped dataset""" + pattern = ( + "Dataset 'text' not found in the catalog" + " - did you mean one of these instead: test" + ) + with pytest.raises(DatasetNotFoundError, match=re.escape(pattern)): + data_catalog.release("text") + + def test_multi_catalog_list(self, multi_catalog): + """Test data catalog which contains multiple datasets""" + entries = multi_catalog.list() + assert "abc" in entries + assert "xyz" in entries + + @pytest.mark.parametrize( + "pattern,expected", + [ + ("^a", ["abc"]), + ("a|x", ["abc", "xyz"]), + ("^(?!(a|x))", []), + ("def", []), + ("", []), + ], + ) + def test_multi_catalog_list_regex(self, multi_catalog, pattern, expected): + """Test that regex patterns filter datasets accordingly""" + assert multi_catalog.list(regex_search=pattern) == expected + + def test_multi_catalog_list_bad_regex(self, multi_catalog): + """Test that bad regex is caught accordingly""" + escaped_regex = r"\(\(" + pattern = f"Invalid regular expression provided: '{escaped_regex}'" + with pytest.raises(SyntaxError, match=pattern): + multi_catalog.list("((") + + def test_eq(self, multi_catalog, data_catalog): + assert multi_catalog == multi_catalog.shallow_copy() + assert multi_catalog != data_catalog + + def test_datasets_on_init(self, data_catalog_from_config): + """Check datasets are loaded correctly on construction""" + assert isinstance(data_catalog_from_config.datasets["boats"], CSVDataset) + assert isinstance(data_catalog_from_config.datasets["cars"], CSVDataset) + + def test_datasets_on_add(self, data_catalog_from_config): + """Check datasets are updated correctly after adding""" + data_catalog_from_config.add("new_dataset", CSVDataset(filepath="some_path")) + assert isinstance(data_catalog_from_config.datasets["new_dataset"], CSVDataset) + assert isinstance(data_catalog_from_config.datasets["boats"], CSVDataset) + + def test_adding_datasets_not_allowed(self, data_catalog_from_config): + """Check error if user tries to update the datasets attribute""" + pattern = r"Operation not allowed. Please use KedroDataCatalog.add\(\) instead." + with pytest.raises(AttributeError, match=pattern): + data_catalog_from_config.datasets = None + + def test_confirm(self, mocker, caplog): + """Confirm the dataset""" + with caplog.at_level(logging.INFO): + mock_ds = mocker.Mock() + data_catalog = KedroDataCatalog(datasets={"mocked": mock_ds}) + data_catalog.confirm("mocked") + mock_ds.confirm.assert_called_once_with() + assert caplog.record_tuples == [ + ( + "kedro.io.kedro_data_catalog", + logging.INFO, + "Confirming dataset 'mocked'", + ) + ] + + @pytest.mark.parametrize( + "dataset_name,error_pattern", + [ + ("missing", "Dataset 'missing' not found in the catalog"), + ("test", "Dataset 'test' does not have 'confirm' method"), + ], + ) + def test_bad_confirm(self, data_catalog, dataset_name, error_pattern): + """Test confirming a non-existent dataset or one that + does not have `confirm` method""" + with pytest.raises(DatasetError, match=re.escape(error_pattern)): + data_catalog.confirm(dataset_name) + + def test_shallow_copy_returns_correct_class_type( + self, + ): + class MyDataCatalog(KedroDataCatalog): + pass + + data_catalog = MyDataCatalog() + copy = data_catalog.shallow_copy() + assert isinstance(copy, MyDataCatalog) + + @pytest.mark.parametrize( + "runtime_patterns,sorted_keys_expected", + [ + ( + { + "{default}": {"type": "MemoryDataset"}, + "{another}#csv": { + "type": "pandas.CSVDataset", + "filepath": "data/{another}.csv", + }, + }, + ["{another}#csv", "{default}"], + ) + ], + ) + def test_shallow_copy_adds_patterns( + self, data_catalog, runtime_patterns, sorted_keys_expected + ): + assert not data_catalog.config_resolver.list_patterns() + data_catalog = data_catalog.shallow_copy(runtime_patterns) + assert data_catalog.config_resolver.list_patterns() == sorted_keys_expected + + def test_init_with_raw_data(self, dummy_dataframe, dataset): + """Test catalog initialisation with raw data""" + catalog = KedroDataCatalog( + datasets={"ds": dataset}, raw_data={"df": dummy_dataframe} + ) + assert "ds" in catalog + assert "df" in catalog + assert isinstance(catalog.datasets["ds"], CSVDataset) + assert isinstance(catalog.datasets["df"], MemoryDataset) + + def test_repr(self, data_catalog): + assert data_catalog.__repr__() == str(data_catalog) + + def test_missing_keys_from_load_versions(self, correct_config): + """Test load versions include keys missing in the catalog""" + pattern = "'load_versions' keys [version] are not found in the catalog." + with pytest.raises(DatasetNotFoundError, match=re.escape(pattern)): + KedroDataCatalog.from_config( + **correct_config, load_versions={"version": "test_version"} + ) + + def test_get_dataset_matching_pattern(self, data_catalog): + """Test get_dataset() when dataset is not in the catalog but pattern matches""" + match_pattern_ds = "match_pattern_ds" + assert match_pattern_ds not in data_catalog + data_catalog.config_resolver.add_runtime_patterns( + {"{default}": {"type": "MemoryDataset"}} + ) + ds = data_catalog.get_dataset(match_pattern_ds) + assert isinstance(ds, MemoryDataset) + + def test_release(self, data_catalog): + """Test release is called without errors""" + data_catalog.release("test") + + class TestKedroDataCatalogFromConfig: + def test_from_correct_config(self, data_catalog_from_config, dummy_dataframe): + """Test populating the data catalog from config""" + data_catalog_from_config.save("boats", dummy_dataframe) + reloaded_df = data_catalog_from_config.load("boats") + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_config_missing_type(self, correct_config): + """Check the error if type attribute is missing for some dataset(s) + in the config""" + del correct_config["catalog"]["boats"]["type"] + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "'type' is missing from dataset catalog configuration" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_module(self, correct_config): + """Check the error if the type points to nonexistent module""" + correct_config["catalog"]["boats"]["type"] = ( + "kedro.invalid_module_name.io.CSVDataset" + ) + + error_msg = "Class 'kedro.invalid_module_name.io.CSVDataset' not found" + with pytest.raises(DatasetError, match=re.escape(error_msg)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_relative_import(self, correct_config): + """Check the error if the type points to a relative import""" + correct_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" + + pattern = "'type' class path does not support relative paths" + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_import_kedro_datasets(self, correct_config, mocker): + """Test kedro_datasets default path to the dataset class""" + # Spy _load_obj because kedro_datasets is not installed and we can't import it. + + import kedro.io.core + + spy = mocker.spy(kedro.io.core, "_load_obj") + parse_dataset_definition(correct_config["catalog"]["boats"]) + for prefix, call_args in zip(_DEFAULT_PACKAGES, spy.call_args_list): + assert call_args.args[0] == f"{prefix}pandas.CSVDataset" + + def test_config_missing_class(self, correct_config): + """Check the error if the type points to nonexistent class""" + correct_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" + + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "Class 'kedro.io.CSVDatasetInvalid' not found, is this a typo?" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_dataset(self, correct_config): + """Check the error if the type points to invalid class""" + correct_config["catalog"]["boats"]["type"] = "KedroDataCatalog" + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "Dataset type 'kedro.io.kedro_data_catalog.KedroDataCatalog' is invalid: " + "all data set types must extend 'AbstractDataset'" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_arguments(self, correct_config): + """Check the error if the dataset config contains invalid arguments""" + correct_config["catalog"]["boats"]["save_and_load_args"] = False + pattern = ( + r"Dataset 'boats' must only contain arguments valid for " + r"the constructor of '.*CSVDataset'" + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_dataset_config(self, correct_config): + correct_config["catalog"]["invalid_entry"] = "some string" + pattern = ( + "Catalog entry 'invalid_entry' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_empty_config(self): + """Test empty config""" + assert KedroDataCatalog.from_config(None) + + def test_missing_credentials(self, correct_config): + """Check the error if credentials can't be located""" + correct_config["catalog"]["cars"]["credentials"] = "missing" + with pytest.raises( + KeyError, match=r"Unable to find credentials \'missing\'" + ): + KedroDataCatalog.from_config(**correct_config) + + def test_link_credentials(self, correct_config, mocker): + """Test credentials being linked to the relevant dataset""" + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") + config = deepcopy(correct_config) + del config["catalog"]["boats"] + + KedroDataCatalog.from_config(**config) + + expected_client_kwargs = correct_config["credentials"]["s3_credentials"] + mock_client.filesystem.assert_called_with("s3", **expected_client_kwargs) + + def test_nested_credentials(self, correct_config_with_nested_creds, mocker): + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") + config = deepcopy(correct_config_with_nested_creds) + del config["catalog"]["boats"] + KedroDataCatalog.from_config(**config) + + expected_client_kwargs = { + "client_kwargs": { + "credentials": { + "client_kwargs": { + "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", + "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", + } + } + }, + "key": "secret", + } + mock_client.filesystem.assert_called_once_with( + "s3", **expected_client_kwargs + ) + + def test_missing_nested_credentials(self, correct_config_with_nested_creds): + del correct_config_with_nested_creds["credentials"]["other_credentials"] + pattern = "Unable to find credentials 'other_credentials'" + with pytest.raises(KeyError, match=pattern): + KedroDataCatalog.from_config(**correct_config_with_nested_creds) + + def test_missing_dependency(self, correct_config, mocker): + """Test that dependency is missing.""" + pattern = "dependency issue" + + def dummy_load(obj_path, *args, **kwargs): + if obj_path == "kedro_datasets.pandas.CSVDataset": + raise AttributeError(pattern) + if obj_path == "kedro_datasets.pandas.__all__": + return ["CSVDataset"] + + mocker.patch("kedro.io.core.load_obj", side_effect=dummy_load) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_idempotent_catalog(self, correct_config): + """Test that data catalog instantiations are idempotent""" + _ = KedroDataCatalog.from_config(**correct_config) + catalog = KedroDataCatalog.from_config(**correct_config) + assert catalog + + def test_error_dataset_init(self, bad_config): + """Check the error when trying to instantiate erroneous dataset""" + pattern = r"Failed to instantiate dataset \'bad\' of type '.*BadDataset'" + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(bad_config, None) + + def test_validate_dataset_config(self): + """Test _validate_dataset_config raises error when wrong dataset config type is passed""" + pattern = ( + "Catalog entry 'bad' is not a valid dataset configuration. \n" + "Hint: If this catalog entry is intended for variable interpolation, make sure that the key is preceded by an underscore." + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog._validate_dataset_config( + ds_name="bad", ds_config="not_dict" + ) + + def test_confirm(self, tmp_path, caplog, mocker): + """Confirm the dataset""" + with caplog.at_level(logging.INFO): + mock_confirm = mocker.patch( + "kedro_datasets.partitions.incremental_dataset.IncrementalDataset.confirm" + ) + catalog = { + "ds_to_confirm": { + "type": "kedro_datasets.partitions.incremental_dataset.IncrementalDataset", + "dataset": "pandas.CSVDataset", + "path": str(tmp_path), + } + } + data_catalog = KedroDataCatalog.from_config(catalog=catalog) + data_catalog.confirm("ds_to_confirm") + assert caplog.record_tuples == [ + ( + "kedro.io.kedro_data_catalog", + logging.INFO, + "Confirming dataset 'ds_to_confirm'", + ) + ] + mock_confirm.assert_called_once_with() + + @pytest.mark.parametrize( + "dataset_name,pattern", + [ + ("missing", "Dataset 'missing' not found in the catalog"), + ("boats", "Dataset 'boats' does not have 'confirm' method"), + ], + ) + def test_bad_confirm(self, correct_config, dataset_name, pattern): + """Test confirming non existent dataset or the one that + does not have `confirm` method""" + data_catalog = KedroDataCatalog.from_config(**correct_config) + with pytest.raises(DatasetError, match=re.escape(pattern)): + data_catalog.confirm(dataset_name) + + class TestDataCatalogVersioned: + def test_from_correct_config_versioned(self, correct_config, dummy_dataframe): + """Test load and save of versioned datasets from config""" + correct_config["catalog"]["boats"]["versioned"] = True + + # Decompose `generate_timestamp` to keep `current_ts` reference. + current_ts = datetime.now(tz=timezone.utc) + fmt = ( + "{d.year:04d}-{d.month:02d}-{d.day:02d}T{d.hour:02d}" + ".{d.minute:02d}.{d.second:02d}.{ms:03d}Z" + ) + version = fmt.format(d=current_ts, ms=current_ts.microsecond // 1000) + + catalog = KedroDataCatalog.from_config( + **correct_config, + load_versions={"boats": version}, + save_version=version, + ) + + catalog.save("boats", dummy_dataframe) + path = Path(correct_config["catalog"]["boats"]["filepath"]) + path = path / version / path.name + assert path.is_file() + + reloaded_df = catalog.load("boats") + assert_frame_equal(reloaded_df, dummy_dataframe) + + reloaded_df_version = catalog.load("boats", version=version) + assert_frame_equal(reloaded_df_version, dummy_dataframe) + + # Verify that `VERSION_FORMAT` can help regenerate `current_ts`. + actual_timestamp = datetime.strptime( + catalog.datasets["boats"].resolve_load_version(), + VERSION_FORMAT, + ) + expected_timestamp = current_ts.replace( + microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None + ) + assert actual_timestamp == expected_timestamp + + @pytest.mark.parametrize("versioned", [True, False]) + def test_from_correct_config_versioned_warn( + self, caplog, correct_config, versioned + ): + """Check the warning if `version` attribute was added + to the dataset config""" + correct_config["catalog"]["boats"]["versioned"] = versioned + correct_config["catalog"]["boats"]["version"] = True + KedroDataCatalog.from_config(**correct_config) + log_record = caplog.records[0] + expected_log_message = ( + "'version' attribute removed from data set configuration since it " + "is a reserved word and cannot be directly specified" + ) + assert log_record.levelname == "WARNING" + assert expected_log_message in log_record.message + + def test_from_correct_config_load_versions_warn(self, correct_config): + correct_config["catalog"]["boats"]["versioned"] = True + version = generate_timestamp() + load_version = {"non-boat": version} + pattern = ( + r"\'load_versions\' keys \[non-boat\] are not found in the catalog\." + ) + with pytest.raises(DatasetNotFoundError, match=pattern): + KedroDataCatalog.from_config( + **correct_config, load_versions=load_version + ) + + def test_compare_tracking_and_other_dataset_versioned( + self, correct_config_with_tracking_ds, dummy_dataframe + ): + """Test saving of tracking datasets from config results in the same + save version as other versioned datasets.""" + + catalog = KedroDataCatalog.from_config(**correct_config_with_tracking_ds) + + catalog.save("boats", dummy_dataframe) + dummy_data = {"col1": 1, "col2": 2, "col3": 3} + catalog.save("planes", dummy_data) + + # Verify that saved version on tracking dataset is the same as on the CSV dataset + csv_timestamp = datetime.strptime( + catalog.datasets["boats"].resolve_save_version(), + VERSION_FORMAT, + ) + tracking_timestamp = datetime.strptime( + catalog.datasets["planes"].resolve_save_version(), + VERSION_FORMAT, + ) + + assert tracking_timestamp == csv_timestamp + + def test_load_version(self, correct_config, dummy_dataframe, mocker): + """Test load versioned datasets from config""" + new_dataframe = pd.DataFrame( + {"col1": [0, 0], "col2": [0, 0], "col3": [0, 0]} + ) + correct_config["catalog"]["boats"]["versioned"] = True + mocker.patch( + "kedro.io.kedro_data_catalog.generate_timestamp", + side_effect=["first", "second"], + ) + + # save first version of the dataset + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", dummy_dataframe) + + # save second version of the dataset + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", new_dataframe) + + assert_frame_equal(catalog.load("boats", version="first"), dummy_dataframe) + assert_frame_equal(catalog.load("boats", version="second"), new_dataframe) + assert_frame_equal(catalog.load("boats"), new_dataframe) + + def test_load_version_on_unversioned_dataset( + self, correct_config, dummy_dataframe, mocker + ): + mocker.patch( + "kedro.io.kedro_data_catalog.generate_timestamp", return_value="first" + ) + + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", dummy_dataframe) + + with pytest.raises(DatasetError): + catalog.load("boats", version="first") diff --git a/tests/io/test_shared_memory_dataset.py b/tests/io/test_shared_memory_dataset.py index d135b3aadd..a4b3526aa4 100644 --- a/tests/io/test_shared_memory_dataset.py +++ b/tests/io/test_shared_memory_dataset.py @@ -79,3 +79,9 @@ def test_saving_none(self, shared_memory_dataset): def test_str_representation(self, shared_memory_dataset): """Test string representation of the dataset""" assert "MemoryDataset" in str(shared_memory_dataset) + + def test_exists(self, shared_memory_dataset, input_data): + """Check that exists returns the expected values""" + assert not shared_memory_dataset.exists() + shared_memory_dataset.save(input_data) + assert shared_memory_dataset.exists() diff --git a/tests/runner/test_resume_logic.py b/tests/runner/test_resume_logic.py index bd1f8e8acb..c733c504c5 100644 --- a/tests/runner/test_resume_logic.py +++ b/tests/runner/test_resume_logic.py @@ -153,6 +153,6 @@ def test_suggestion_consistency( test_pipeline, remaining_nodes, persistent_dataset_catalog ) - assert set(n.name for n in required_nodes) == set( + assert {n.name for n in required_nodes} == { n.name for n in test_pipeline.from_nodes(*resume_node_names).nodes - ) + } diff --git a/tests/runner/test_sequential_runner.py b/tests/runner/test_sequential_runner.py index dbc73a30f0..4f22bab296 100644 --- a/tests/runner/test_sequential_runner.py +++ b/tests/runner/test_sequential_runner.py @@ -130,7 +130,9 @@ def test_conflict_feed_catalog( def test_unsatisfied_inputs(self, is_async, unfinished_outputs_pipeline, catalog): """ds1, ds2 and ds3 were not specified.""" - with pytest.raises(ValueError, match=r"not found in the DataCatalog"): + with pytest.raises( + ValueError, match=rf"not found in the {catalog.__class__.__name__}" + ): SequentialRunner(is_async=is_async).run( unfinished_outputs_pipeline, catalog )