Skip to content

Commit

Permalink
feat(datasets): Add CSVDataset to dask module (#627)
Browse files Browse the repository at this point in the history
* Add CSVDataset to dask module

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Add tests to dask.CSVDataset

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Fix formatting issues in example usage

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Fix error in example usage that is causing test to fail

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Remove arguments from example usage

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Fix issue with folder used as path for CSV file

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Change number of partitions to fix failing assertion

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Fix syntax issue

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Remove temp path

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Add default save args

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Add to documentation and release notes

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>

* Fix lint

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Try fix netcdfdataset doctest

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Try fix netcdfdataset doctest pointing at file

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Fix moto mock_aws import

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>

* Fix lint and test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Mypy

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* docs test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* docs test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* docs test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Fix unit tests

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Remove extra comments

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Try fix test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Release notes + test

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Suggestion from code review

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Michael Sexton <michael.sexton@ga.gov.au>
Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com>
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
Signed-off-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com>
Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com>
Co-authored-by: Merel Theisen <merel.theisen@quantumblack.com>
Co-authored-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com>
Co-authored-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
5 people authored Jun 21, 2024
1 parent bf6596a commit 966d989
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 1 deletion.
8 changes: 8 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
| `langchain.ChatOpenAIDataset` | A dataset for loading a ChatOpenAI langchain model. | `kedro_datasets_experimental.langchain` |
| `netcdf.NetCDFDataset` | A dataset for loading and saving "*.nc" files. | `kedro_datasets_experimental.netcdf` |
* `netcdf.NetCDFDataset` moved from `kedro_datasets` to `kedro_datasets_experimental`.

* Added the following new core datasets:
| Type | Description | Location |
|-------------------------------------|-----------------------------------------------------------|-----------------------------------------|
| `dask.CSVDataset` | A dataset for loading a CSV files using `dask` | `kedro_datasets.dask` |

* Extended preview feature to `yaml.YAMLDataset`.

## Community contributions

Many thanks to the following Kedroids for contributing PRs to this release:
* [Lukas Innig](https://github.com/derluke)
* [Michael Sexton](https://github.com/michaelsexton)


# Release 3.0.1
Expand Down Expand Up @@ -58,6 +65,7 @@ Many thanks to the following Kedroids for contributing PRs to this release:
* [Eduardo Romero Lopez](https://github.com/eromerobilbomatica)
* [Jerome Asselin](https://github.com/jerome-asselin-buspatrol)


# Release 2.1.0
## Major features and improvements

Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/docs/source/api/kedro_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ kedro_datasets

kedro_datasets.api.APIDataset
kedro_datasets.biosequence.BioSequenceDataset
kedro_datasets.dask.CSVDataset
kedro_datasets.dask.ParquetDataset
kedro_datasets.databricks.ManagedTableDataset
kedro_datasets.email.EmailMessageDataset
Expand Down
4 changes: 3 additions & 1 deletion kedro-datasets/kedro_datasets/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
ParquetDataset: Any
CSVDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__, submod_attrs={"parquet_dataset": ["ParquetDataset"]}
__name__,
submod_attrs={"parquet_dataset": ["ParquetDataset"], "csv_dataset": ["CSVDataset"]},
)
125 changes: 125 additions & 0 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""``CSVDataset`` is a data set used to load and save data to CSV files using Dask
dataframe"""
from __future__ import annotations

from copy import deepcopy
from typing import Any

import dask.dataframe as dd
import fsspec
from kedro.io.core import AbstractDataset, get_protocol_and_path


class CSVDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]):
"""``CSVDataset`` loads and saves data to comma-separated value file(s). It uses Dask
remote data services to handle the corresponding load and save operations:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:
.. code-block:: yaml
cars:
type: dask.CSVDataset
filepath: s3://bucket_name/path/to/folder
save_args:
compression: GZIP
credentials:
client_kwargs:
aws_access_key_id: YOUR_KEY
aws_secret_access_key: YOUR_SECRET
Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
advanced_data_catalog_usage.html>`_:
.. code-block:: pycon
>>> from kedro_datasets.dask import CSVDataset
>>> import pandas as pd
>>> import numpy as np
>>> import dask.dataframe as dd
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [[5, 6], [7, 8]]})
>>> ddf = dd.from_pandas(data, npartitions=1)
>>> dataset = CSVDataset(filepath="path/to/folder/*.csv")
>>> dataset.save(ddf)
>>> reloaded = dataset.load()
>>> assert np.array_equal(ddf.compute(), reloaded.compute())
"""

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {"index": False}

def __init__( # noqa: PLR0913
self,
filepath: str,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
fs_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``CSVDataset`` pointing to concrete
CSV files.
Args:
filepath: Filepath in POSIX format to a CSV file
CSV collection or the directory of a multipart CSV.
load_args: Additional loading options `dask.dataframe.read_csv`:
https://docs.dask.org/en/latest/generated/dask.dataframe.read_csv.html
save_args: Additional saving options for `dask.dataframe.to_csv`:
https://docs.dask.org/en/latest/generated/dask.dataframe.to_csv.html
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Optional parameters to the backend file system driver:
https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html#optional-parameters
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
self._filepath = filepath
self._fs_args = deepcopy(fs_args) or {}
self._credentials = deepcopy(credentials) or {}

self.metadata = metadata

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

@property
def fs_args(self) -> dict[str, Any]:
"""Property of optional file system parameters.
Returns:
A dictionary of backend file system parameters, including credentials.
"""
fs_args = deepcopy(self._fs_args)
fs_args.update(self._credentials)
return fs_args

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
"load_args": self._load_args,
"save_args": self._save_args,
}

def _load(self) -> dd.DataFrame:
return dd.read_csv(
self._filepath, storage_options=self.fs_args, **self._load_args
)

def _save(self, data: dd.DataFrame) -> None:
data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args)

def _exists(self) -> bool:
protocol = get_protocol_and_path(self._filepath)[0]
file_system = fsspec.filesystem(protocol=protocol, **self.fs_args)
files = file_system.glob(self._filepath)
return bool(files)
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class NetCDFDataset(AbstractDataset):
... )
>>> dataset.save(ds)
>>> reloaded = dataset.load()
>>> assert ds.equals(reloaded)
"""

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
Expand Down
158 changes: 158 additions & 0 deletions kedro-datasets/tests/dask/test_csv_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import boto3
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pytest
from kedro.io.core import DatasetError
from moto import mock_aws
from s3fs import S3FileSystem

from kedro_datasets.dask import CSVDataset

FILE_NAME = "*.csv"
BUCKET_NAME = "test_bucket"
AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"}

# Pathlib cannot be used since it strips out the second slash from "s3://"
S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}"


@pytest.fixture
def mocked_s3_bucket():
"""Create a bucket for testing using moto."""
with mock_aws():
conn = boto3.client(
"s3",
aws_access_key_id="fake_access_key",
aws_secret_access_key="fake_secret_key",
)
conn.create_bucket(Bucket=BUCKET_NAME)
yield conn


@pytest.fixture
def dummy_dd_dataframe() -> dd.DataFrame:
df = pd.DataFrame(
{"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]}
)
return dd.from_pandas(df, npartitions=1)


@pytest.fixture
def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFrame):
"""Creates test data and adds it to mocked S3 bucket."""
pandas_df = dummy_dd_dataframe.compute()
temporary_path = tmp_path / "test.csv"
pandas_df.to_csv(str(temporary_path))

mocked_s3_bucket.put_object(
Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes()
)
return mocked_s3_bucket


@pytest.fixture
def s3_dataset(load_args, save_args):
return CSVDataset(
filepath=S3_PATH,
credentials=AWS_CREDENTIALS,
load_args=load_args,
save_args=save_args,
)


@pytest.fixture()
def s3fs_cleanup():
# clear cache so we get a clean slate every time we instantiate a S3FileSystem
yield
S3FileSystem.cachable = False


@pytest.mark.usefixtures("s3fs_cleanup")
class TestCSVDataset:
def test_incorrect_credentials_load(self):
"""Test that incorrect credential keys won't instantiate dataset."""
pattern = r"unexpected keyword argument"
with pytest.raises(DatasetError, match=pattern):
CSVDataset(
filepath=S3_PATH,
credentials={
"client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"}
},
).load().compute()

@pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}])
def test_empty_credentials_load(self, bad_credentials):
csv_dataset = CSVDataset(filepath=S3_PATH, credentials=bad_credentials)
pattern = r"Failed while loading data from data set CSVDataset\(.+\)"
with pytest.raises(DatasetError, match=pattern):
csv_dataset.load().compute()

@pytest.mark.xfail
def test_pass_credentials(self, mocker):
"""Test that AWS credentials are passed successfully into boto3
client instantiation on creating S3 connection."""
client_mock = mocker.patch("botocore.session.Session.create_client")
s3_dataset = CSVDataset(filepath=S3_PATH, credentials=AWS_CREDENTIALS)
pattern = r"Failed while loading data from data set CSVDataset\(.+\)"
with pytest.raises(DatasetError, match=pattern):
s3_dataset.load().compute()

assert client_mock.call_count == 1
args, kwargs = client_mock.call_args_list[0]
assert args == ("s3",)
assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"]
assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"]

def test_save_data(self, s3_dataset, mocked_s3_bucket):
"""Test saving the data to S3."""
pd_data = pd.DataFrame(
{"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]}
)
dd_data = dd.from_pandas(pd_data, npartitions=1)
s3_dataset.save(dd_data)
loaded_data = s3_dataset.load()
np.array_equal(loaded_data.compute(), dd_data.compute())

def test_load_data(self, s3_dataset, dummy_dd_dataframe, mocked_s3_object):
"""Test loading the data from S3."""
loaded_data = s3_dataset.load()
np.array_equal(loaded_data, dummy_dd_dataframe.compute())

def test_exists(self, s3_dataset, dummy_dd_dataframe, mocked_s3_bucket):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
assert not s3_dataset.exists()
s3_dataset.save(dummy_dd_dataframe)
assert s3_dataset.exists()

def test_save_load_locally(self, tmp_path, dummy_dd_dataframe):
"""Test loading the data locally."""
file_path = str(tmp_path / "some" / "dir" / FILE_NAME)
dataset = CSVDataset(filepath=file_path)

assert not dataset.exists()
dataset.save(dummy_dd_dataframe)
assert dataset.exists()
loaded_data = dataset.load()
dummy_dd_dataframe.compute().equals(loaded_data.compute())

@pytest.mark.parametrize(
"load_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_load_extra_params(self, s3_dataset, load_args):
"""Test overriding the default load arguments."""
for key, value in load_args.items():
assert s3_dataset._load_args[key] == value

@pytest.mark.parametrize(
"save_args", [{"k1": "v1", "index": "value"}], indirect=True
)
def test_save_extra_params(self, s3_dataset, save_args):
"""Test overriding the default save arguments."""

for key, value in save_args.items():
assert s3_dataset._save_args[key] == value

for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items():
assert s3_dataset._save_args[key] != value

0 comments on commit 966d989

Please sign in to comment.