Skip to content

Commit

Permalink
Make ReactionModelProbPolicy work for AndNode
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingliu_microsoft committed Sep 13, 2023
1 parent 2fc93b5 commit df84065
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion syntheseus/search/algorithms/pdvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
See: https://arxiv.org/abs/2301.13755
In particular, this file contains code for the MCTS algorithm used to make training data for PDVN,
In particular, this file contains code for the And/OR-MCTS algorithm used to make training data for PDVN,
and the code to extract training data from a completed search graph.
"""

Expand Down
17 changes: 10 additions & 7 deletions syntheseus/search/node_evaluation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Union

from syntheseus.search.chem import BackwardReaction
from syntheseus.search.graph.and_or import AndNode
Expand Down Expand Up @@ -37,15 +38,17 @@ def _evaluate_nodes(self, nodes, graph=None) -> Sequence[float]:
return [-v for v in super()._evaluate_nodes(nodes, graph)]


class ReactionModelProbPolicy(ReactionModelBasedEvaluator[MolSetNode]):
"""Evaluator that uses the reactions' probability to form a policy (useful for OR-MCTS)."""
class ReactionModelProbPolicy(ReactionModelBasedEvaluator[Union[MolSetNode, AndNode]]):
"""Evaluator that uses the reactions' probability to form a policy (useful for OR-MCTS and And/OR-MCTS)."""

def __init__(self, **kwargs) -> None:
kwargs["normalize"] = kwargs.get("normalize", True) # set `normalize = True` by default
super().__init__(return_log=False, **kwargs)

def _get_reaction(self, node: MolSetNode, graph) -> BackwardReaction:
parents = list(graph.predecessors(node))
assert len(parents) == 1, "Graph must be a tree"

return graph._graph.edges[parents[0], node]["reaction"]
def _get_reaction(self, node: Union[MolSetNode, AndNode], graph) -> BackwardReaction:
if isinstance(node, MolSetNode):
parents = list(graph.predecessors(node))
assert len(parents) == 1, "Graph must be a tree"
return graph._graph.edges[parents[0], node]["reaction"]
else:
return node.reaction

0 comments on commit df84065

Please sign in to comment.