diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 666b83a58..9bd3cb0ff 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -11,12 +11,19 @@ | `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` | | `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` | * `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`. + +* Added the following new core datasets: +| Type | Description | Location | +|-------------------------------------|-----------------------------------------------------------|-----------------------------------------| +| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` | + * Extended preview feature to `yaml.YAMLDataset`. ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: * [Lukas Innig](https://github.com/derluke) +* [Michael Sexton](https://github.com/michaelsexton) # Release 3.0.1 @@ -58,6 +65,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: * [Eduardo Romero Lopez](https://github.com/eromerobilbomatica) * [Jerome Asselin](https://github.com/jerome-asselin-buspatrol) + # Release 2.1.0 ## Major features and improvements diff --git a/kedro-datasets/docs/source/api/kedro_datasets.rst b/kedro-datasets/docs/source/api/kedro_datasets.rst index 6d4047b53..0109ebefc 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets.rst @@ -13,6 +13,7 @@ kedro_datasets kedro_datasets.api.APIDataset kedro_datasets.biosequence.BioSequenceDataset + kedro_datasets.dask.CSVDataset kedro_datasets.dask.ParquetDataset kedro_datasets.databricks.ManagedTableDataset kedro_datasets.email.EmailMessageDataset diff --git a/kedro-datasets/kedro_datasets/dask/__init__.py b/kedro-datasets/kedro_datasets/dask/__init__.py index 585115d54..ac293d1a6 100644 --- a/kedro-datasets/kedro_datasets/dask/__init__.py +++ b/kedro-datasets/kedro_datasets/dask/__init__.py @@ -6,7 +6,9 @@ # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 ParquetDataset: Any +CSVDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"parquet_dataset": ["ParquetDataset"]} + __name__, + submod_attrs={"parquet_dataset": ["ParquetDataset"], "csv_dataset": ["CSVDataset"]}, ) diff --git a/kedro-datasets/kedro_datasets/dask/csv_dataset.py b/kedro-datasets/kedro_datasets/dask/csv_dataset.py new file mode 100644 index 000000000..ce100bc37 --- /dev/null +++ b/kedro-datasets/kedro_datasets/dask/csv_dataset.py @@ -0,0 +1,125 @@ +"""``CSVDataset`` is a data set used to load and save data to CSV files using Dask +dataframe""" +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import dask.dataframe as dd +import fsspec +from kedro.io.core import AbstractDataset, get_protocol_and_path + + +class CSVDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]): + """``CSVDataset`` loads and saves data to comma-separated value file(s). It uses Dask + remote data services to handle the corresponding load and save operations: + https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + cars: + type: dask.CSVDataset + filepath: s3://bucket_name/path/to/folder + save_args: + compression: GZIP + credentials: + client_kwargs: + aws_access_key_id: YOUR_KEY + aws_secret_access_key: YOUR_SECRET + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.dask import CSVDataset + >>> import pandas as pd + >>> import numpy as np + >>> import dask.dataframe as dd + >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [[5, 6], [7, 8]]}) + >>> ddf = dd.from_pandas(data, npartitions=1) + >>> dataset = CSVDataset(filepath="path/to/folder/*.csv") + >>> dataset.save(ddf) + >>> reloaded = dataset.load() + >>> assert np.array_equal(ddf.compute(), reloaded.compute()) + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False} + + def __init__( # noqa: PLR0913 + self, + filepath: str, + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``CSVDataset`` pointing to concrete + CSV files. + + Args: + filepath: Filepath in POSIX format to a CSV file + CSV collection or the directory of a multipart CSV. + load_args: Additional loading options `dask.dataframe.read_csv`: + https://docs.dask.org/en/latest/generated/dask.dataframe.read_csv.html + save_args: Additional saving options for `dask.dataframe.to_csv`: + https://docs.dask.org/en/latest/generated/dask.dataframe.to_csv.html + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Optional parameters to the backend file system driver: + https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + self._filepath = filepath + self._fs_args = deepcopy(fs_args) or {} + self._credentials = deepcopy(credentials) or {} + + self.metadata = metadata + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + @property + def fs_args(self) -> dict[str, Any]: + """Property of optional file system parameters. + + Returns: + A dictionary of backend file system parameters, including credentials. + """ + fs_args = deepcopy(self._fs_args) + fs_args.update(self._credentials) + return fs_args + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "load_args": self._load_args, + "save_args": self._save_args, + } + + def _load(self) -> dd.DataFrame: + return dd.read_csv( + self._filepath, storage_options=self.fs_args, **self._load_args + ) + + def _save(self, data: dd.DataFrame) -> None: + data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args) + + def _exists(self) -> bool: + protocol = get_protocol_and_path(self._filepath)[0] + file_system = fsspec.filesystem(protocol=protocol, **self.fs_args) + files = file_system.glob(self._filepath) + return bool(files) diff --git a/kedro-datasets/kedro_datasets_experimental/netcdf/netcdf_dataset.py b/kedro-datasets/kedro_datasets_experimental/netcdf/netcdf_dataset.py index 1f24e681a..6ef568223 100644 --- a/kedro-datasets/kedro_datasets_experimental/netcdf/netcdf_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/netcdf/netcdf_dataset.py @@ -62,6 +62,7 @@ class NetCDFDataset(AbstractDataset): ... ) >>> dataset.save(ds) >>> reloaded = dataset.load() + >>> assert ds.equals(reloaded) """ DEFAULT_LOAD_ARGS: dict[str, Any] = {} diff --git a/kedro-datasets/tests/dask/test_csv_dataset.py b/kedro-datasets/tests/dask/test_csv_dataset.py new file mode 100644 index 000000000..898606ad3 --- /dev/null +++ b/kedro-datasets/tests/dask/test_csv_dataset.py @@ -0,0 +1,158 @@ +import boto3 +import dask.dataframe as dd +import numpy as np +import pandas as pd +import pytest +from kedro.io.core import DatasetError +from moto import mock_aws +from s3fs import S3FileSystem + +from kedro_datasets.dask import CSVDataset + +FILE_NAME = "*.csv" +BUCKET_NAME = "test_bucket" +AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + +# Pathlib cannot be used since it strips out the second slash from "s3://" +S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}" + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id="fake_access_key", + aws_secret_access_key="fake_secret_key", + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def dummy_dd_dataframe() -> dd.DataFrame: + df = pd.DataFrame( + {"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} + ) + return dd.from_pandas(df, npartitions=1) + + +@pytest.fixture +def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFrame): + """Creates test data and adds it to mocked S3 bucket.""" + pandas_df = dummy_dd_dataframe.compute() + temporary_path = tmp_path / "test.csv" + pandas_df.to_csv(str(temporary_path)) + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes() + ) + return mocked_s3_bucket + + +@pytest.fixture +def s3_dataset(load_args, save_args): + return CSVDataset( + filepath=S3_PATH, + credentials=AWS_CREDENTIALS, + load_args=load_args, + save_args=save_args, + ) + + +@pytest.fixture() +def s3fs_cleanup(): + # clear cache so we get a clean slate every time we instantiate a S3FileSystem + yield + S3FileSystem.cachable = False + + +@pytest.mark.usefixtures("s3fs_cleanup") +class TestCSVDataset: + def test_incorrect_credentials_load(self): + """Test that incorrect credential keys won't instantiate dataset.""" + pattern = r"unexpected keyword argument" + with pytest.raises(DatasetError, match=pattern): + CSVDataset( + filepath=S3_PATH, + credentials={ + "client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"} + }, + ).load().compute() + + @pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) + def test_empty_credentials_load(self, bad_credentials): + csv_dataset = CSVDataset(filepath=S3_PATH, credentials=bad_credentials) + pattern = r"Failed while loading data from data set CSVDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + csv_dataset.load().compute() + + @pytest.mark.xfail + def test_pass_credentials(self, mocker): + """Test that AWS credentials are passed successfully into boto3 + client instantiation on creating S3 connection.""" + client_mock = mocker.patch("botocore.session.Session.create_client") + s3_dataset = CSVDataset(filepath=S3_PATH, credentials=AWS_CREDENTIALS) + pattern = r"Failed while loading data from data set CSVDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + s3_dataset.load().compute() + + assert client_mock.call_count == 1 + args, kwargs = client_mock.call_args_list[0] + assert args == ("s3",) + assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"] + assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] + + def test_save_data(self, s3_dataset, mocked_s3_bucket): + """Test saving the data to S3.""" + pd_data = pd.DataFrame( + {"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]} + ) + dd_data = dd.from_pandas(pd_data, npartitions=1) + s3_dataset.save(dd_data) + loaded_data = s3_dataset.load() + np.array_equal(loaded_data.compute(), dd_data.compute()) + + def test_load_data(self, s3_dataset, dummy_dd_dataframe, mocked_s3_object): + """Test loading the data from S3.""" + loaded_data = s3_dataset.load() + np.array_equal(loaded_data, dummy_dd_dataframe.compute()) + + def test_exists(self, s3_dataset, dummy_dd_dataframe, mocked_s3_bucket): + """Test `exists` method invocation for both existing and + nonexistent data set.""" + assert not s3_dataset.exists() + s3_dataset.save(dummy_dd_dataframe) + assert s3_dataset.exists() + + def test_save_load_locally(self, tmp_path, dummy_dd_dataframe): + """Test loading the data locally.""" + file_path = str(tmp_path / "some" / "dir" / FILE_NAME) + dataset = CSVDataset(filepath=file_path) + + assert not dataset.exists() + dataset.save(dummy_dd_dataframe) + assert dataset.exists() + loaded_data = dataset.load() + dummy_dd_dataframe.compute().equals(loaded_data.compute()) + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, s3_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert s3_dataset._load_args[key] == value + + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, s3_dataset, save_args): + """Test overriding the default save arguments.""" + + for key, value in save_args.items(): + assert s3_dataset._save_args[key] == value + + for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items(): + assert s3_dataset._save_args[key] != value