From 664411d568417d414a8efc848a02b0fb08ce1c80 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Date: Wed, 2 Aug 2023 18:22:09 +0100 Subject: [PATCH] Allow registering of custom resolvers to `OmegaConfigLoader` (#2869) * Allow registering of custom resolvers to OCL Signed-off-by: Ankita Katiyar * Complete doc string Signed-off-by: Ankita Katiyar * Add test for overwritten resolvers Signed-off-by: Ankita Katiyar * Update test for overwritten resolvers Signed-off-by: Ankita Katiyar * Remove replace=True by default Signed-off-by: Ankita Katiyar * Update release notes Signed-off-by: Ankita Katiyar * Update release notes Signed-off-by: Ankita Katiyar * Add debug level log for registering new resolver Signed-off-by: Ankita Katiyar --------- Signed-off-by: Ankita Katiyar --- RELEASE.md | 1 + kedro/config/omegaconf_config.py | 17 ++++++++++++++++- tests/config/test_omegaconf_config.py | 22 ++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/RELEASE.md b/RELEASE.md index 2fcf553bcf..ae0714b8d9 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,7 @@ # Upcoming Release 0.18.13 ## Major features and improvements +* Allowed registering of custom resolvers to `OmegaConfigLoader` through `CONFIG_LOADER_ARGS`. ## Bug fixes and other changes diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index d7d9bd245b..4d2ace59d4 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -7,7 +7,7 @@ import logging import mimetypes from pathlib import Path -from typing import Any, Iterable +from typing import Any, Callable, Iterable import fsspec from omegaconf import OmegaConf @@ -82,6 +82,7 @@ def __init__( # noqa: too-many-arguments config_patterns: dict[str, list[str]] = None, base_env: str = "base", default_run_env: str = "local", + custom_resolvers: dict[str, Callable] = None, ): """Instantiates a ``OmegaConfigLoader``. @@ -97,6 +98,8 @@ def __init__( # noqa: too-many-arguments the configuration paths. default_run_env: Name of the default run environment. Defaults to `"local"`. Can be overridden by supplying the `env` argument. + custom_resolvers: A dictionary of custom resolvers to be registered. For more information, + see here: https://omegaconf.readthedocs.io/en/2.3_branch/custom_resolvers.html#custom-resolvers """ self.base_env = base_env self.default_run_env = default_run_env @@ -111,6 +114,9 @@ def __init__( # noqa: too-many-arguments # Deactivate oc.env built-in resolver for OmegaConf OmegaConf.clear_resolver("oc.env") + # Register user provided custom resolvers + if custom_resolvers: + self._register_new_resolvers(custom_resolvers) file_mimetype, _ = mimetypes.guess_type(conf_source) if file_mimetype == "application/x-tar": @@ -302,6 +308,15 @@ def _is_valid_config_path(self, path): ".json", ] + @staticmethod + def _register_new_resolvers(resolvers: dict[str, Callable]): + """Register custom resolvers""" + for name, resolver in resolvers.items(): + if not OmegaConf.has_resolver(name): + msg = f"Registering new custom resolver: {name}" + _config_logger.debug(msg) + OmegaConf.register_new_resolver(name=name, resolver=resolver) + @staticmethod def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]): duplicates = [] diff --git a/tests/config/test_omegaconf_config.py b/tests/config/test_omegaconf_config.py index dd49292019..af57b52224 100644 --- a/tests/config/test_omegaconf_config.py +++ b/tests/config/test_omegaconf_config.py @@ -649,3 +649,25 @@ def test_variable_interpolation_in_catalog_with_separate_templates_file( conf = OmegaConfigLoader(str(tmp_path)) conf.default_run_env = "" assert conf["catalog"]["companies"]["type"] == "pandas.CSVDataSet" + + def test_custom_resolvers(self, tmp_path): + base_params = tmp_path / _BASE_ENV / "parameters.yml" + param_config = { + "model_options": { + "param1": "${add: 3, 4}", + "param2": "${plus_2: 1}", + "param3": "${oc.env: VAR}", + } + } + _write_yaml(base_params, param_config) + custom_resolvers = { + "add": lambda *x: sum(x), + "plus_2": lambda x: x + 2, + "oc.env": oc.env, + } + os.environ["VAR"] = "my_env_variable" + conf = OmegaConfigLoader(tmp_path, custom_resolvers=custom_resolvers) + conf.default_run_env = "" + assert conf["parameters"]["model_options"]["param1"] == 7 + assert conf["parameters"]["model_options"]["param2"] == 3 + assert conf["parameters"]["model_options"]["param3"] == "my_env_variable"