Skip to content

Commit

Permalink
Add test for overwritten resolvers
Browse files Browse the repository at this point in the history
Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
ankatiyar committed Jul 31, 2023
1 parent 89ebb12 commit 3163c5e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _is_valid_config_path(self, path):
def _register_new_resolvers(resolvers: dict[str, Callable]):
"""Register custom resolvers"""
for name, resolver in resolvers.items():
OmegaConf.register_new_resolver(name, resolver, replace=True)
OmegaConf.register_new_resolver(name=name, resolver=resolver, replace=True)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
Expand Down
19 changes: 19 additions & 0 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,22 @@ def test_custom_resolvers(self, tmp_path):
conf.default_run_env = ""
assert conf["parameters"]["model_options"]["test_size"] == 7
assert conf["parameters"]["model_options"]["random_state"] == 3

def test_overwrite_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
# OmegaConf is a singleton, register a resolver to be overwritten
OmegaConf.register_new_resolver("custom", lambda x: x + 10)

param_config = {
"model_options": {
"test_size": "${custom: 10}",
}
}
_write_yaml(base_params, param_config)
custom_resolvers = {
"custom": lambda x: x + 20,
}
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers)
conf.default_run_env = ""
# test_size should be calculated using overwritten custom resolver (x + 20)
assert conf["parameters"]["model_options"]["test_size"] == 30

0 comments on commit 3163c5e

Please sign in to comment.