diff --git a/CHANGELOG.md b/CHANGELOG.md index 950ffcbc..3dc495f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - :sparkles: Add support for "modern" datasets ([introduced in Kedro 0.19.7](https://github.com/kedro-org/kedro/commit/52458c2addd1827623d06c20228b709052a5fdf3)) that expose `load` and `save` publicly ([#590, deepyaman](https://github.com/Galileo-Galilei/kedro-mlflow/pull/590)) +### Fixed + +- :bug: Refactor ``MlflowMetricsHistoryDataset`` to avoid using ``get_all_metrics`` internally because this is cannot save metrics on a remote server ([#582](https://github.com/Galileo-Galilei/kedro-mlflow/issues/582)) + ## [0.13.0] - 2024-09-01 ### Added diff --git a/kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py b/kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py index fb4d157f..6e9e84ae 100644 --- a/kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py +++ b/kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py @@ -1,4 +1,4 @@ -from functools import partial, reduce +from functools import partial from itertools import chain from typing import Any, Dict, Generator, List, Optional, Tuple, Union @@ -71,20 +71,21 @@ def _load(self) -> MetricsDict: Dict[str, Union[int, float]]: Dictionary with MLflow metrics dataset. """ client = MlflowClient() - run_id = self.run_id - all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id) - dataset_metrics = filter(self._is_dataset_metric, all_metrics) - dataset = reduce( - lambda xs, x: self._update_metric( - # get_all_metrics returns last saved values per metric key. - # All values are required here. - client.get_metric_history(run_id, x.key), - xs, - ), - dataset_metrics, - {}, - ) - return dataset + + all_metrics_keys = list(client.get_run(self.run_id).data.metrics.keys()) + + dataset_metrics_keys = [ + key for key in all_metrics_keys if self._is_dataset_metric(key) + ] + + dataset_metrics = { + key: self._convert_metric_history_to_list_or_dict( + client.get_metric_history(self.run_id, key) + ) + for key in dataset_metrics_keys + } + + return dataset_metrics def _save(self, data: MetricsDict) -> None: """Save given MLflow metrics dataset and log it in MLflow as metrics. @@ -119,10 +120,11 @@ def _exists(self) -> bool: bool: Is MLflow metrics dataset exists? """ client = MlflowClient() - all_metrics = client._tracking_client.store.get_all_metrics( - run_uuid=self.run_id - ) - return any(self._is_dataset_metric(x) for x in all_metrics) + all_metrics_keys = client.get_run(self.run_id).data.metrics.keys() + # all_metrics = client._tracking_client.store.get_all_metrics( + # run_uuid=self.run_id + # ) + return any(self._is_dataset_metric(x) for x in all_metrics_keys) def _describe(self) -> Dict[str, Any]: """Describe MLflow metrics dataset. @@ -135,39 +137,55 @@ def _describe(self) -> Dict[str, Any]: "prefix": self._prefix, } - def _is_dataset_metric(self, metric: mlflow.entities.Metric) -> bool: + def _is_dataset_metric(self, key: str) -> bool: """Check if given metric belongs to dataset. Args: - metric (mlflow.entities.Metric): MLflow metric instance. + key str: The name of a mlflow metric registered in the run """ - return self._prefix is None or ( - self._prefix and metric.key.startswith(self._prefix) - ) + return self._prefix is None or (self._prefix and key.startswith(self._prefix)) + + # @staticmethod + # def _update_metric( + # metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {} + # ) -> MetricsDict: + # """Update metric in given dataset. + + # Args: + # metrics (List[mlflow.entities.Metric]): List with MLflow metric objects. + # dataset (MetricsDict): Dictionary contains MLflow metrics dataset. + + # Returns: + # MetricsDict: Dictionary with MLflow metrics dataset. + # """ + # for metric in metrics: + # metric_dict = {"step": metric.step, "value": metric.value} + # if metric.key in dataset: + # if isinstance(dataset[metric.key], list): + # dataset[metric.key].append(metric_dict) + # else: + # dataset[metric.key] = [dataset[metric.key], metric_dict] + # else: + # dataset[metric.key] = metric_dict + # return dataset @staticmethod - def _update_metric( - metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {} - ) -> MetricsDict: - """Update metric in given dataset. + def _convert_metric_history_to_list_or_dict( + metrics: List[mlflow.entities.Metric], + ) -> Dict[str, Dict[str, Union[float, List[float]]]]: + """Convert Mlflow metrics objects from MlflowClient().get_metric_history(run_id, key) + to a list [{'step': x, 'value': y}, {'step': ..., 'value': ...}] Args: - metrics (List[mlflow.entities.Metric]): List with MLflow metric objects. - dataset (MetricsDict): Dictionary contains MLflow metrics dataset. - - Returns: - MetricsDict: Dictionary with MLflow metrics dataset. + metrics (List[mlflow.entities.Metric]): A list of MLflow Metrics retrieved from the run history """ - for metric in metrics: - metric_dict = {"step": metric.step, "value": metric.value} - if metric.key in dataset: - if isinstance(dataset[metric.key], list): - dataset[metric.key].append(metric_dict) - else: - dataset[metric.key] = [dataset[metric.key], metric_dict] - else: - dataset[metric.key] = metric_dict - return dataset + metrics_as_list = [ + {"step": metric.step, "value": metric.value} for metric in metrics + ] + metrics_result = ( + metrics_as_list[0] if len(metrics_as_list) == 1 else metrics_as_list + ) + return metrics_result def _build_args_list_from_metric_item( self, key: str, value: MetricItem