From 3351137c1398cea05db070bbff3769d0bb17da42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yolan=20Honor=C3=A9-Roug=C3=A9?= Date: Tue, 27 Aug 2024 23:17:47 +0200 Subject: [PATCH] :sparkles: Add support for loading model with alias in MlflowModelRegistryDataset (#553) --- CHANGELOG.md | 4 ++ docs/source/07_python_objects/01_DataSets.md | 13 ++-- .../models/mlflow_model_registry_dataset.py | 22 ++++++- .../test_mlflow_model_registry_dataset.py | 61 ++++++++++++++++++- 4 files changed, 89 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abf386f8..21538081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- :sparkles: Add support for loading model with alias in ``MlflowModelRegistryDataset`` [#553](https://github.com/Galileo-Galilei/kedro-mlflow/issues/553) + ### Changed - :boom: :pushpin: Officially drop support for ``mlflow<1.29.0`` which was implicit since the introduction of ``km.random_name`` resolver in [#481](https://github.com/Galileo-Galilei/kedro-mlflow/issues/481) ([#571](https://github.com/Galileo-Galilei/kedro-mlflow/issues/571)) diff --git a/docs/source/07_python_objects/01_DataSets.md b/docs/source/07_python_objects/01_DataSets.md index a4b6a2e7..a1da1c6a 100644 --- a/docs/source/07_python_objects/01_DataSets.md +++ b/docs/source/07_python_objects/01_DataSets.md @@ -117,7 +117,7 @@ my_model: ### ``MlflowModelLocalFileSystemDataset`` -The ``MlflowModelTrackingDataset`` accepts the following arguments: +The ``MlflowModelLocalFileSystemDataset`` accepts the following arguments: - flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable. - filepath (str): Path to store the dataset locally. @@ -163,11 +163,12 @@ my_model: The ``MlflowModelRegistryDataset`` accepts the following arguments: -- model_name (str): The name of the registered model is the mlflow registry -- stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to "latest" which fetch the last version and the higher "stage" available. -- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable. -- pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows). -- load_args (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None. +- ``model_name`` (str): The name of the registered model is the mlflow registry +- ``stage_or_version`` (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to None,(internally converted to "latest" if no alias si provided) which fetch the last version and the higher "stage" available. +- ``alias`` (str): A valid alias, which is used instead of stage to filter model since mlflow 2.9.0. Will raise an error if both ``stage_or_version`` and ``alias`` are provided. +- ``flavor`` (str): Built-in or custom MLflow model flavor module. Must be Python-importable. +- ``pyfunc_workflow`` (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows). +- ``load_args`` (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None. We assume you have registered a mlflow model first, either [with the ``MlflowClient``](https://mlflow.org/docs/latest/model-registry.html#adding-an-mlflow-model-to-the-model-registry) or [within the mlflow ui](https://mlflow.org/docs/latest/model-registry.html#ui-workflow), e.g. : diff --git a/kedro_mlflow/io/models/mlflow_model_registry_dataset.py b/kedro_mlflow/io/models/mlflow_model_registry_dataset.py index 8306dd33..615a5bc5 100644 --- a/kedro_mlflow/io/models/mlflow_model_registry_dataset.py +++ b/kedro_mlflow/io/models/mlflow_model_registry_dataset.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Optional, Union +from kedro.io.core import DatasetError + from kedro_mlflow.io.models.mlflow_abstract_model_dataset import ( MlflowAbstractModelDataSet, ) @@ -11,7 +13,8 @@ class MlflowModelRegistryDataset(MlflowAbstractModelDataSet): def __init__( self, model_name: str, - stage_or_version: Union[str, int] = "latest", + stage_or_version: Union[str, int, None] = None, + alias: Optional[str] = None, flavor: Optional[str] = "mlflow.pyfunc", pyfunc_workflow: Optional[str] = "python_model", load_args: Optional[Dict[str, Any]] = None, @@ -46,9 +49,23 @@ def __init__( version=None, ) + if alias is None and stage_or_version is None: + # reassign stage_or_version to "latest" + stage_or_version = "latest" + + if alias and stage_or_version: + raise DatasetError( + f"You cannot specify 'alias' and 'stage_or_version' simultaneously ({alias=} and {stage_or_version=})" + ) + self.model_name = model_name self.stage_or_version = stage_or_version - self.model_uri = f"models:/{model_name}/{stage_or_version}" + self.alias = alias + self.model_uri = ( + f"models:/{model_name}@{alias}" + if alias + else f"models:/{model_name}/{stage_or_version}" + ) def _load(self) -> Any: """Loads an MLflow model from local path or from MLflow run. @@ -74,6 +91,7 @@ def _describe(self) -> Dict[str, Any]: model_uri=self.model_uri, model_name=self.model_name, stage_or_version=self.stage_or_version, + alias=self.alias, flavor=self._flavor, pyfunc_workflow=self._pyfunc_workflow, # load_args=self._load_args, diff --git a/tests/io/models/test_mlflow_model_registry_dataset.py b/tests/io/models/test_mlflow_model_registry_dataset.py index 1c58f70e..a8b42b37 100644 --- a/tests/io/models/test_mlflow_model_registry_dataset.py +++ b/tests/io/models/test_mlflow_model_registry_dataset.py @@ -1,11 +1,18 @@ +import re + import mlflow import pytest from kedro.io.core import DatasetError from mlflow import MlflowClient +from mlflow import __version__ as mlflow_version from sklearn.tree import DecisionTreeClassifier from kedro_mlflow.io.models import MlflowModelRegistryDataset +MLFLOW_VERSION_TUPLE = tuple( + int(x) for x in re.findall("([0-9]+)\.([0-9]+)\.([0-9]+)", mlflow_version)[0] +) + def test_mlflow_model_registry_save_not_implemented(tmp_path): ml_ds = MlflowModelRegistryDataset(model_name="demo_model") @@ -16,14 +23,25 @@ def test_mlflow_model_registry_save_not_implemented(tmp_path): ml_ds.save(DecisionTreeClassifier()) +def test_mlflow_model_registry_alias_and_stage_or_version_fails(tmp_path): + with pytest.raises( + DatasetError, + match=r"You cannot specify 'alias' and 'stage_or_version' simultaneously", + ): + MlflowModelRegistryDataset( + model_name="demo_model", alias="my_alias", stage_or_version="my_stage" + ) + + def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch): # we must change the working directory because when # using mlflow with a local database tracking, the artifacts # are stored in a relative mlruns/ folder so we need to have # the same working directory that the one of the tracking uri monkeypatch.chdir(tmp_path) - tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix() - mlflow.set_tracking_uri(tracking_uri) + tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix() + mlflow.set_tracking_uri(tracking_and_registry_uri) + mlflow.set_registry_uri(tracking_and_registry_uri) # setup: we train 3 version of a model under a single # registered model and stage the 2nd one @@ -36,7 +54,9 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch ) runs[i + 1] = mlflow.active_run().info.run_id - client = MlflowClient(tracking_uri=tracking_uri) + client = MlflowClient( + tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri + ) client.transition_model_version_stage(name="demo_model", version=2, stage="Staging") # case 1: no version is provided, we take the last one @@ -55,3 +75,38 @@ def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch ml_ds = MlflowModelRegistryDataset(model_name="demo_model", stage_or_version="1") loaded_model = ml_ds.load() assert loaded_model.metadata.run_id == runs[1] + + +@pytest.mark.skipif( + MLFLOW_VERSION_TUPLE < (2, 9, 0), reason="Requires mlflow 2.9.0 or higher" +) +def test_mlflow_model_registry_load_given_alias(tmp_path, monkeypatch): + # we must change the working directory because when + # using mlflow with a local database tracking, the artifacts + # are stored in a relative mlruns/ folder so we need to have + # the same working directory that the one of the tracking uri + monkeypatch.chdir(tmp_path) + tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns4.db").as_posix() + mlflow.set_tracking_uri(tracking_and_registry_uri) + mlflow.set_registry_uri(tracking_and_registry_uri) + + # setup: we train 3 version of a model under a single + # registered model and alias the 2nd one + runs = {} + for i in range(2): + with mlflow.start_run(): + model = DecisionTreeClassifier() + mlflow.sklearn.log_model( + model, artifact_path="demo_model", registered_model_name="demo_model" + ) + runs[i + 1] = mlflow.active_run().info.run_id + + client = MlflowClient( + tracking_uri=tracking_and_registry_uri, registry_uri=tracking_and_registry_uri + ) + client.set_registered_model_alias(name="demo_model", alias="champion", version=1) + + # case 2: an alias is provided, we take the last model with this stage + ml_ds = MlflowModelRegistryDataset(model_name="demo_model", alias="champion") + loaded_model = ml_ds.load() + assert loaded_model.metadata.run_id == runs[1]