Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support modern datasets (Kedro 0.19.7+) #590

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

## [Unreleased]

### Added

- :sparkles: Add support for "modern" datasets ([introduced in Kedro 0.19.7](https://github.com/kedro-org/kedro/commit/52458c2addd1827623d06c20228b709052a5fdf3)) that expose `load` and `save` publicly ([#590, deepyaman](https://github.com/Galileo-Galilei/kedro-mlflow/pull/590))

## [0.13.0] - 2024-09-01

### Added

- :sparkles: Add support for loading model with alias in `MlflowModelRegistryDataset` [#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553)
- :sparkles: Add support for loading model with alias in `MlflowModelRegistryDataset` ([#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553))

### Changed

Expand Down
10 changes: 8 additions & 2 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _save(self, data: Any):
# for logging on remote storage like Azure S3
local_path = local_path.as_posix()

super()._save(data)
if getattr(super().save, "__savewrapped__", False): # modern dataset
super().save.__wrapped__(self, data)
else: # legacy dataset
super()._save(data)

if self._logging_activated:
if self.run_id:
Expand Down Expand Up @@ -131,7 +134,10 @@ def _load(self) -> Any: # pragma: no cover
shutil.copy(src=temp_download_filepath, dst=local_path)

# finally, read locally
return super()._load()
if getattr(super().load, "__loadwrapped__", False): # modern dataset
return super().load.__wrapped__(self)
else: # legacy dataset
super()._load()

# rename the class
parent_name = dataset_obj.__name__
Expand Down
81 changes: 81 additions & 0 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mlflow
import pandas as pd
import pytest
from kedro.io import AbstractDataset
from kedro_datasets.pandas import CSVDataset
from kedro_datasets.partitions import PartitionedDataset
from kedro_datasets.pickle import PickleDataset
Expand Down Expand Up @@ -289,3 +290,83 @@ def test_partitioned_dataset_save_and_reload(
reloaded_data = {k: loader() for k, loader in mlflow_dataset.load().items()}
for k, df in data.items():
pd.testing.assert_frame_equal(df, reloaded_data[k])


def test_modern_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithoutUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithoutUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())


def test_legacy_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def _load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def _save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())
Loading