diff --git a/RELEASE.md b/RELEASE.md index a5e34a6ba8..f4b10035b7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,6 +1,10 @@ # Upcoming Release ## Major features and improvements +* Implemented dict-like interface for `KedroDataCatalog`. + +**Note:** ``KedroDataCatalog`` is an experimental feature and is under active development. Therefore, it is possible we'll introduce breaking changes to this class, so be mindful of that if you decide to use it already. Let us know if you have any feedback about the ``KedroDataCatalog`` or ideas for new features. + ## Bug fixes and other changes ## Breaking changes to the API ## Documentation changes diff --git a/kedro/io/kedro_data_catalog.py b/kedro/io/kedro_data_catalog.py index d07de8151a..c3d216abcd 100644 --- a/kedro/io/kedro_data_catalog.py +++ b/kedro/io/kedro_data_catalog.py @@ -14,7 +14,7 @@ import difflib import logging import re -from typing import Any +from typing import Any, Iterator, List # noqa: UP035 from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns from kedro.io.core import ( @@ -84,10 +84,12 @@ def __init__( @property def datasets(self) -> dict[str, Any]: + # TODO: remove when removing old catalog return copy.copy(self._datasets) @datasets.setter def datasets(self, value: Any) -> None: + # TODO: remove when removing old catalog raise AttributeError( "Operation not allowed. Please use KedroDataCatalog.add() instead." ) @@ -112,6 +114,49 @@ def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] other.config_resolver.list_patterns(), ) + def keys(self) -> List[str]: # noqa: UP006 + return list(self.__iter__()) + + def values(self) -> List[AbstractDataset]: # noqa: UP006 + return [self._datasets[key] for key in self] + + def items(self) -> List[tuple[str, AbstractDataset]]: # noqa: UP006 + return [(key, self._datasets[key]) for key in self] + + def __iter__(self) -> Iterator[str]: + yield from self._datasets.keys() + + def __getitem__(self, ds_name: str) -> AbstractDataset: + return self.get_dataset(ds_name) + + def __setitem__(self, key: str, value: Any) -> None: + if key in self._datasets: + self._logger.warning("Replacing dataset '%s'", key) + if isinstance(value, AbstractDataset): + self._datasets[key] = value + else: + self._logger.info(f"Adding input data as a MemoryDataset - {key}") + self._datasets[key] = MemoryDataset(data=value) # type: ignore[abstract] + + def __len__(self) -> int: + return len(self.keys()) + + def get( + self, key: str, default: AbstractDataset | None = None + ) -> AbstractDataset | None: + """Get a dataset by name from an internal collection of datasets.""" + if key not in self._datasets: + ds_config = self._config_resolver.resolve_pattern(key) + if ds_config: + self._add_from_config(key, ds_config) + + dataset = self._datasets.get(key, None) + + return dataset or default + + def _ipython_key_completions_(self) -> list[str]: + return list(self._datasets.keys()) + @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @@ -178,6 +223,7 @@ def _add_from_config(self, ds_name: str, ds_config: dict[str, Any]) -> None: def get_dataset( self, ds_name: str, version: Version | None = None, suggest: bool = True ) -> AbstractDataset: + # TODO: remove when removing old catalog """Get a dataset by name from an internal collection of datasets. If a dataset is not in the collection but matches any pattern @@ -197,12 +243,7 @@ def get_dataset( DatasetNotFoundError: When a dataset with the given name is not in the collection and do not match patterns. """ - if ds_name not in self._datasets: - ds_config = self._config_resolver.resolve_pattern(ds_name) - if ds_config: - self._add_from_config(ds_name, ds_config) - - dataset = self._datasets.get(ds_name, None) + dataset = self.get(ds_name) if dataset is None: error_msg = f"Dataset '{ds_name}' not found in the catalog" @@ -231,40 +272,71 @@ def _get_dataset( def add( self, ds_name: str, dataset: AbstractDataset, replace: bool = False ) -> None: + # TODO: remove when removing old catalog """Adds a new ``AbstractDataset`` object to the ``KedroDataCatalog``.""" - if ds_name in self._datasets: - if replace: - self._logger.warning("Replacing dataset '%s'", ds_name) - else: - raise DatasetAlreadyExistsError( - f"Dataset '{ds_name}' has already been registered" - ) - self._datasets[ds_name] = dataset - - def list(self, regex_search: str | None = None) -> list[str]: + if ds_name in self._datasets and not replace: + raise DatasetAlreadyExistsError( + f"Dataset '{ds_name}' has already been registered" + ) + self.__setitem__(ds_name, dataset) + + def list( + self, regex_search: str | None = None, regex_flags: int | re.RegexFlag = 0 + ) -> List[str]: # noqa: UP006 + # TODO: rename depending on the solution for https://github.com/kedro-org/kedro/issues/3917 """ List of all dataset names registered in the catalog. This can be filtered by providing an optional regular expression which will only return matching keys. """ - if regex_search is None: - return list(self._datasets.keys()) + return self.keys() - if not regex_search.strip(): + if regex_search == "": self._logger.warning("The empty string will not match any datasets") return [] + if not regex_flags: + regex_flags = re.IGNORECASE + try: - pattern = re.compile(regex_search, flags=re.IGNORECASE) + pattern = re.compile(regex_search, flags=regex_flags) except re.error as exc: raise SyntaxError( f"Invalid regular expression provided: '{regex_search}'" ) from exc - return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] + return [ds_name for ds_name in self.__iter__() if pattern.search(ds_name)] def save(self, name: str, data: Any) -> None: - """Save data to a registered dataset.""" + # TODO: rename input argument when breaking change: name -> ds_name + """Save data to a registered dataset. + + Args: + name: A dataset to be saved to. + data: A data object to be saved as configured in the registered + dataset. + + Raises: + DatasetNotFoundError: When a dataset with the given name + has not yet been registered. + + Example: + :: + + >>> import pandas as pd + >>> + >>> from kedro_datasets.pandas import CSVDataset + >>> + >>> cars = CSVDataset(filepath="cars.csv", + >>> load_args=None, + >>> save_args={"index": False}) + >>> catalog = DataCatalog(datasets={'cars': cars}) + >>> + >>> df = pd.DataFrame({'col1': [1, 2], + >>> 'col2': [4, 5], + >>> 'col3': [5, 6]}) + >>> catalog.save("cars", df) + """ dataset = self.get_dataset(name) self._logger.info( @@ -277,7 +349,35 @@ def save(self, name: str, data: Any) -> None: dataset.save(data) def load(self, name: str, version: str | None = None) -> Any: - """Loads a registered dataset.""" + # TODO: rename input argument when breaking change: name -> ds_name + # TODO: remove version from input arguments when breaking change + """Loads a registered dataset. + + Args: + name: A dataset to be loaded. + version: Optional argument for concrete data version to be loaded. + Works only with versioned datasets. + + Returns: + The loaded data as configured. + + Raises: + DatasetNotFoundError: When a dataset with the given name + has not yet been registered. + + Example: + :: + + >>> from kedro.io import DataCatalog + >>> from kedro_datasets.pandas import CSVDataset + >>> + >>> cars = CSVDataset(filepath="cars.csv", + >>> load_args=None, + >>> save_args={"index": False}) + >>> catalog = DataCatalog(datasets={'cars': cars}) + >>> + >>> df = catalog.load("cars") + """ load_version = Version(version, None) if version else None dataset = self.get_dataset(name, version=load_version) diff --git a/tests/io/test_kedro_data_catalog.py b/tests/io/test_kedro_data_catalog.py index efa993bb0e..a53717f8ba 100644 --- a/tests/io/test_kedro_data_catalog.py +++ b/tests/io/test_kedro_data_catalog.py @@ -379,7 +379,7 @@ def test_config_invalid_dataset_config(self, correct_config): def test_empty_config(self): """Test empty config""" - assert KedroDataCatalog.from_config(None) + assert len(KedroDataCatalog.from_config(None)) == 0 def test_missing_credentials(self, correct_config): """Check the error if credentials can't be located""" @@ -502,6 +502,39 @@ def test_bad_confirm(self, correct_config, dataset_name, pattern): with pytest.raises(DatasetError, match=re.escape(pattern)): data_catalog.confirm(dataset_name) + def test_iteration(self, correct_config): + """Test iterate through keys, values and items.""" + data_catalog = KedroDataCatalog.from_config(**correct_config) + + for ds_name_cat, ds_name_config in zip( + data_catalog, correct_config["catalog"] + ): + assert ds_name_cat == ds_name_config + + for ds_name_cat, ds_name_config in zip( + data_catalog.keys(), correct_config["catalog"] + ): + assert ds_name_cat == ds_name_config + + for ds in data_catalog.values(): + assert isinstance(ds, CSVDataset) + + for ds_name, ds in data_catalog.items(): + assert isinstance(ds, CSVDataset) + assert ds_name in correct_config["catalog"] + + def test_getitem_setitem(self, correct_config): + """Test get and set item.""" + data_catalog = KedroDataCatalog.from_config(**correct_config) + data_catalog["test"] = 123 + assert isinstance(data_catalog["test"], MemoryDataset) + + def test_ipython_key_completions(self, correct_config): + data_catalog = KedroDataCatalog.from_config(**correct_config) + assert data_catalog._ipython_key_completions_() == list( + correct_config["catalog"].keys() + ) + class TestDataCatalogVersioned: def test_from_correct_config_versioned(self, correct_config, dummy_dataframe): """Test load and save of versioned datasets from config"""