Skip to content

Commit

Permalink
Remove replace=True by default
Browse files Browse the repository at this point in the history
Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
ankatiyar committed Aug 1, 2023
1 parent b03b1e4 commit 7420b28
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 28 deletions.
3 changes: 2 additions & 1 deletion kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def _is_valid_config_path(self, path):
def _register_new_resolvers(resolvers: dict[str, Callable]):
"""Register custom resolvers"""
for name, resolver in resolvers.items():
OmegaConf.register_new_resolver(name=name, resolver=resolver, replace=True)
if not OmegaConf.has_resolver("name"):
OmegaConf.register_new_resolver(name=name, resolver=resolver)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
Expand Down
36 changes: 9 additions & 27 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,38 +654,20 @@ def test_custom_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
param_config = {
"model_options": {
"test_size": "${add: 3, 4}",
"random_state": "${plus_2: 1}",
"param1": "${add: 3, 4}",
"param2": "${plus_2: 1}",
"param3": "${oc.env: VAR}",
}
}
_write_yaml(base_params, param_config)
custom_resolvers = {
"add": lambda *x: sum(x),
"plus_2": lambda x: x + 2,
"oc.env": oc.env,
}
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers)
os.environ["VAR"] = "my_env_variable"
conf = OmegaConfigLoader(tmp_path, custom_resolvers=custom_resolvers)
conf.default_run_env = ""
assert conf["parameters"]["model_options"]["test_size"] == 7
assert conf["parameters"]["model_options"]["random_state"] == 3

def test_overwrite_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
# OmegaConf is a singleton, register a resolver to be overwritten
OmegaConf.register_new_resolver("custom", lambda x: x + 10)

param_config = {
"model_options": {
"test_size": "${custom: 10}",
}
}
_write_yaml(base_params, param_config)
conf_original = OmegaConf.load(base_params)
# test_size should be calculated using custom resolver (x + 10)
assert conf_original["model_options"]["test_size"] == 20
custom_resolvers = {
"custom": lambda x: x + 20,
}
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers)
conf.default_run_env = ""
# test_size should be calculated using overwritten custom resolver (x + 20)
assert conf["parameters"]["model_options"]["test_size"] == 30
assert conf["parameters"]["model_options"]["param1"] == 7
assert conf["parameters"]["model_options"]["param2"] == 3
assert conf["parameters"]["model_options"]["param3"] == "my_env_variable"

0 comments on commit 7420b28

Please sign in to comment.