Skip to content

Commit

Permalink
Allow registering of custom resolvers to OmegaConfigLoader (#2869)
Browse files Browse the repository at this point in the history
* Allow registering of custom resolvers to OCL

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Complete doc string

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Add test for overwritten resolvers

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update test for overwritten resolvers

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Remove replace=True by default

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update release notes

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update release notes

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Add debug level log for registering new resolver

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
ankatiyar authored Aug 2, 2023
1 parent 73a35bb commit 664411d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Upcoming Release 0.18.13

## Major features and improvements
* Allowed registering of custom resolvers to `OmegaConfigLoader` through `CONFIG_LOADER_ARGS`.

## Bug fixes and other changes

Expand Down
17 changes: 16 additions & 1 deletion kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import mimetypes
from pathlib import Path
from typing import Any, Iterable
from typing import Any, Callable, Iterable

import fsspec
from omegaconf import OmegaConf
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__( # noqa: too-many-arguments
config_patterns: dict[str, list[str]] = None,
base_env: str = "base",
default_run_env: str = "local",
custom_resolvers: dict[str, Callable] = None,
):
"""Instantiates a ``OmegaConfigLoader``.
Expand All @@ -97,6 +98,8 @@ def __init__( # noqa: too-many-arguments
the configuration paths.
default_run_env: Name of the default run environment. Defaults to `"local"`.
Can be overridden by supplying the `env` argument.
custom_resolvers: A dictionary of custom resolvers to be registered. For more information,
see here: https://omegaconf.readthedocs.io/en/2.3_branch/custom_resolvers.html#custom-resolvers
"""
self.base_env = base_env
self.default_run_env = default_run_env
Expand All @@ -111,6 +114,9 @@ def __init__( # noqa: too-many-arguments

# Deactivate oc.env built-in resolver for OmegaConf
OmegaConf.clear_resolver("oc.env")
# Register user provided custom resolvers
if custom_resolvers:
self._register_new_resolvers(custom_resolvers)

file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
Expand Down Expand Up @@ -302,6 +308,15 @@ def _is_valid_config_path(self, path):
".json",
]

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
"""Register custom resolvers"""
for name, resolver in resolvers.items():
if not OmegaConf.has_resolver(name):
msg = f"Registering new custom resolver: {name}"
_config_logger.debug(msg)
OmegaConf.register_new_resolver(name=name, resolver=resolver)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
duplicates = []
Expand Down
22 changes: 22 additions & 0 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,25 @@ def test_variable_interpolation_in_catalog_with_separate_templates_file(
conf = OmegaConfigLoader(str(tmp_path))
conf.default_run_env = ""
assert conf["catalog"]["companies"]["type"] == "pandas.CSVDataSet"

def test_custom_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
param_config = {
"model_options": {
"param1": "${add: 3, 4}",
"param2": "${plus_2: 1}",
"param3": "${oc.env: VAR}",
}
}
_write_yaml(base_params, param_config)
custom_resolvers = {
"add": lambda *x: sum(x),
"plus_2": lambda x: x + 2,
"oc.env": oc.env,
}
os.environ["VAR"] = "my_env_variable"
conf = OmegaConfigLoader(tmp_path, custom_resolvers=custom_resolvers)
conf.default_run_env = ""
assert conf["parameters"]["model_options"]["param1"] == 7
assert conf["parameters"]["model_options"]["param2"] == 3
assert conf["parameters"]["model_options"]["param3"] == "my_env_variable"

0 comments on commit 664411d

Please sign in to comment.