Skip to content

Commit

Permalink
feat(node_evaluation): Add node evaluators that use single-step model…
Browse files Browse the repository at this point in the history
… probability
  • Loading branch information
kmaziarz committed Aug 30, 2023
1 parent 8879c0e commit 6e874b3
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 2 deletions.
53 changes: 53 additions & 0 deletions syntheseus/search/node_evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from collections.abc import Sequence
from typing import Generic, Optional, TypeVar

import numpy as np

from syntheseus.search.chem import BackwardReaction
from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph
from syntheseus.search.graph.node import BaseGraphNode

Expand Down Expand Up @@ -81,3 +84,53 @@ def _evaluate_nodes(
) -> Sequence[float]:
"""Override this method to just evaluate the nodes, without counting the number of calls."""
pass


class ReactionModelBasedEvaluator(NoCacheNodeEvaluator[NodeType]):
"""Evaluator that computes its value based on the probability from the single-step model."""

def __init__(
self,
return_log: bool,
return_negated: bool,
temperature: float = 1.0,
clip_probability_min: float = 1e-10,
clip_probability_max: float = 0.999,
) -> None:
super().__init__()

assert 0.0 <= clip_probability_min <= clip_probability_max <= 1.0

if return_log and clip_probability_min == 0.0:
raise ValueError("Disabling clipping can lead to NaNs when computing log probability")

self._return_log = return_log
self._return_negated = return_negated
self._temperature = temperature
self._clip_probability_min = clip_probability_min
self._clip_probability_max = clip_probability_max

@abc.abstractmethod
def _get_reaction(self, node, graph) -> BackwardReaction:
pass

def _get_probability(self, node, graph) -> float:
metadata = self._get_reaction(node, graph).metadata

if "probability" not in metadata:
raise ValueError("Cannot call node evaluator as reaction model probability is not set")
return metadata["probability"] # type: ignore

def _evaluate_nodes(self, nodes, graph=None) -> Sequence[float]:
probs = np.asarray([self._get_probability(n, graph) for n in nodes])
probs = np.clip(probs, a_min=self._clip_probability_min, a_max=self._clip_probability_max)

if self._return_log:
outputs = np.log(probs) / self._temperature
else:
outputs = probs ** (1.0 / self._temperature)

if self._return_negated:
outputs *= -1.0

