Skip to content

Commit

Permalink
feat(cli): Select search hyperparameters based on algorithm and singl…
Browse files Browse the repository at this point in the history
…e-step model
  • Loading branch information
kmaziarz committed Sep 12, 2023
1 parent f7bc9b8 commit 68f25f9
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 6 deletions.
40 changes: 39 additions & 1 deletion syntheseus/cli/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pprint import pformat
from typing import Any, Dict, Iterator, List, Optional, cast

import yaml
from omegaconf import MISSING, DictConfig, OmegaConf
from tqdm import tqdm

Expand Down Expand Up @@ -73,7 +74,7 @@ class MCTSConfig:
policy_class: str = "ReactionModelProbPolicy"
policy_kwargs: Dict[str, Any] = field(default_factory=dict)

bound_constant: float = 1e2
bound_constant: float = 1.0
bound_function_class: str = "pucb_bound"


Expand Down Expand Up @@ -307,6 +308,43 @@ def build_node_evaluator(key: str) -> None:

def main(argv: Optional[List[str]]) -> None:
config: SearchConfig = cli_get_config(argv=argv, config_cls=SearchConfig)

def _warn_will_not_use_defaults(message: str) -> None:
logger.warning(f"{message}; no model-specific search hyperparameters will be used")

defaults_file_path = Path(__file__).parent / "search_config.yml"
if not defaults_file_path.exists():
_warn_will_not_use_defaults(f"File {defaults_file_path} does not exist")
else:
with open(defaults_file_path, "rt") as f_defaults:
defaults = yaml.safe_load(f_defaults)

if config.search_algorithm not in defaults:
_warn_will_not_use_defaults(
f"Hyperparameter defaults file has no entry for {config.search_algorithm}"
)
else:
search_algorithm_defaults = defaults[config.search_algorithm]

model_name = config.model_class.name
if model_name not in search_algorithm_defaults:
_warn_will_not_use_defaults(
f"Hyperparameter defaults file has no entry for {model_name}"
)
else:
relevant_defaults = search_algorithm_defaults[model_name]
logger.info(
f"Using hyperparameter defaults from {defaults_file_path}: {relevant_defaults}"
)

# We now parse the config again (we could not have included the defaults earlier as
# we did not know the search algorithm and model class before the first parsing).
config = cli_get_config(
argv=argv,
config_cls=SearchConfig,
defaults={f"{config.search_algorithm}_config": relevant_defaults},
)

run_from_config(config)


Expand Down
86 changes: 86 additions & 0 deletions syntheseus/cli/search_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
mcts:
Chemformer:
bound_constant: 1
policy_kwargs:
clip_probability_max: 0.9
clip_probability_min: 1.0e-06
temperature: 8.0
value_function_kwargs:
constant: 0.75
GLN:
bound_constant: 100
policy_kwargs:
clip_probability_max: 0.9
clip_probability_min: 1.0e-07
temperature: 4.0
value_function_kwargs:
constant: 0.5
LocalRetro:
bound_constant: 1
policy_kwargs:
clip_probability_max: 0.99
clip_probability_min: 1.0e-05
temperature: 0.5
value_function_kwargs:
constant: 0.5
MEGAN:
bound_constant: 1
policy_kwargs:
clip_probability_max: 0.9999
clip_probability_min: 1.0e-05
temperature: 2.0
value_function_kwargs:
constant: 0.75
MHNreact:
bound_constant: 1
policy_kwargs:
clip_probability_max: 0.9
clip_probability_min: 1.0e-11
temperature: 8.0
value_function_kwargs:
constant: 0.5
RetroKNN:
bound_constant: 1
policy_kwargs:
clip_probability_max: 0.9
clip_probability_min: 1.0e-07
temperature: 8.0
value_function_kwargs:
constant: 0.75
RootAligned:
bound_constant: 10
policy_kwargs:
clip_probability_max: 0.999
clip_probability_min: 1.0e-05
temperature: 8.0
value_function_kwargs:
constant: 0.5
retro_star:
Chemformer:
and_node_cost_fn_kwargs:
clip_probability_max: 0.999
clip_probability_min: 1.0e-09
GLN:
and_node_cost_fn_kwargs:
clip_probability_max: 0.999
clip_probability_min: 1.0e-05
LocalRetro:
and_node_cost_fn_kwargs:
clip_probability_max: 0.999
clip_probability_min: 1.0e-05
MEGAN:
and_node_cost_fn_kwargs:
clip_probability_max: 0.99
clip_probability_min: 1.0e-05
MHNreact:
and_node_cost_fn_kwargs:
clip_probability_max: 0.9999
clip_probability_min: 1.0e-08
RetroKNN:
and_node_cost_fn_kwargs:
clip_probability_max: 0.999
clip_probability_min: 1.0e-06
RootAligned:
and_node_cost_fn_kwargs:
clip_probability_max: 0.99
clip_probability_min: 1.0e-10
18 changes: 13 additions & 5 deletions syntheseus/reaction_prediction/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import argparse
import sys
from typing import Callable, List, Optional, TypeVar, cast
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast

from omegaconf import OmegaConf
from omegaconf import DictConfig, ListConfig, OmegaConf

R = TypeVar("R")


def get_config(argv: Optional[List[str]], config_cls: Callable[..., R]) -> R:
def get_config(
argv: Optional[List[str]],
config_cls: Callable[..., R],
defaults: Optional[Dict[str, Any]] = None,
) -> R:
"""
Utility function to get `OmegaConf` config options.
Expand Down Expand Up @@ -37,8 +41,12 @@ def get_config(argv: Optional[List[str]], config_cls: Callable[..., R]) -> R:
)
args, config_changes = parser.parse_known_args(argv)

# Read configs from file and command line
conf_yamls = [OmegaConf.load(c) for c in args.config]
# Read configs from defaults, file and command line
conf_yamls: List[Union[DictConfig, ListConfig]] = []
if defaults:
conf_yamls = [OmegaConf.create(defaults)]

conf_yamls += [OmegaConf.load(c) for c in args.config]
conf_cli = OmegaConf.from_cli(config_changes)

# Make merged config options
Expand Down

0 comments on commit 68f25f9

Please sign in to comment.