diff --git a/RELEASE.md b/RELEASE.md index 34e75ffb74..548b49a109 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,6 +1,8 @@ # Upcoming Release ## Major features and improvements +* Refactored `kedro run` and `kedro catalog` commands. +* Moved pattern resolution logic from `DataCatalog` to a separate component - `CatalogConfigResolver`. Updated `DataCatalog` to use `CatalogConfigResolver` internally. * Made packaged Kedro projects return `session.run()` output to be used when running it in the interactive environment. * Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking. ## Bug fixes and other changes diff --git a/docs/source/conf.py b/docs/source/conf.py index 562f5a4b0e..2c3a2c4c00 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -127,6 +127,7 @@ "typing.Type", "typing.Set", "kedro.config.config.ConfigLoader", + "kedro.io.catalog_config_resolver.CatalogConfigResolver", "kedro.io.core.AbstractDataset", "kedro.io.core.AbstractVersionedDataset", "kedro.io.core.DatasetError", @@ -168,6 +169,7 @@ "D[k] if k in D, else d. d defaults to None.", "None. Update D from mapping/iterable E and F.", "Patterns", + "CatalogConfigResolver", ), "py:data": ( "typing.Any", diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 223980dade..7bd0197e5b 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -2,9 +2,8 @@ from __future__ import annotations -import copy from collections import defaultdict -from itertools import chain +from itertools import chain, filterfalse from typing import TYPE_CHECKING, Any import click @@ -28,6 +27,11 @@ def _create_session(package_name: str, **kwargs: Any) -> KedroSession: return KedroSession.create(**kwargs) +def is_parameter(dataset_name: str) -> bool: + """Check if dataset is a parameter.""" + return dataset_name.startswith("params:") or dataset_name == "parameters" + + @click.group(name="Kedro") def catalog_cli() -> None: # pragma: no cover pass @@ -88,21 +92,15 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: # resolve any factory datasets in the pipeline factory_ds_by_type = defaultdict(list) + for ds_name in default_ds: - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} + if data_catalog.config_resolver.match_pattern(ds_name): + ds_config = data_catalog.config_resolver.resolve_dataset_pattern( + ds_name ) - - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy + factory_ds_by_type[ds_config.get("type", "DefaultDataset")].append( + ds_name ) - factory_ds_by_type[ds_config["type"]].append(ds_name) default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values())) @@ -128,12 +126,10 @@ def _map_type_to_datasets( datasets of the specific type as a value. """ mapping = defaultdict(list) # type: ignore[var-annotated] - for dataset in datasets: - is_param = dataset.startswith("params:") or dataset == "parameters" - if not is_param: - ds_type = datasets_meta[dataset].__class__.__name__ - if dataset not in mapping[ds_type]: - mapping[ds_type].append(dataset) + for dataset_name in filterfalse(is_parameter, datasets): + ds_type = datasets_meta[dataset_name].__class__.__name__ + if dataset_name not in mapping[ds_type]: + mapping[ds_type].append(dataset_name) return mapping @@ -170,20 +166,12 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N f"'{pipeline_name}' pipeline not found! Existing pipelines: {existing_pipelines}" ) - pipe_datasets = { - ds_name - for ds_name in pipeline.datasets() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + pipeline_datasets = set(filterfalse(is_parameter, pipeline.datasets())) - catalog_datasets = { - ds_name - for ds_name in context.catalog._datasets.keys() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + catalog_datasets = set(filterfalse(is_parameter, context.catalog.list())) # Datasets that are missing in Data Catalog - missing_ds = sorted(pipe_datasets - catalog_datasets) + missing_ds = sorted(pipeline_datasets - catalog_datasets) if missing_ds: catalog_path = ( context.project_path @@ -221,12 +209,9 @@ def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: session = _create_session(metadata.package_name, env=env) context = session.load_context() - catalog_factories = { - **context.catalog._dataset_patterns, - **context.catalog._default_pattern, - } + catalog_factories = context.catalog.config_resolver.list_patterns() if catalog_factories: - click.echo(yaml.dump(list(catalog_factories.keys()))) + click.echo(yaml.dump(catalog_factories)) else: click.echo("There are no dataset factories in the catalog.") @@ -250,35 +235,25 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: explicit_datasets = { ds_name: ds_config for ds_name, ds_config in catalog_config.items() - if not data_catalog._is_pattern(ds_name) + if not data_catalog.config_resolver.is_pattern(ds_name) } target_pipelines = pipelines.keys() - datasets = set() + pipeline_datasets = set() for pipe in target_pipelines: pl_obj = pipelines.get(pipe) if pl_obj: - datasets.update(pl_obj.datasets()) + pipeline_datasets.update(pl_obj.datasets()) - for ds_name in datasets: - is_param = ds_name.startswith("params:") or ds_name == "parameters" - if ds_name in explicit_datasets or is_param: + for ds_name in pipeline_datasets: + if ds_name in explicit_datasets or is_parameter(ds_name): continue - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) + ds_config = data_catalog.config_resolver.resolve_dataset_pattern(ds_name) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy - ) + # Exclude MemoryDatasets not set in the catalog explicitly + if ds_config: explicit_datasets[ds_name] = ds_config secho(yaml.dump(explicit_datasets)) diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 91928f7c4b..caa3553954 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -394,13 +394,11 @@ def run( # noqa: PLR0913 run_params=record_data, pipeline=filtered_pipeline, catalog=catalog ) + if isinstance(runner, ThreadRunner): + for ds in filtered_pipeline.datasets(): + if catalog.config_resolver.match_pattern(ds): + _ = catalog._get_dataset(ds) try: - if isinstance(runner, ThreadRunner): - for ds in filtered_pipeline.datasets(): - if catalog._match_pattern( - catalog._dataset_patterns, ds - ) or catalog._match_pattern(catalog._default_pattern, ds): - _ = catalog._get_dataset(ds) run_result = runner.run( filtered_pipeline, catalog, hook_manager, session_id ) diff --git a/kedro/io/__init__.py b/kedro/io/__init__.py index aba59827e9..4b4a2e1b52 100644 --- a/kedro/io/__init__.py +++ b/kedro/io/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations from .cached_dataset import CachedDataset +from .catalog_config_resolver import CatalogConfigResolver from .core import ( AbstractDataset, AbstractVersionedDataset, @@ -23,6 +24,7 @@ "AbstractVersionedDataset", "CachedDataset", "DataCatalog", + "CatalogConfigResolver", "DatasetAlreadyExistsError", "DatasetError", "DatasetNotFoundError", diff --git a/kedro/io/catalog_config_resolver.py b/kedro/io/catalog_config_resolver.py new file mode 100644 index 0000000000..97ffbadd5f --- /dev/null +++ b/kedro/io/catalog_config_resolver.py @@ -0,0 +1,259 @@ +"""``CatalogConfigResolver`` resolves dataset configurations and datasets' +patterns based on catalog configuration and credentials provided. +""" + +from __future__ import annotations + +import copy +import logging +import re +from typing import Any, Dict + +from parse import parse + +from kedro.io.core import DatasetError + +Patterns = Dict[str, Dict[str, Any]] + +CREDENTIALS_KEY = "credentials" + + +class CatalogConfigResolver: + """Resolves dataset configurations based on patterns and credentials.""" + + def __init__( + self, + config: dict[str, dict[str, Any]] | None = None, + credentials: dict[str, dict[str, Any]] | None = None, + ): + self._runtime_patterns: Patterns = {} + self._dataset_patterns, self._default_pattern = self._extract_patterns( + config, credentials + ) + self._resolved_configs = self._resolve_config_credentials(config, credentials) + + @property + def config(self) -> dict[str, dict[str, Any]]: + return self._resolved_configs + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) + + @staticmethod + def is_pattern(pattern: str) -> bool: + """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" + return "{" in pattern + + @staticmethod + def _pattern_specificity(pattern: str) -> int: + """Calculate the specificity of a pattern based on characters outside curly brackets.""" + # Remove all the placeholders from the pattern and count the number of remaining chars + result = re.sub(r"\{.*?\}", "", pattern) + return len(result) + + @classmethod + def _sort_patterns(cls, dataset_patterns: Patterns) -> Patterns: + """Sort a dictionary of dataset patterns according to parsing rules. + + In order: + 1. Decreasing specificity (number of characters outside the curly brackets) + 2. Decreasing number of placeholders (number of curly bracket pairs) + 3. Alphabetically + """ + sorted_keys = sorted( + dataset_patterns, + key=lambda pattern: ( + -(cls._pattern_specificity(pattern)), + -pattern.count("{"), + pattern, + ), + ) + catch_all = [ + pattern for pattern in sorted_keys if cls._pattern_specificity(pattern) == 0 + ] + if len(catch_all) > 1: + raise DatasetError( + f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." + ) + return {key: dataset_patterns[key] for key in sorted_keys} + + @staticmethod + def _fetch_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: + """Fetch the specified credentials from the provided credentials dictionary. + + Args: + credentials_name: Credentials name. + credentials: A dictionary with all credentials. + + Returns: + The set of requested credentials. + + Raises: + KeyError: When a data set with the given name has not yet been + registered. + + """ + try: + return credentials[credentials_name] + except KeyError as exc: + raise KeyError( + f"Unable to find credentials '{credentials_name}': check your data " + "catalog and credentials configuration. See " + "https://kedro.readthedocs.io/en/stable/kedro.io.DataCatalog.html " + "for an example." + ) from exc + + @classmethod + def _resolve_credentials( + cls, config: dict[str, Any], credentials: dict[str, Any] + ) -> dict[str, Any]: + """Return the dataset configuration where credentials are resolved using + credentials dictionary provided. + + Args: + config: Original dataset config, which may contain unresolved credentials. + credentials: A dictionary with all credentials. + + Returns: + The dataset config, where all the credentials are successfully resolved. + """ + config = copy.deepcopy(config) + + def _resolve_value(key: str, value: Any) -> Any: + if key == CREDENTIALS_KEY and isinstance(value, str): + return cls._fetch_credentials(value, credentials) + if isinstance(value, dict): + return {k: _resolve_value(k, v) for k, v in value.items()} + return value + + return {k: _resolve_value(k, v) for k, v in config.items()} + + @classmethod + def _resolve_dataset_config( + cls, + ds_name: str, + pattern: str, + config: Any, + ) -> Any: + """Resolve dataset configuration based on the provided pattern.""" + resolved_vars = parse(pattern, ds_name) + # Resolve the factory config for the dataset + if isinstance(config, dict): + for key, value in config.items(): + config[key] = cls._resolve_dataset_config(ds_name, pattern, value) + elif isinstance(config, (list, tuple)): + config = [ + cls._resolve_dataset_config(ds_name, pattern, value) for value in config + ] + elif isinstance(config, str) and "}" in config: + try: + config = config.format_map(resolved_vars.named) + except KeyError as exc: + raise DatasetError( + f"Unable to resolve '{config}' from the pattern '{pattern}'. Keys used in the configuration " + f"should be present in the dataset factory pattern." + ) from exc + return config + + def list_patterns(self) -> list[str]: + """List al patterns available in the catalog.""" + return ( + list(self._dataset_patterns.keys()) + + list(self._default_pattern.keys()) + + list(self._runtime_patterns.keys()) + ) + + def match_pattern(self, ds_name: str) -> str | None: + """Match a dataset name against patterns in a dictionary.""" + all_patterns = self.list_patterns() + matches = (pattern for pattern in all_patterns if parse(pattern, ds_name)) + return next(matches, None) + + def _get_pattern_config(self, pattern: str) -> dict[str, Any]: + return ( + self._dataset_patterns.get(pattern) + or self._default_pattern.get(pattern) + or self._runtime_patterns.get(pattern) + or {} + ) + + @classmethod + def _extract_patterns( + cls, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> tuple[Patterns, Patterns]: + """Extract and sort patterns from the configuration.""" + config = config or {} + credentials = credentials or {} + dataset_patterns = {} + user_default = {} + + for ds_name, ds_config in config.items(): + if cls.is_pattern(ds_name): + dataset_patterns[ds_name] = cls._resolve_credentials( + ds_config, credentials + ) + + sorted_patterns = cls._sort_patterns(dataset_patterns) + if sorted_patterns: + # If the last pattern is a catch-all pattern, pop it and set it as the default + if cls._pattern_specificity(list(sorted_patterns.keys())[-1]) == 0: + last_pattern = sorted_patterns.popitem() + user_default = {last_pattern[0]: last_pattern[1]} + + return sorted_patterns, user_default + + def _resolve_config_credentials( + self, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> dict[str, dict[str, Any]]: + """Initialize the dataset configuration with resolved credentials.""" + config = config or {} + credentials = credentials or {} + resolved_configs = {} + + for ds_name, ds_config in config.items(): + if not isinstance(ds_config, dict): + raise DatasetError( + f"Catalog entry '{ds_name}' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + if not self.is_pattern(ds_name): + resolved_configs[ds_name] = self._resolve_credentials( + ds_config, credentials + ) + + return resolved_configs + + def resolve_dataset_pattern(self, ds_name: str) -> dict[str, Any]: + """Resolve dataset patterns and return resolved configurations based on the existing patterns.""" + matched_pattern = self.match_pattern(ds_name) + + if matched_pattern and ds_name not in self._resolved_configs: + pattern_config = self._get_pattern_config(matched_pattern) + ds_config = self._resolve_dataset_config( + ds_name, matched_pattern, copy.deepcopy(pattern_config) + ) + + if ( + self._pattern_specificity(matched_pattern) == 0 + and matched_pattern in self._default_pattern + ): + self._logger.warning( + "Config from the dataset factory pattern '%s' in the catalog will be used to " + "override the default dataset creation for '%s'", + matched_pattern, + ds_name, + ) + return ds_config # type: ignore[no-any-return] + + return self._resolved_configs.get(ds_name, {}) + + def add_runtime_patterns(self, dataset_patterns: Patterns) -> None: + """Add new runtime patterns and re-sort them.""" + self._runtime_patterns = {**self._runtime_patterns, **dataset_patterns} + self._runtime_patterns = self._sort_patterns(self._runtime_patterns) diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index d3fd163230..420f8857c8 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -7,15 +7,17 @@ from __future__ import annotations -import copy import difflib import logging import pprint import re -from typing import Any, Dict - -from parse import parse +from typing import Any +from kedro.io.catalog_config_resolver import ( + CREDENTIALS_KEY, # noqa: F401 + CatalogConfigResolver, + Patterns, +) from kedro.io.core import ( AbstractDataset, AbstractVersionedDataset, @@ -28,64 +30,10 @@ from kedro.io.memory_dataset import MemoryDataset from kedro.utils import _format_rich, _has_rich_handler -Patterns = Dict[str, Dict[str, Any]] - -CATALOG_KEY = "catalog" -CREDENTIALS_KEY = "credentials" +CATALOG_KEY = "catalog" # Kept to avoid the breaking change WORDS_REGEX_PATTERN = re.compile(r"\W+") -def _get_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: - """Return a set of credentials from the provided credentials dict. - - Args: - credentials_name: Credentials name. - credentials: A dictionary with all credentials. - - Returns: - The set of requested credentials. - - Raises: - KeyError: When a data set with the given name has not yet been - registered. - - """ - try: - return credentials[credentials_name] - except KeyError as exc: - raise KeyError( - f"Unable to find credentials '{credentials_name}': check your data " - "catalog and credentials configuration. See " - "https://docs.kedro.org/en/stable/api/kedro.io.DataCatalog.html " - "for an example." - ) from exc - - -def _resolve_credentials( - config: dict[str, Any], credentials: dict[str, Any] -) -> dict[str, Any]: - """Return the dataset configuration where credentials are resolved using - credentials dictionary provided. - - Args: - config: Original dataset config, which may contain unresolved credentials. - credentials: A dictionary with all credentials. - - Returns: - The dataset config, where all the credentials are successfully resolved. - """ - config = copy.deepcopy(config) - - def _map_value(key: str, value: Any) -> Any: - if key == CREDENTIALS_KEY and isinstance(value, str): - return _get_credentials(value, credentials) - if isinstance(value, dict): - return {k: _map_value(k, v) for k, v in value.items()} - return value - - return {k: _map_value(k, v) for k, v in config.items()} - - def _sub_nonword_chars(dataset_name: str) -> str: """Replace non-word characters in data set names since Kedro 0.16.2. @@ -103,13 +51,15 @@ class _FrozenDatasets: def __init__( self, - *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset], + *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset] | None, ): """Return a _FrozenDatasets instance from some datasets collections. Each collection could either be another _FrozenDatasets or a dictionary. """ self._original_names: dict[str, str] = {} for collection in datasets_collections: + if collection is None: + continue if isinstance(collection, _FrozenDatasets): self.__dict__.update(collection.__dict__) self._original_names.update(collection._original_names) @@ -161,10 +111,11 @@ def __init__( # noqa: PLR0913 self, datasets: dict[str, AbstractDataset] | None = None, feed_dict: dict[str, Any] | None = None, - dataset_patterns: Patterns | None = None, + dataset_patterns: Patterns | None = None, # Kept for interface compatibility load_versions: dict[str, str] | None = None, save_version: str | None = None, - default_pattern: Patterns | None = None, + default_pattern: Patterns | None = None, # Kept for interface compatibility + config_resolver: CatalogConfigResolver | None = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataset`` implementations to provide ``load`` and ``save`` capabilities from @@ -195,6 +146,8 @@ def __init__( # noqa: PLR0913 sorted in lexicographical order. default_pattern: A dictionary of the default catch-all pattern that overrides the default pattern provided through the runners. + config_resolver: An instance of CatalogConfigResolver to resolve dataset patterns and configurations. + Example: :: @@ -206,14 +159,21 @@ def __init__( # noqa: PLR0913 >>> save_args={"index": False}) >>> catalog = DataCatalog(datasets={'cars': cars}) """ - self._datasets = dict(datasets or {}) - self.datasets = _FrozenDatasets(self._datasets) - # Keep a record of all patterns in the catalog. - # {dataset pattern name : dataset pattern body} - self._dataset_patterns = dataset_patterns or {} + self._config_resolver = config_resolver or CatalogConfigResolver() + + # Kept to avoid breaking changes + if not config_resolver: + self._config_resolver._dataset_patterns = dataset_patterns or {} + self._config_resolver._default_pattern = default_pattern or {} + + self._datasets: dict[str, AbstractDataset] = {} + self.datasets: _FrozenDatasets | None = None + + self.add_all(datasets or {}) + self._load_versions = load_versions or {} self._save_version = save_version - self._default_pattern = default_pattern or {} + self._use_rich_markup = _has_rich_handler() if feed_dict: @@ -222,6 +182,23 @@ def __init__( # noqa: PLR0913 def __repr__(self) -> str: return self.datasets.__repr__() + def __contains__(self, dataset_name: str) -> bool: + """Check if an item is in the catalog as a materialised dataset or pattern""" + return ( + dataset_name in self._datasets + or self._config_resolver.match_pattern(dataset_name) is not None + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return (self._datasets, self._config_resolver.list_patterns()) == ( + other._datasets, + other.config_resolver.list_patterns(), + ) + + @property + def config_resolver(self) -> CatalogConfigResolver: + return self._config_resolver + @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @@ -303,44 +280,28 @@ class to be loaded is specified with the key ``type`` and their >>> df = catalog.load("cars") >>> catalog.save("boats", df) """ + catalog = catalog or {} datasets = {} - dataset_patterns = {} - catalog = copy.deepcopy(catalog) or {} - credentials = copy.deepcopy(credentials) or {} + config_resolver = CatalogConfigResolver(catalog, credentials) save_version = save_version or generate_timestamp() - load_versions = copy.deepcopy(load_versions) or {} - user_default = {} - - for ds_name, ds_config in catalog.items(): - if not isinstance(ds_config, dict): - raise DatasetError( - f"Catalog entry '{ds_name}' is not a valid dataset configuration. " - "\nHint: If this catalog entry is intended for variable interpolation, " - "make sure that the key is preceded by an underscore." - ) + load_versions = load_versions or {} - ds_config = _resolve_credentials( # noqa: PLW2901 - ds_config, credentials - ) - if cls._is_pattern(ds_name): - # Add each factory to the dataset_patterns dict. - dataset_patterns[ds_name] = ds_config - - else: + for ds_name in catalog: + if not config_resolver.is_pattern(ds_name): datasets[ds_name] = AbstractDataset.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version + ds_name, + config_resolver.config.get(ds_name, {}), + load_versions.get(ds_name), + save_version, ) - sorted_patterns = cls._sort_patterns(dataset_patterns) - if sorted_patterns: - # If the last pattern is a catch-all pattern, pop it and set it as the default - if cls._specificity(list(sorted_patterns.keys())[-1]) == 0: - last_pattern = sorted_patterns.popitem() - user_default = {last_pattern[0]: last_pattern[1]} missing_keys = [ - key - for key in load_versions.keys() - if not (key in catalog or cls._match_pattern(sorted_patterns, key)) + ds_name + for ds_name in load_versions + if not ( + ds_name in config_resolver.config + or config_resolver.match_pattern(ds_name) + ) ] if missing_keys: raise DatasetNotFoundError( @@ -350,107 +311,29 @@ class to be loaded is specified with the key ``type`` and their return cls( datasets=datasets, - dataset_patterns=sorted_patterns, + dataset_patterns=config_resolver._dataset_patterns, load_versions=load_versions, save_version=save_version, - default_pattern=user_default, + default_pattern=config_resolver._default_pattern, + config_resolver=config_resolver, ) - @staticmethod - def _is_pattern(pattern: str) -> bool: - """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" - return "{" in pattern - - @staticmethod - def _match_pattern(dataset_patterns: Patterns, dataset_name: str) -> str | None: - """Match a dataset name against patterns in a dictionary.""" - matches = ( - pattern - for pattern in dataset_patterns.keys() - if parse(pattern, dataset_name) - ) - return next(matches, None) - - @classmethod - def _sort_patterns(cls, dataset_patterns: Patterns) -> dict[str, dict[str, Any]]: - """Sort a dictionary of dataset patterns according to parsing rules. - - In order: - - 1. Decreasing specificity (number of characters outside the curly brackets) - 2. Decreasing number of placeholders (number of curly bracket pairs) - 3. Alphabetically - """ - sorted_keys = sorted( - dataset_patterns, - key=lambda pattern: ( - -(cls._specificity(pattern)), - -pattern.count("{"), - pattern, - ), - ) - catch_all = [ - pattern for pattern in sorted_keys if cls._specificity(pattern) == 0 - ] - if len(catch_all) > 1: - raise DatasetError( - f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." - ) - return {key: dataset_patterns[key] for key in sorted_keys} - - @staticmethod - def _specificity(pattern: str) -> int: - """Helper function to check the length of exactly matched characters not inside brackets. - - Example: - :: - - >>> specificity("{namespace}.companies") = 10 - >>> specificity("{namespace}.{dataset}") = 1 - >>> specificity("france.companies") = 16 - """ - # Remove all the placeholders from the pattern and count the number of remaining chars - result = re.sub(r"\{.*?\}", "", pattern) - return len(result) - def _get_dataset( self, dataset_name: str, version: Version | None = None, suggest: bool = True, ) -> AbstractDataset: - matched_pattern = self._match_pattern( - self._dataset_patterns, dataset_name - ) or self._match_pattern(self._default_pattern, dataset_name) - if dataset_name not in self._datasets and matched_pattern: - # If the dataset is a patterned dataset, materialise it and add it to - # the catalog - config_copy = copy.deepcopy( - self._dataset_patterns.get(matched_pattern) - or self._default_pattern.get(matched_pattern) - or {} - ) - dataset_config = self._resolve_config( - dataset_name, matched_pattern, config_copy - ) - dataset = AbstractDataset.from_config( + ds_config = self._config_resolver.resolve_dataset_pattern(dataset_name) + + if dataset_name not in self._datasets and ds_config: + ds = AbstractDataset.from_config( dataset_name, - dataset_config, + ds_config, self._load_versions.get(dataset_name), self._save_version, ) - if ( - self._specificity(matched_pattern) == 0 - and matched_pattern in self._default_pattern - ): - self._logger.warning( - "Config from the dataset factory pattern '%s' in the catalog will be used to " - "override the default dataset creation for '%s'", - matched_pattern, - dataset_name, - ) - - self.add(dataset_name, dataset) + self.add(dataset_name, ds) if dataset_name not in self._datasets: error_msg = f"Dataset '{dataset_name}' not found in the catalog" @@ -462,7 +345,9 @@ def _get_dataset( suggestions = ", ".join(matches) error_msg += f" - did you mean one of these instead: {suggestions}" raise DatasetNotFoundError(error_msg) + dataset = self._datasets[dataset_name] + if version and isinstance(dataset, AbstractVersionedDataset): # we only want to return a similar-looking dataset, # not modify the one stored in the current catalog @@ -470,41 +355,6 @@ def _get_dataset( return dataset - def __contains__(self, dataset_name: str) -> bool: - """Check if an item is in the catalog as a materialised dataset or pattern""" - matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) - if dataset_name in self._datasets or matched_pattern: - return True - return False - - @classmethod - def _resolve_config( - cls, - dataset_name: str, - matched_pattern: str, - config: dict, - ) -> dict[str, Any]: - """Get resolved AbstractDataset from a factory config""" - result = parse(matched_pattern, dataset_name) - # Resolve the factory config for the dataset - if isinstance(config, dict): - for key, value in config.items(): - config[key] = cls._resolve_config(dataset_name, matched_pattern, value) - elif isinstance(config, (list, tuple)): - config = [ - cls._resolve_config(dataset_name, matched_pattern, value) - for value in config - ] - elif isinstance(config, str) and "}" in config: - try: - config = str(config).format_map(result.named) - except KeyError as exc: - raise DatasetError( - f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration " - f"should be present in the dataset factory pattern." - ) from exc - return config - def load(self, name: str, version: str | None = None) -> Any: """Loads a registered data set. @@ -619,7 +469,10 @@ def release(self, name: str) -> None: dataset.release() def add( - self, dataset_name: str, dataset: AbstractDataset, replace: bool = False + self, + dataset_name: str, + dataset: AbstractDataset, + replace: bool = False, ) -> None: """Adds a new ``AbstractDataset`` object to the ``DataCatalog``. @@ -657,7 +510,9 @@ def add( self.datasets = _FrozenDatasets(self.datasets, {dataset_name: dataset}) def add_all( - self, datasets: dict[str, AbstractDataset], replace: bool = False + self, + datasets: dict[str, AbstractDataset], + replace: bool = False, ) -> None: """Adds a group of new data sets to the ``DataCatalog``. @@ -688,8 +543,8 @@ def add_all( >>> >>> assert catalog.list() == ["cars", "planes", "boats"] """ - for name, dataset in datasets.items(): - self.add(name, dataset, replace) + for ds_name, ds in datasets.items(): + self.add(ds_name, ds, replace) def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None: """Add datasets to the ``DataCatalog`` using the data provided through the `feed_dict`. @@ -726,13 +581,13 @@ def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> Non >>> >>> assert catalog.load("data_csv_dataset").equals(df) """ - for dataset_name in feed_dict: - if isinstance(feed_dict[dataset_name], AbstractDataset): - dataset = feed_dict[dataset_name] - else: - dataset = MemoryDataset(data=feed_dict[dataset_name]) # type: ignore[abstract] - - self.add(dataset_name, dataset, replace) + for ds_name, ds_data in feed_dict.items(): + dataset = ( + ds_data + if isinstance(ds_data, AbstractDataset) + else MemoryDataset(data=ds_data) # type: ignore[abstract] + ) + self.add(ds_name, dataset, replace) def list(self, regex_search: str | None = None) -> list[str]: """ @@ -777,7 +632,7 @@ def list(self, regex_search: str | None = None) -> list[str]: raise SyntaxError( f"Invalid regular expression provided: '{regex_search}'" ) from exc - return [dset_name for dset_name in self._datasets if pattern.search(dset_name)] + return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] def shallow_copy( self, extra_dataset_patterns: Patterns | None = None @@ -787,26 +642,15 @@ def shallow_copy( Returns: Copy of the current object. """ - if not self._default_pattern and extra_dataset_patterns: - unsorted_dataset_patterns = { - **self._dataset_patterns, - **extra_dataset_patterns, - } - dataset_patterns = self._sort_patterns(unsorted_dataset_patterns) - else: - dataset_patterns = self._dataset_patterns + if extra_dataset_patterns: + self._config_resolver.add_runtime_patterns(extra_dataset_patterns) return self.__class__( datasets=self._datasets, - dataset_patterns=dataset_patterns, + dataset_patterns=self._config_resolver._dataset_patterns, + default_pattern=self._config_resolver._default_pattern, load_versions=self._load_versions, save_version=self._save_version, - default_pattern=self._default_pattern, - ) - - def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] - return (self._datasets, self._dataset_patterns) == ( - other._datasets, - other._dataset_patterns, + config_resolver=self._config_resolver, ) def confirm(self, name: str) -> None: diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 2ffd0389e4..6f165e87c0 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -83,7 +83,6 @@ def run( """ hook_or_null_manager = hook_manager or _NullPluginManager() - catalog = catalog.shallow_copy() # Check which datasets used in the pipeline are in the catalog or match # a pattern in the catalog diff --git a/tests/framework/cli/test_catalog.py b/tests/framework/cli/test_catalog.py index f34034296e..7a61c9e7a0 100644 --- a/tests/framework/cli/test_catalog.py +++ b/tests/framework/cli/test_catalog.py @@ -490,7 +490,6 @@ def test_rank_catalog_factories( mocked_context.catalog = DataCatalog.from_config( fake_catalog_with_overlapping_factories ) - print("!!!!", mocked_context.catalog._dataset_patterns) result = CliRunner().invoke( fake_project_cli, ["catalog", "rank"], obj=fake_metadata ) @@ -547,7 +546,7 @@ def test_catalog_resolve( mocked_context.catalog = DataCatalog.from_config( catalog=fake_catalog_config, credentials=fake_credentials_config ) - placeholder_ds = mocked_context.catalog._dataset_patterns.keys() + placeholder_ds = mocked_context.catalog.config_resolver.list_patterns() pipeline_datasets = {"csv_example", "parquet_example", "explicit_dataset"} mocker.patch.object( diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index 83550f3a56..086d581045 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -693,7 +693,7 @@ def test_run_thread_runner( } mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret) mocker.patch( - "kedro.io.data_catalog.DataCatalog._match_pattern", + "kedro.io.data_catalog.CatalogConfigResolver.match_pattern", return_value=match_pattern, ) diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index dbec57e64d..db777cc634 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -846,7 +846,7 @@ def test_match_added_to_datasets_on_get(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "{brand}_cars" not in catalog._datasets assert "tesla_cars" not in catalog._datasets - assert "{brand}_cars" in catalog._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns tesla_cars = catalog._get_dataset("tesla_cars") assert isinstance(tesla_cars, CSVDataset) @@ -875,8 +875,8 @@ def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "audi_cars" in catalog._datasets assert "{brand}_cars" not in catalog._datasets - assert "audi_cars" not in catalog._dataset_patterns - assert "{brand}_cars" in catalog._dataset_patterns + assert "audi_cars" not in catalog.config_resolver._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns def test_explicit_entry_not_overwritten(self, config_with_dataset_factories): """Check that the existing catalog entry is not overwritten by config in pattern""" @@ -909,11 +909,7 @@ def test_sorting_order_patterns(self, config_with_dataset_factories_only_pattern "{dataset}s", "{user_default}", ] - assert ( - list(catalog._dataset_patterns.keys()) - + list(catalog._default_pattern.keys()) - == sorted_keys_expected - ) + assert catalog.config_resolver.list_patterns() == sorted_keys_expected def test_multiple_catch_all_patterns_not_allowed( self, config_with_dataset_factories @@ -953,13 +949,13 @@ def test_sorting_order_with_other_dataset_through_extra_pattern( ) sorted_keys_expected = [ "{country}_companies", - "{another}#csv", "{namespace}_{dataset}", "{dataset}s", + "{another}#csv", "{default}", ] assert ( - list(catalog_with_default._dataset_patterns.keys()) == sorted_keys_expected + catalog_with_default.config_resolver.list_patterns() == sorted_keys_expected ) def test_user_default_overwrites_runner_default(self): @@ -988,11 +984,15 @@ def test_user_default_overwrites_runner_default(self): sorted_keys_expected = [ "{dataset}s", "{a_default}", + "{another}#csv", + "{default}", ] - assert "{a_default}" in catalog_with_runner_default._default_pattern assert ( - list(catalog_with_runner_default._dataset_patterns.keys()) - + list(catalog_with_runner_default._default_pattern.keys()) + "{a_default}" + in catalog_with_runner_default.config_resolver._default_pattern + ) + assert ( + catalog_with_runner_default.config_resolver.list_patterns() == sorted_keys_expected )