Skip to content

Commit

Permalink
Model version level metadata (#26)
Browse files Browse the repository at this point in the history
* model version metadata

* update zenml ref
  • Loading branch information
avishniakov authored Dec 15, 2023
1 parent 71663a3 commit 30a4146
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 15 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,5 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/OSS-2529-passing-pipeline-parameters-as-yaml-and-document' }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/OSS-2574-add-metadata-to-model-versions' }}
ref-template: ${{ inputs.ref-template || github.ref }}

4 changes: 1 addition & 3 deletions template/steps/deployment/deployment_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def deployment_deploy() -> (
# deploy predictor service
deployment_service = mlflow_model_registry_deployer_step.entrypoint(
registry_model_name=model_version.name,
registry_model_version=model_version.get_model_artifact("model")
.run_metadata["model_registry_version"]
.value,
registry_model_version=model_version.metadata["model_registry_version"],
replace_existing=True,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def promote_with_metric_compare(
logger.info(f"Current model version was promoted to '{target_env}'.")

# Promote in Model Registry
latest_version_model_registry_number = latest_version.get_model_artifact("model").run_metadata["model_registry_version"].value
latest_version_model_registry_number = latest_version.metadata["model_registry_version"]
if current_version_number is None:
current_version_model_registry_number = latest_version_model_registry_number
else:
current_version_model_registry_number = current_version.get_model_artifact("model").run_metadata["model_registry_version"].value
current_version_model_registry_number = current_version.metadata["model_registry_version"]
promote_in_model_registry(
latest_version=latest_version_model_registry_number,
current_version=current_version_model_registry_number,
Expand All @@ -87,7 +87,7 @@ def promote_with_metric_compare(
)
promoted_version = latest_version_model_registry_number
else:
promoted_version = current_version.get_model_artifact("model").run_metadata["model_registry_version"].value
promoted_version = current_version.metadata["model_registry_version"]

logger.info(
f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def promote_latest_version(
logger.info(f"Current model version was promoted to '{target_env}'.")

# Promote in Model Registry
latest_version_model_registry_number = latest_version.get_model_artifact("model").run_metadata["model_registry_version"].value
latest_version_model_registry_number = latest_version.metadata["model_registry_version"]
if current_version.number is None:
current_version_model_registry_number = latest_version_model_registry_number
else:
current_version_model_registry_number = current_version.get_model_artifact("model").run_metadata["model_registry_version"].value
current_version_model_registry_number = current_version.metadata["model_registry_version"]
promote_in_model_registry(
latest_version=latest_version_model_registry_number,
current_version=current_version_model_registry_number,
Expand Down
8 changes: 3 additions & 5 deletions template/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlflow
import pandas as pd
from sklearn.base import ClassifierMixin
from zenml import ArtifactConfig, log_artifact_metadata, step
from zenml import ArtifactConfig, log_artifact_metadata, step, get_step_context
from zenml.client import Client
from zenml.integrations.mlflow.experiment_trackers import MLFlowExperimentTracker
from zenml.integrations.mlflow.steps.mlflow_registry import mlflow_register_model_step
Expand Down Expand Up @@ -84,10 +84,8 @@ def model_trainer(
if model_registry:
versions = model_registry.list_model_versions(name=name)
if versions:
log_artifact_metadata(
metadata={"model_registry_version": versions[-1].version},
artifact_name="model",
)
model_version = get_step_context().model_version
model_version.log_metadata({"model_registry_version": versions[-1].version})
### YOUR CODE ENDS HERE ###

return model

0 comments on commit 30a4146

Please sign in to comment.