diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index bd2e9ac15..52ba9fe51 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -5,6 +5,8 @@ | Type | Description | Location | |-------------------------------------|-----------------------------------------------------------|-----------------------------------------| | `pytorch.PyTorchDataset` | A dataset for securely saving and loading PyTorch models | `kedro_datasets_experimental.pytorch` | +| `prophet.ProphetModelDataset` | A dataset for Meta's Prophet model for time series forecasting | `kedro_datasets_experimental.prophet` | + * Added the following new core datasets: @@ -24,6 +26,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: * [yury-fedotov](https://github.com/yury-fedotov) * [gitgud5000](https://github.com/gitgud5000) * [janickspirig](https://github.com/janickspirig) +* [Galen Seilis](https://github.com/galenseilis) # Release 4.1.0 diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index 0eb76c739..219510954 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -16,5 +16,6 @@ kedro_datasets_experimental langchain.ChatOpenAIDataset langchain.OpenAIEmbeddingsDataset netcdf.NetCDFDataset + prophet.ProphetModelDataset pytorch.PyTorchDataset rioxarray.GeoTIFFDataset diff --git a/kedro-datasets/docs/source/conf.py b/kedro-datasets/docs/source/conf.py index 70c6be3ae..09524612a 100644 --- a/kedro-datasets/docs/source/conf.py +++ b/kedro-datasets/docs/source/conf.py @@ -140,6 +140,8 @@ "xarray.core.dataset.Dataset", "xarray.core.dataarray.DataArray", "torch.nn.modules.module.Module", + "prophet.forecaster.Prophet", + "Prophet", ), "py:data": ( "typing.Any", diff --git a/kedro-datasets/kedro_datasets_experimental/prophet/__init__.py b/kedro-datasets/kedro_datasets_experimental/prophet/__init__.py new file mode 100644 index 000000000..93cd66d99 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/prophet/__init__.py @@ -0,0 +1,11 @@ +"""``JSONDataset`` implementation to load/save data from/to a Prophet model file.""" + +from typing import Any + +import lazy_loader as lazy + +ProphetDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"prophet_dataset": ["ProphetModelDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/prophet/prophet_dataset.py b/kedro-datasets/kedro_datasets_experimental/prophet/prophet_dataset.py new file mode 100644 index 000000000..ca2cd1e75 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/prophet/prophet_dataset.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Any + +from kedro.io.core import Version, get_filepath_str +from prophet import Prophet +from prophet.serialize import model_from_json, model_to_json + +from kedro_datasets.json import JSONDataset + + +class ProphetModelDataset(JSONDataset): + """``ProphetModelDataset`` loads/saves Facebook Prophet models to a JSON file using an + underlying filesystem (e.g., local, S3, GCS). It uses Prophet's built-in + serialization to handle the JSON file. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + model: + type: custom_datasets.ProphetModelDataset + filepath: gcs://your_bucket/model.json + fs_args: + project: my-project + credentials: my_gcp_credentials + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets_experimental.prophet import ProphetModelDataset + >>> from prophet import Prophet + >>> import pandas as pd + >>> + >>> df = pd.DataFrame({ + >>> "ds": ["2024-01-01", "2024-01-02", "2024-01-03"], + >>> "y": [100, 200, 300] + >>> }) + >>> + >>> model = Prophet() + >>> model.fit(df) + >>> dataset = ProphetModelDataset(filepath="path/to/model.json") + >>> dataset.save(model) + >>> reloaded_model = dataset.load() + + """ + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``ProphetModelDataset`` pointing to a concrete JSON file + on a specific filesystem. + + Args: + filepath: Filepath in POSIX format to a JSON file prefixed with a protocol like `s3://`. + If prefix is not provided, `file` protocol (local filesystem) will be used. + The prefix should be any protocol supported by ``fsspec``. + Note: `http(s)` doesn't support versioning. + save_args: json options for saving JSON files (arguments passed + into ```json.dump``). Here you can find all available arguments: + https://docs.python.org/3/library/json.html + All defaults are preserved, but "default_flow_style", which is set to False. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + super().__init__( + filepath=filepath, + save_args=save_args, + version=version, + credentials=credentials, + fs_args=fs_args, + metadata=metadata, + ) + + def _load(self) -> Prophet: + """Loads a Prophet model from a JSON file. + + Returns: + Prophet: A deserialized Prophet model. + """ + load_path = get_filepath_str(self._get_load_path(), self._protocol) + + with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: + return model_from_json(fs_file.read()) + + def _save(self, data: Prophet) -> None: + """Saves a Prophet model to a JSON file. + + Args: + data: The Prophet model instance to be serialized and saved. + """ + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: + fs_file.write(model_to_json(data)) + + self._invalidate_cache() diff --git a/kedro-datasets/kedro_datasets_experimental/tests/conftest.py b/kedro-datasets/kedro_datasets_experimental/tests/conftest.py new file mode 100644 index 000000000..91d19f646 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/conftest.py @@ -0,0 +1,34 @@ +""" +This file contains the fixtures that are reusable by any tests within +this directory. You don't need to import the fixtures as pytest will +discover them automatically. More info here: +https://docs.pytest.org/en/latest/fixture.html +""" + +from kedro.io.core import generate_timestamp +from pytest import fixture + + +@fixture(params=[None]) +def load_version(request): + return request.param + + +@fixture(params=[None]) +def save_version(request): + return request.param or generate_timestamp() + + +@fixture(params=[None]) +def load_args(request): + return request.param + + +@fixture(params=[None]) +def save_args(request): + return request.param + + +@fixture(params=[None]) +def fs_args(request): + return request.param diff --git a/kedro-datasets/kedro_datasets_experimental/tests/prophet/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/prophet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/prophet/test_prophet_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/prophet/test_prophet_dataset.py new file mode 100644 index 000000000..88510a99b --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/prophet/test_prophet_dataset.py @@ -0,0 +1,209 @@ +from pathlib import Path, PurePosixPath + +import pandas as pd +import pytest +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from prophet import Prophet +from s3fs.core import S3FileSystem + +from kedro_datasets_experimental.prophet import ProphetModelDataset + + +@pytest.fixture +def filepath_json(tmp_path): + return (tmp_path / "test_model.json").as_posix() + + +@pytest.fixture +def prophet_model_dataset(filepath_json, save_args, fs_args): + return ProphetModelDataset( + filepath=filepath_json, save_args=save_args, fs_args=fs_args + ) + + +@pytest.fixture +def versioned_prophet_model_dataset(filepath_json, load_version, save_version): + return ProphetModelDataset( + filepath=filepath_json, version=Version(load_version, save_version) + ) + + +@pytest.fixture +def dummy_model(): + df = pd.DataFrame({"ds": ["2024-01-01", "2024-01-02", "2024-01-03"], "y": [100, 200, 300]}) + model = Prophet() + # Fit the model with dummy data + model.fit(df) + return model + + +class TestProphetModelDataset: + def test_save_and_load(self, prophet_model_dataset, dummy_model): + """Test saving and reloading the Prophet model.""" + prophet_model_dataset.save(dummy_model) + reloaded = prophet_model_dataset.load() + assert isinstance(reloaded, Prophet) + assert prophet_model_dataset._fs_open_args_load == {} + assert prophet_model_dataset._fs_open_args_save == {"mode": "w"} + + def test_exists(self, prophet_model_dataset, dummy_model): + """Test `exists` method invocation for both existing and + nonexistent dataset.""" + assert not prophet_model_dataset.exists() + prophet_model_dataset.save(dummy_model) + assert prophet_model_dataset.exists() + + @pytest.mark.parametrize("save_args", [{"k1": "v1", "indent": 4}], indirect=True) + def test_save_extra_params(self, prophet_model_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert prophet_model_dataset._save_args[key] == value + + @pytest.mark.parametrize( + "fs_args", + [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], + indirect=True, + ) + def test_open_extra_args(self, prophet_model_dataset, fs_args): + assert prophet_model_dataset._fs_open_args_load == fs_args["open_args_load"] + assert prophet_model_dataset._fs_open_args_save == { + "mode": "w" + } # default unchanged + + def test_load_missing_file(self, prophet_model_dataset): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from data set ProphetModelDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + prophet_model_dataset.load() + + @pytest.mark.parametrize( + "filepath,instance_type", + [ + ("s3://bucket/model.json", S3FileSystem), + ("file:///tmp/test_model.json", LocalFileSystem), + ("/tmp/test_model.json", LocalFileSystem), #nosec: B108 + ("gcs://bucket/model.json", GCSFileSystem), + ("https://example.com/model.json", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, filepath, instance_type): + dataset = ProphetModelDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) + + path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test_model.json" + dataset = ProphetModelDataset(filepath=filepath) + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + + +class TestProphetModelDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + filepath = "test_model.json" + ds = ProphetModelDataset(filepath=filepath) + ds_versioned = ProphetModelDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + assert filepath in str(ds) + assert "version" not in str(ds) + + assert filepath in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "ProphetModelDataset" in str(ds_versioned) + assert "ProphetModelDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + # Default save_args + assert "save_args={'indent': 2}" in str(ds) + assert "save_args={'indent': 2}" in str(ds_versioned) + + def test_save_and_load(self, versioned_prophet_model_dataset, dummy_model): + """Test that saved and reloaded data matches the original one for + the versioned dataset.""" + versioned_prophet_model_dataset.save(dummy_model) + reloaded = versioned_prophet_model_dataset.load() + assert isinstance(reloaded, Prophet) + + def test_no_versions(self, versioned_prophet_model_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for ProphetModelDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_prophet_model_dataset.load() + + def test_exists(self, versioned_prophet_model_dataset, dummy_model): + """Test `exists` method invocation for versioned dataset.""" + assert not versioned_prophet_model_dataset.exists() + versioned_prophet_model_dataset.save(dummy_model) + assert versioned_prophet_model_dataset.exists() + + def test_prevent_overwrite(self, versioned_prophet_model_dataset, dummy_model): + """Check the error when attempting to override the dataset if the + corresponding json file for a given save version already exists.""" + versioned_prophet_model_dataset.save(dummy_model) + pattern = ( + r"Save path \'.+\' for ProphetModelDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_prophet_model_dataset.save(dummy_model) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_prophet_model_dataset, load_version, save_version, dummy_model + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + f"Save version '{save_version}' did not match " + f"load version '{load_version}' for " + r"ProphetModelDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_prophet_model_dataset.save(dummy_model) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + ProphetModelDataset( + filepath="https://example.com/model.json", version=Version(None, None) + ) + + def test_versioning_existing_dataset( + self, prophet_model_dataset, versioned_prophet_model_dataset, dummy_model + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + prophet_model_dataset.save(dummy_model) + assert prophet_model_dataset.exists() + assert ( + prophet_model_dataset._filepath == versioned_prophet_model_dataset._filepath + ) + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_prophet_model_dataset._filepath.parent.as_posix()})" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_prophet_model_dataset.save(dummy_model) + + # Remove non-versioned dataset and try again + Path(prophet_model_dataset._filepath.as_posix()).unlink() + versioned_prophet_model_dataset.save(dummy_model) + assert versioned_prophet_model_dataset.exists() diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 895caebfd..da2c15b18 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -183,6 +183,8 @@ langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddi netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"] netcdf = ["kedro-datasets[netcdf-netcdfdataset]"] +prophet-dataset = ["prophet>=1.1.5"] +prophet = ["kedro-datasets[prophet]"] pytorch-dataset = ["torch"] pytorch = ["kedro-datasets[pytorch-dataset]"] @@ -290,7 +292,8 @@ experimental = [ "netcdf4>=1.6.4", "xarray>=2023.1.0", "rioxarray", - "torch" + "torch", + "prophet>=1.1.5", ] # All requirements