Skip to content

Commit

Permalink
🐛 Replace get_all_metrics in MLflowMetricsHistoryDataset to be compat…
Browse files Browse the repository at this point in the history
…ible with all Mlflow stores (#582) (#591)

* 🐛 Replace get_all_metrics in MLflowMetricsHistoryDataset to be compatible with all Mlflow stores (#582)

* changelog

* bump changelog
  • Loading branch information
Galileo-Galilei authored Sep 24, 2024
1 parent f066d7d commit 089e7aa
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 43 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

- :sparkles: Add support for "modern" datasets ([introduced in Kedro 0.19.7](https://github.com/kedro-org/kedro/commit/52458c2addd1827623d06c20228b709052a5fdf3)) that expose `load` and `save` publicly ([#590, deepyaman](https://github.com/Galileo-Galilei/kedro-mlflow/pull/590))

### Fixed

- :bug: Refactor ``MlflowMetricsHistoryDataset`` to avoid using ``get_all_metrics`` internally because this is cannot save metrics on a remote server ([#582](https://github.com/Galileo-Galilei/kedro-mlflow/issues/582))

## [0.13.0] - 2024-09-01

### Added
Expand Down
104 changes: 61 additions & 43 deletions kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial, reduce
from functools import partial
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -71,20 +71,21 @@ def _load(self) -> MetricsDict:
Dict[str, Union[int, float]]: Dictionary with MLflow metrics dataset.
"""
client = MlflowClient()
run_id = self.run_id
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id)
dataset_metrics = filter(self._is_dataset_metric, all_metrics)
dataset = reduce(
lambda xs, x: self._update_metric(
# get_all_metrics returns last saved values per metric key.
# All values are required here.
client.get_metric_history(run_id, x.key),
xs,
),
dataset_metrics,
{},
)
return dataset

all_metrics_keys = list(client.get_run(self.run_id).data.metrics.keys())

dataset_metrics_keys = [
key for key in all_metrics_keys if self._is_dataset_metric(key)
]

dataset_metrics = {
key: self._convert_metric_history_to_list_or_dict(
client.get_metric_history(self.run_id, key)
)
for key in dataset_metrics_keys
}

return dataset_metrics

def _save(self, data: MetricsDict) -> None:
"""Save given MLflow metrics dataset and log it in MLflow as metrics.
Expand Down Expand Up @@ -119,10 +120,11 @@ def _exists(self) -> bool:
bool: Is MLflow metrics dataset exists?
"""
client = MlflowClient()
all_metrics = client._tracking_client.store.get_all_metrics(
run_uuid=self.run_id
)
return any(self._is_dataset_metric(x) for x in all_metrics)
all_metrics_keys = client.get_run(self.run_id).data.metrics.keys()
# all_metrics = client._tracking_client.store.get_all_metrics(
# run_uuid=self.run_id
# )
return any(self._is_dataset_metric(x) for x in all_metrics_keys)

def _describe(self) -> Dict[str, Any]:
"""Describe MLflow metrics dataset.
Expand All @@ -135,39 +137,55 @@ def _describe(self) -> Dict[str, Any]:
"prefix": self._prefix,
}

def _is_dataset_metric(self, metric: mlflow.entities.Metric) -> bool:
def _is_dataset_metric(self, key: str) -> bool:
"""Check if given metric belongs to dataset.
Args:
metric (mlflow.entities.Metric): MLflow metric instance.
key str: The name of a mlflow metric registered in the run
"""
return self._prefix is None or (
self._prefix and metric.key.startswith(self._prefix)
)
return self._prefix is None or (self._prefix and key.startswith(self._prefix))

# @staticmethod
# def _update_metric(
# metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {}
# ) -> MetricsDict:
# """Update metric in given dataset.

# Args:
# metrics (List[mlflow.entities.Metric]): List with MLflow metric objects.
# dataset (MetricsDict): Dictionary contains MLflow metrics dataset.

# Returns:
# MetricsDict: Dictionary with MLflow metrics dataset.
# """
# for metric in metrics:
# metric_dict = {"step": metric.step, "value": metric.value}
# if metric.key in dataset:
# if isinstance(dataset[metric.key], list):
# dataset[metric.key].append(metric_dict)
# else:
# dataset[metric.key] = [dataset[metric.key], metric_dict]
# else:
# dataset[metric.key] = metric_dict
# return dataset

@staticmethod
def _update_metric(
metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {}
) -> MetricsDict:
"""Update metric in given dataset.
def _convert_metric_history_to_list_or_dict(
metrics: List[mlflow.entities.Metric],
) -> Dict[str, Dict[str, Union[float, List[float]]]]:
"""Convert Mlflow metrics objects from MlflowClient().get_metric_history(run_id, key)
to a list [{'step': x, 'value': y}, {'step': ..., 'value': ...}]
Args:
metrics (List[mlflow.entities.Metric]): List with MLflow metric objects.
dataset (MetricsDict): Dictionary contains MLflow metrics dataset.
Returns:
MetricsDict: Dictionary with MLflow metrics dataset.
metrics (List[mlflow.entities.Metric]): A list of MLflow Metrics retrieved from the run history
"""
for metric in metrics:
metric_dict = {"step": metric.step, "value": metric.value}
if metric.key in dataset:
if isinstance(dataset[metric.key], list):
dataset[metric.key].append(metric_dict)
else:
dataset[metric.key] = [dataset[metric.key], metric_dict]
else:
dataset[metric.key] = metric_dict
return dataset
metrics_as_list = [
{"step": metric.step, "value": metric.value} for metric in metrics
]
metrics_result = (
metrics_as_list[0] if len(metrics_as_list) == 1 else metrics_as_list
)
return metrics_result

def _build_args_list_from_metric_item(
self, key: str, value: MetricItem
Expand Down

0 comments on commit 089e7aa

Please sign in to comment.