Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset factories #2635

Merged
merged 29 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec66a12
Cleaned up and up to date version of dataset factories code
merelcht May 22, 2023
5e6c15d
Add some simple tests
merelcht May 23, 2023
0fca72c
Add parsing rules
ankatiyar Jun 2, 2023
06ed1a4
Refactor
ankatiyar Jun 8, 2023
b0e3fb9
Add some tests
ankatiyar Jun 8, 2023
0833af2
Add unit tests
ankatiyar Jun 12, 2023
8fc80f9
Fix test + refactor runner
ankatiyar Jun 12, 2023
091f794
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 12, 2023
8c192ee
Add comments + update specificity fn
ankatiyar Jun 13, 2023
3e2642c
Update function names
ankatiyar Jun 15, 2023
c2635d0
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 15, 2023
d310486
Update test
ankatiyar Jun 15, 2023
573c67f
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 19, 2023
9d80de4
Release notes + update resume scenario fix
ankatiyar Jun 19, 2023
549823f
revert change to suggest resume scenario
ankatiyar Jun 19, 2023
e052ae6
Update tests DataSet->Dataset
ankatiyar Jun 19, 2023
96c219f
Small refactor + move parsing rules to a new fn
ankatiyar Jun 20, 2023
c2634e5
Fix problem with load_version + refactor
ankatiyar Jul 3, 2023
eee606a
linting + small fix _get_datasets
ankatiyar Jul 3, 2023
635510a
Remove check for existence
ankatiyar Jul 4, 2023
394f37b
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 4, 2023
a1c602d
Add updated tests + Release notes
ankatiyar Jul 5, 2023
978d0a5
change classmethod to staticmethod for _match_patterns
ankatiyar Jul 5, 2023
b4fe7a7
Add test for layer
ankatiyar Jul 5, 2023
2782dca
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 5, 2023
85d3df1
Minor change from code review
ankatiyar Jul 5, 2023
8904ce3
Remove type conversion
ankatiyar Jul 6, 2023
bdc953d
Add warning for catch-all patterns [dataset factories] (#2774)
ankatiyar Jul 6, 2023
fa6c256
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse~=1.19.0
pip-tools~=6.5
pluggy~=1.0.0
PyYAML>=4.2, <7.0
Expand Down
148 changes: 123 additions & 25 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
import re
from collections import defaultdict
from typing import Any
from typing import Any, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -94,6 +96,23 @@ def _sub_nonword_chars(data_set_name: str) -> str:
return re.sub(WORDS_REGEX_PATTERN, "__", data_set_name)


def _specificity(pattern: str) -> int:
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
"""Helper function to check length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16
Args:
pattern:
Returns:
"""
pattern_variables = parse(pattern, pattern).named
for k in pattern_variables:
pattern_variables[k] = ""
specific_characters = pattern.format(**pattern_variables)
return -len(specific_characters)


class _FrozenDatasets:
"""Helper class to access underlying loaded datasets"""

Expand Down Expand Up @@ -141,6 +160,7 @@ def __init__(
data_sets: dict[str, AbstractDataSet] = None,
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, Any] = None,
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -170,7 +190,14 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers

# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self.dataset_patterns = dict(dataset_patterns or {})
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
self._sorted_dataset_patterns = sorted(
self.dataset_patterns.keys(),
key=lambda x: (_specificity(x), -x.count("{"), x),
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
)
self._pattern_name_matches_cache: dict[str, str] = {}
merelcht marked this conversation as resolved.
Show resolved Hide resolved
# import the feed dict
if feed_dict:
self.add_feed_dict(feed_dict)
Expand Down Expand Up @@ -257,6 +284,7 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
Expand All @@ -271,35 +299,54 @@ class to be loaded is specified with the key ``type`` and their

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
# Assume that any name with } in it is a dataset factory to be matched.
if "}" in ds_name:
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=dataset_patterns,
)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
if data_set_name not in self._data_sets:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
error_msg = f"DataSet '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"

