Skip to content

Commit

Permalink
✨ Add support for loading model with alias in MlflowModelRegistryData…
Browse files Browse the repository at this point in the history
…set (#553)
  • Loading branch information
Galileo-Galilei committed Aug 27, 2024
1 parent 61864f0 commit a1e248c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- :sparkles: Add support for loading model with alias in ``MlflowModelRegistryDataset`` [#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553)

### Changed

- :boom: :pushpin: Officially drop support for ``mlflow<1.29.0`` which was implicit since the introduction of ``km.random_name`` resolver in [#481](https://github.com/Galileo-Galilei/kedro-mlflow/issues/481) ([#571](https://github.com/Galileo-Galilei/kedro-mlflow/issues/571))
Expand Down
22 changes: 20 additions & 2 deletions kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, Union

from kedro.io.core import DatasetError

from kedro_mlflow.io.models.mlflow_abstract_model_dataset import (
MlflowAbstractModelDataSet,
)
Expand All @@ -11,7 +13,8 @@ class MlflowModelRegistryDataset(MlflowAbstractModelDataSet):
def __init__(
self,
model_name: str,
stage_or_version: Union[str, int] = "latest",
stage_or_version: Union[str, int, None] = None,
alias: Optional[str] = None,
flavor: Optional[str] = "mlflow.pyfunc",
pyfunc_workflow: Optional[str] = "python_model",
load_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -46,9 +49,23 @@ def __init__(
version=None,
)

if alias is None and stage_or_version is None:
# reassign stage_or_version to "latest"
stage_or_version = "latest"

if alias and stage_or_version:
raise DatasetError(
f"You cannot specify 'alias' and 'stage_or_version' simultaneously ({alias=} and {stage_or_version=})"
)

self.model_name = model_name
self.stage_or_version = stage_or_version
self.model_uri = f"models:/{model_name}/{stage_or_version}"
self.alias = alias
self.model_uri = (
f"models:/{model_name}@{alias}"
if alias
else f"models:/{model_name}/{stage_or_version}"
)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Expand All @@ -74,6 +91,7 @@ def _describe(self) -> Dict[str, Any]:
model_uri=self.model_uri,
model_name=self.model_name,
stage_or_version=self.stage_or_version,
alias=self.alias,
flavor=self._flavor,
pyfunc_workflow=self._pyfunc_workflow,
# load_args=self._load_args,
Expand Down
61 changes: 58 additions & 3 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import re

import mlflow
import pytest
from kedro.io.core import DatasetError
from mlflow import MlflowClient
from mlflow import __version__ as mlflow_version
from sklearn.tree import DecisionTreeClassifier

from kedro_mlflow.io.models import MlflowModelRegistryDataset

MLFLOW_VERSION_TUPLE = tuple(
int(x) for x in re.findall("([0-9]+)\.([0-9]+)\.([0-9]+)", mlflow_version)[0]
)


def test_mlflow_model_registry_save_not_implemented(tmp_path):
ml_ds = MlflowModelRegistryDataset(model_name="demo_model")
Expand All @@ -16,14 +23,25 @@ def test_mlflow_model_registry_save_not_implemented(tmp_path):
ml_ds.save(DecisionTreeClassifier())


def test_mlflow_model_registry_alias_and_stage_or_version_fails(tmp_path):
with pytest.raises(
DatasetError,
match=r"You cannot specify 'alias' and 'stage_or_version' simultaneously",
):
MlflowModelRegistryDataset(
model_name="demo_model", alias="my_alias", stage_or_version="my_stage"
)


def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_uri)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 3 version of a model under a single
# registered model and stage the 2nd one
Expand All @@ -36,7 +54,9 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch
)
runs[i + 1] = mlflow.active_run().info.run_id

client = MlflowClient(tracking_uri=tracking_uri)
client = MlflowClient(
tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri
)
client.transition_model_version_stage(name="demo_model", version=2, stage="Staging")

# case 1: no version is provided, we take the last one
Expand All @@ -55,3 +75,38 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", stage_or_version="1")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]


@pytest.mark.skipif(
MLFLOW_VERSION_TUPLE < (2, 9, 0), reason="Requires mlflow 2.9.0 or higher"
)
def test_mlflow_model_registry_load_given_alias(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns4.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 3 version of a model under a single
# registered model and alias the 2nd one
runs = {}
for i in range(2):
with mlflow.start_run():
model = DecisionTreeClassifier()
mlflow.sklearn.log_model(
model, artifact_path="demo_model", registered_model_name="demo_model"
)
runs[i + 1] = mlflow.active_run().info.run_id

client = MlflowClient(
tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri
)
client.set_registered_model_alias(name="demo_model", alias="champion", version=1)

# case 2: an alias is provided, we take the last model with this stage
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", alias="champion")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]

0 comments on commit a1e248c

Please sign in to comment.