Skip to content

Commit

Permalink
Fix error handling in MEGAN (#29)
Browse files Browse the repository at this point in the history
In search experiments, it can be observed that MEGAN sometimes produces
odd molecules containing `C+` atoms, which violate valence despite being
accepted as valid by `rdkit`. These molecules fail sanitization through
MEGAN's `fix_explicit_hs` utility, and so far our wrapper has been
ignoring this and passing the molecule through to the model to propose
reactions (likely pushing the problematic `C+` atom to reactants). In
this PR, the underlying model is not called on inputs which fail
`fix_explicit_hs`, and rather an empty list of reactions is returned in
this case. Relatedly, I also silenced some of the other internal
warnings produced by the MEGAN model by applying a (now improved)
`suppress_outputs` context manager.
  • Loading branch information
kmaziarz authored Sep 12, 2023
1 parent 8072ac1 commit 8516d31
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
- Fix bug where standardizing MolSetGraphs crashed ([#24](https://github.com/microsoft/syntheseus/pull/24)) ([@austint])
- Change default node depth to infinity ([#16](https://github.com/microsoft/syntheseus/pull/16)) ([@austint])
- Adapt tutorials to the renaming from PR #9 ([#17](https://github.com/microsoft/syntheseus/pull/17)) ([@jagarridotorres])
- Fix error handling in MEGAN ([#29](https://github.com/microsoft/syntheseus/pull/29)) ([@kmaziarz])
- Pin `pydantic` version to `1.*` ([#10](https://github.com/microsoft/syntheseus/pull/10)) ([@kmaziarz])
- Fix compatibility with Python 3.7 ([#5](https://github.com/microsoft/syntheseus/pull/5)) ([@kmaziarz])
- Pin `zipp` version to `<3.16` ([#11](https://github.com/microsoft/syntheseus/pull/11)) ([@kmaziarz])
Expand Down
62 changes: 37 additions & 25 deletions syntheseus/reaction_prediction/inference/megan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
The original MEGAN code is released under the MIT license.
"""

from __future__ import annotations

import os
import sys
from pathlib import Path
from typing import List, Union
from typing import Any, Optional, Union

from rdkit import Chem

Expand All @@ -20,6 +22,7 @@
get_unique_file_in_dir,
process_raw_smiles_outputs,
)
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class MEGANModel(BackwardReactionModel):
Expand Down Expand Up @@ -89,7 +92,7 @@ def __init__(
def get_parameters(self):
return self.model.parameters()

def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]:
def _mols_to_batch(self, inputs: list[Molecule]) -> list[Optional[Chem.Mol]]:
from src.feat.utils import fix_explicit_hs

# Inputs to the model are list of `rdkit` molecules.
Expand All @@ -102,43 +105,52 @@ def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]:
a.SetAtomMapNum(i + 1)

try:
mol = fix_explicit_hs(mol)
input_batch.append(fix_explicit_hs(mol))
except Exception:
# Sometimes `fix_explicit_hs` may fail with an `rdkit` error. In such cases we give
# up and use the molecule as-is.
# TODO(kmaziarz): Investigate these cases in more detail.
pass

input_batch.append(mol)
# MEGAN sometimes produces broken molecules containing C+ atoms which pass `rdkit`
# sanitization but fail in `fix_explicit_hs`. We block these here to avoid making
# predictions for them.
input_batch.append(None)

return input_batch

def __call__(self, inputs: List[Molecule], num_results: int) -> List[BackwardPredictionList]:
def __call__(self, inputs: list[Molecule], num_results: int) -> list[BackwardPredictionList]:
import torch
from src.model.beam_search import beam_search

# Get the inputs into the right form to call the underlying model.
batch = self._mols_to_batch(inputs)

with torch.no_grad():
beam_search_results = beam_search(
[self.model],
batch,
rdkit_cache=self.rdkit_cache,
max_steps=self.max_gen_steps,
beam_size=num_results,
batch_size=self.beam_batch_size,
base_action_masks=self.base_action_masks,
max_atoms=self.n_max_atoms,
reaction_types=None,
action_vocab=self.action_vocab,
) # returns a list of `beam_size` results for each input molecule.
batch_valid = [mol for mol in batch if mol is not None]
batch_valid_idxs = [idx for idx, mol in enumerate(batch) if mol is not None]

if batch_valid:
with torch.no_grad(), suppress_outputs():
beam_search_results = beam_search(
[self.model],
batch_valid,
rdkit_cache=self.rdkit_cache,
max_steps=self.max_gen_steps,
beam_size=num_results,
batch_size=self.beam_batch_size,
base_action_masks=self.base_action_masks,
max_atoms=self.n_max_atoms,
reaction_types=None,
action_vocab=self.action_vocab,
) # returns a list of `beam_size` results for each input molecule
else:
beam_search_results = []

assert len(batch_valid_idxs) == len(beam_search_results)

all_outputs: list[list[dict[str, Any]]] = [[] for _ in batch]
for idx, raw_outputs in zip(batch_valid_idxs, beam_search_results):
all_outputs[idx] = raw_outputs

return [
process_raw_smiles_outputs(
input=input,
output_list=[prediction["final_smi_unmapped"] for prediction in raw_outputs],
kwargs_list=[{"probability": prediction["prob"]} for prediction in raw_outputs],
)
for input, raw_outputs in zip(inputs, beam_search_results)
for input, raw_outputs in zip(inputs, all_outputs)
]
3 changes: 3 additions & 0 deletions syntheseus/reaction_prediction/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import multiprocessing
import os
import random
Expand Down Expand Up @@ -33,7 +34,9 @@ def suppress_outputs():
"""Suppress messages written to both stdout and stderr."""
with open(devnull, "w") as fnull:
with redirect_stderr(fnull), redirect_stdout(fnull):
logging.disable(logging.CRITICAL)
yield
logging.disable(logging.NOTSET)


def dictify(data: Any) -> Any:
Expand Down

0 comments on commit 8516d31

Please sign in to comment.