diff --git a/syntheseus/cli/search.py b/syntheseus/cli/search.py index 3169a664..311f3840 100644 --- a/syntheseus/cli/search.py +++ b/syntheseus/cli/search.py @@ -25,6 +25,7 @@ from pprint import pformat from typing import Any, Dict, Iterator, List, Optional, cast +import yaml from omegaconf import MISSING, DictConfig, OmegaConf from tqdm import tqdm @@ -73,7 +74,7 @@ class MCTSConfig: policy_class: str = "ReactionModelProbPolicy" policy_kwargs: Dict[str, Any] = field(default_factory=dict) - bound_constant: float = 1e2 + bound_constant: float = 1.0 bound_function_class: str = "pucb_bound" @@ -307,6 +308,43 @@ def build_node_evaluator(key: str) -> None: def main(argv: Optional[List[str]]) -> None: config: SearchConfig = cli_get_config(argv=argv, config_cls=SearchConfig) + + def _warn_will_not_use_defaults(message: str) -> None: + logger.warning(f"{message}; no model-specific search hyperparameters will be used") + + defaults_file_path = Path(__file__).parent / "search_config.yml" + if not defaults_file_path.exists(): + _warn_will_not_use_defaults(f"File {defaults_file_path} does not exist") + else: + with open(defaults_file_path, "rt") as f_defaults: + defaults = yaml.safe_load(f_defaults) + + if config.search_algorithm not in defaults: + _warn_will_not_use_defaults( + f"Hyperparameter defaults file has no entry for {config.search_algorithm}" + ) + else: + search_algorithm_defaults = defaults[config.search_algorithm] + + model_name = config.model_class.name + if model_name not in search_algorithm_defaults: + _warn_will_not_use_defaults( + f"Hyperparameter defaults file has no entry for {model_name}" + ) + else: + relevant_defaults = search_algorithm_defaults[model_name] + logger.info( + f"Using hyperparameter defaults from {defaults_file_path}: {relevant_defaults}" + ) + + # We now parse the config again (we could not have included the defaults earlier as + # we did not know the search algorithm and model class before the first parsing). + config = cli_get_config( + argv=argv, + config_cls=SearchConfig, + defaults={f"{config.search_algorithm}_config": relevant_defaults}, + ) + run_from_config(config) diff --git a/syntheseus/cli/search_config.yml b/syntheseus/cli/search_config.yml new file mode 100644 index 00000000..2d4f7f6c --- /dev/null +++ b/syntheseus/cli/search_config.yml @@ -0,0 +1,86 @@ +mcts: + Chemformer: + bound_constant: 1 + policy_kwargs: + clip_probability_max: 0.9 + clip_probability_min: 1.0e-06 + temperature: 8.0 + value_function_kwargs: + constant: 0.75 + GLN: + bound_constant: 100 + policy_kwargs: + clip_probability_max: 0.9 + clip_probability_min: 1.0e-07 + temperature: 4.0 + value_function_kwargs: + constant: 0.5 + LocalRetro: + bound_constant: 1 + policy_kwargs: + clip_probability_max: 0.99 + clip_probability_min: 1.0e-05 + temperature: 0.5 + value_function_kwargs: + constant: 0.5 + MEGAN: + bound_constant: 1 + policy_kwargs: + clip_probability_max: 0.9999 + clip_probability_min: 1.0e-05 + temperature: 2.0 + value_function_kwargs: + constant: 0.75 + MHNreact: + bound_constant: 1 + policy_kwargs: + clip_probability_max: 0.9 + clip_probability_min: 1.0e-11 + temperature: 8.0 + value_function_kwargs: + constant: 0.5 + RetroKNN: + bound_constant: 1 + policy_kwargs: + clip_probability_max: 0.9 + clip_probability_min: 1.0e-07 + temperature: 8.0 + value_function_kwargs: + constant: 0.75 + RootAligned: + bound_constant: 10 + policy_kwargs: + clip_probability_max: 0.999 + clip_probability_min: 1.0e-05 + temperature: 8.0 + value_function_kwargs: + constant: 0.5 +retro_star: + Chemformer: + and_node_cost_fn_kwargs: + clip_probability_max: 0.999 + clip_probability_min: 1.0e-09 + GLN: + and_node_cost_fn_kwargs: + clip_probability_max: 0.999 + clip_probability_min: 1.0e-05 + LocalRetro: + and_node_cost_fn_kwargs: + clip_probability_max: 0.999 + clip_probability_min: 1.0e-05 + MEGAN: + and_node_cost_fn_kwargs: + clip_probability_max: 0.99 + clip_probability_min: 1.0e-05 + MHNreact: + and_node_cost_fn_kwargs: + clip_probability_max: 0.9999 + clip_probability_min: 1.0e-08 + RetroKNN: + and_node_cost_fn_kwargs: + clip_probability_max: 0.999 + clip_probability_min: 1.0e-06 + RootAligned: + and_node_cost_fn_kwargs: + clip_probability_max: 0.99 + clip_probability_min: 1.0e-10 diff --git a/syntheseus/reaction_prediction/utils/config.py b/syntheseus/reaction_prediction/utils/config.py index 1a70075e..6d0d4801 100644 --- a/syntheseus/reaction_prediction/utils/config.py +++ b/syntheseus/reaction_prediction/utils/config.py @@ -1,13 +1,17 @@ import argparse import sys -from typing import Callable, List, Optional, TypeVar, cast +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast -from omegaconf import OmegaConf +from omegaconf import DictConfig, ListConfig, OmegaConf R = TypeVar("R") -def get_config(argv: Optional[List[str]], config_cls: Callable[..., R]) -> R: +def get_config( + argv: Optional[List[str]], + config_cls: Callable[..., R], + defaults: Optional[Dict[str, Any]] = None, +) -> R: """ Utility function to get `OmegaConf` config options. @@ -37,8 +41,12 @@ def get_config(argv: Optional[List[str]], config_cls: Callable[..., R]) -> R: ) args, config_changes = parser.parse_known_args(argv) - # Read configs from file and command line - conf_yamls = [OmegaConf.load(c) for c in args.config] + # Read configs from defaults, file and command line + conf_yamls: List[Union[DictConfig, ListConfig]] = [] + if defaults: + conf_yamls = [OmegaConf.create(defaults)] + + conf_yamls += [OmegaConf.load(c) for c in args.config] conf_cli = OmegaConf.from_cli(config_changes) # Make merged config options