return outputs.tolist()
32 changes: 31 additions & 1 deletion syntheseus/search/node_evaluation/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Common node evaluation functions."""

from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.chem import BackwardReaction
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.graph.molset import MolSetNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator, ReactionModelBasedEvaluator


class ConstantNodeEvaluator(NoCacheNodeEvaluator):
Expand All @@ -15,3 +18,30 @@ def _evaluate_nodes(self, nodes, graph=None):
class HasSolutionValueFunction(NoCacheNodeEvaluator):
def _evaluate_nodes(self, nodes, graph=None):
return [float(n.has_solution) for n in nodes]


class ReactionModelLogProbCost(ReactionModelBasedEvaluator[AndNode]):
def __init__(self, **kwargs) -> None:
super().__init__(return_log=True, return_negated=True, **kwargs)

def _get_reaction(self, node: AndNode, graph) -> BackwardReaction:
return node.reaction


class ReactionModelProbPolicy(ReactionModelBasedEvaluator[MolSetNode]):
def __init__(
self, clip_probability_min: float = 0.0, clip_probability_max: float = 1.0, **kwargs
) -> None:
super().__init__(
return_log=False,
return_negated=False,
clip_probability_min=clip_probability_min,
clip_probability_max=clip_probability_max,
**kwargs,
)

def _get_reaction(self, node: MolSetNode, graph) -> BackwardReaction:
parents = list(graph.predecessors(node))
assert len(parents) == 1, "Graph must be a tree"

return graph._graph.edges[parents[0], node]["reaction"]
95 changes: 94 additions & 1 deletion syntheseus/tests/search/node_evaluation/test_common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import math

import pytest

from syntheseus.search.graph.and_or import AndOrGraph
from syntheseus.search.graph.and_or import AndNode, AndOrGraph
from syntheseus.search.graph.molset import MolSetGraph
from syntheseus.search.node_evaluation.common import (
ConstantNodeEvaluator,
HasSolutionValueFunction,
ReactionModelLogProbCost,
ReactionModelProbPolicy,
)


Expand All @@ -27,3 +32,91 @@ def test_values(self, andor_graph_non_minimal: AndOrGraph) -> None:
assert val_fn.num_calls == len(
andor_graph_non_minimal
) # should have been called once per node


class TestReactionModelLogProbCost:
@pytest.mark.parametrize("temperature", [1.0, 2.0])
@pytest.mark.parametrize("clip_probability_min", [0.1, 0.5])
@pytest.mark.parametrize("clip_probability_max", [0.5, 1.0])
def test_values(
self,
andor_graph_non_minimal: AndOrGraph,
temperature: float,
clip_probability_min: float,
clip_probability_max: float,
) -> None:
val_fn = ReactionModelLogProbCost(
temperature=temperature,
clip_probability_min=clip_probability_min,
clip_probability_max=clip_probability_max,
)
nodes = [node for node in andor_graph_non_minimal.nodes() if isinstance(node, AndNode)]

# The toy model does not set reaction probabilities, so set these manually.
node_prob = {}
for idx, node in enumerate(nodes):
node_prob[node] = idx / (len(nodes) - 1)
node.reaction.metadata["probability"] = node_prob[node] # type: ignore

vals = val_fn(nodes)
for val_computed, node in zip(vals, nodes): # values should match
prob = node_prob[node]
val_expected = (
-math.log(min(clip_probability_max, max(clip_probability_min, prob))) / temperature
)

assert math.isclose(val_computed, val_expected)

assert val_fn.num_calls == len(nodes) # should have been called once per AND node

def test_enforces_min_clipping(self) -> None:
with pytest.raises(ValueError):
ReactionModelLogProbCost(clip_probability_min=0.0)


class TestReactionModelProbPolicy:
@pytest.mark.parametrize("temperature", [1.0, 2.0])
@pytest.mark.parametrize("clip_probability_min", [0.0, 0.5])
@pytest.mark.parametrize("clip_probability_max", [0.5, 1.0])
def test_values(
self,
molset_tree_non_minimal: MolSetGraph,
temperature: float,
clip_probability_min: float,
clip_probability_max: float,
) -> None:
val_fn = ReactionModelProbPolicy(
temperature=temperature,
clip_probability_min=clip_probability_min,
clip_probability_max=clip_probability_max,
)
nodes = [
node
for node in molset_tree_non_minimal.nodes()
if node != molset_tree_non_minimal.root_node
]

# The toy model does not set reaction probabilities, so set these manually.
node_prob = {}
for idx, node in enumerate(nodes):
[parent] = molset_tree_non_minimal.predecessors(node)
reaction = molset_tree_non_minimal._graph.edges[parent, node]["reaction"]

# Be careful not to overwrite things as some reactions in the graph are repeated.
if "probability" not in reaction.metadata:
reaction.metadata["probability"] = node_prob[node] = idx / (len(nodes) - 1)
else:
node_prob[node] = reaction.metadata["probability"]

vals = val_fn(nodes, graph=molset_tree_non_minimal)
for val_computed, node in zip(vals, nodes): # values should match
prob = node_prob[node]
val_expected = min(clip_probability_max, max(clip_probability_min, prob)) ** (
1.0 / temperature
)

assert math.isclose(val_computed, val_expected)

assert (
val_fn.num_calls == len(molset_tree_non_minimal) - 1
) # should have been called once per non-root node

0 comments on commit 6e874b3

Please sign in to comment.