raise DataSetNotFoundError(error_msg)
# When a dataset is "used" in the pipeline that's not in the recorded catalog datasets,
# try to match it against the data factories in the catalog. If it's a match,
# resolve it to a dataset instance and add it to the catalog, so it only needs
# to be matched once and not everytime the dataset is used in the pipeline.
if self.exists_in_catalog(data_set_name):
pattern = self._pattern_name_matches_cache[data_set_name]
matched_dataset = self._get_resolved_dataset(data_set_name, pattern)
self.add(data_set_name, matched_dataset)
else:
error_msg = f"DataSet '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
# slow down plugins like `kedro-viz`
if suggest:
matches = difflib.get_close_matches(
data_set_name, self._data_sets.keys()
)
if matches:
suggestions = ", ".join(matches)
error_msg += (
f" - did you mean one of these instead: {suggestions}"
)

raise DataSetNotFoundError(error_msg)

data_set = self._data_sets[data_set_name]
if version and isinstance(data_set, AbstractVersionedDataSet):
Expand All @@ -311,6 +358,28 @@ def _get_dataset(

return data_set

def _get_resolved_dataset(
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
self, dataset_name: str, matched_pattern: str
) -> AbstractDataSet:
result = parse(matched_pattern, dataset_name)
merelcht marked this conversation as resolved.
Show resolved Hide resolved
template_copy = copy.deepcopy(self.dataset_patterns[matched_pattern])
for key, value in template_copy.items():
# Find all dataset fields that need to be resolved with
# the values that were matched.
if isinstance(value, Iterable) and "}" in value:
string_value = str(value)
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
template_copy[key] = string_value.format_map(result.named)
except KeyError as exc:
raise DataSetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
# Create dataset from catalog template.
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
return AbstractDataSet.from_config(dataset_name, template_copy)

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.

Expand Down Expand Up @@ -567,13 +636,42 @@ def list(self, regex_search: str | None = None) -> list[str]:
) from exc
return [dset_name for dset_name in self._data_sets if pattern.search(dset_name)]

def exists_in_catalog(self, dataset_name: str) -> bool:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
"""Check if a dataset exists in the catalog as an exact match or if it matches a pattern."""
if (
dataset_name in self._data_sets
or dataset_name in self._pattern_name_matches_cache
):
return True
matched_pattern = self.match_name_against_pattern(dataset_name)
if self.dataset_patterns and matched_pattern:
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
# cache the "dataset_name -> pattern" match
self._pattern_name_matches_cache[dataset_name] = matched_pattern
return True
return False

def match_name_against_pattern(self, dataset_name: str) -> str | None:
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
"""Match a dataset name against existing patterns"""
# Loop through all dataset patterns and check if the given dataset name has a match.
for pattern in self._sorted_dataset_patterns:
result = parse(pattern, dataset_name)
# If there's a match resolve the rest of the pattern template to create
# a dataset instance. A result can be None or contain a dictionary of matched items:
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
if result:
return pattern
return None

def shallow_copy(self) -> DataCatalog:
"""Returns a shallow copy of the current object.

Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets,
layers=self.layers,
dataset_patterns=self.dataset_patterns,
)

def __eq__(self, other):
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
Expand Down
25 changes: 22 additions & 3 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,33 @@ def run(
hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
# Check which datasets used in the pipeline aren't in the catalog and don't match
# a pattern in the catalog
unregistered_ds = [
ds for ds in pipeline.data_sets() if not catalog.exists_in_catalog(ds)
]

# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = [
input_name
for input_name in pipeline.inputs()
if input_name in unregistered_ds
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using comprehension here makes it look more complicated than it is. Could we keep the original code here where we are using simple set difference?

Suggested change
# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = [
input_name
for input_name in pipeline.inputs()
if input_name in unregistered_ds
]
unsatisfied = pipeline.inputs() - set(registered_ds)

This reads in English as "give me all inputs minus all the registered ones".

if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog.
free_outputs = [
output_name
for output_name in pipeline.outputs()
if output_name in unregistered_ds
]

# Create a default dataset for unregistered datasets
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down
Loading