From 8516d3108070f97a3e31f185fb6c2e3c38ff8d74 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 12 Sep 2023 21:14:42 +0100 Subject: [PATCH] Fix error handling in MEGAN (#29) In search experiments, it can be observed that MEGAN sometimes produces odd molecules containing `C+` atoms, which violate valence despite being accepted as valid by `rdkit`. These molecules fail sanitization through MEGAN's `fix_explicit_hs` utility, and so far our wrapper has been ignoring this and passing the molecule through to the model to propose reactions (likely pushing the problematic `C+` atom to reactants). In this PR, the underlying model is not called on inputs which fail `fix_explicit_hs`, and rather an empty list of reactions is returned in this case. Relatedly, I also silenced some of the other internal warnings produced by the MEGAN model by applying a (now improved) `suppress_outputs` context manager. --- CHANGELOG.md | 1 + .../reaction_prediction/inference/megan.py | 62 +++++++++++-------- syntheseus/reaction_prediction/utils/misc.py | 3 + 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21894dc6..45eb1e07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. - Fix bug where standardizing MolSetGraphs crashed ([#24](https://github.com/microsoft/syntheseus/pull/24)) ([@austint]) - Change default node depth to infinity ([#16](https://github.com/microsoft/syntheseus/pull/16)) ([@austint]) - Adapt tutorials to the renaming from PR #9 ([#17](https://github.com/microsoft/syntheseus/pull/17)) ([@jagarridotorres]) +- Fix error handling in MEGAN ([#29](https://github.com/microsoft/syntheseus/pull/29)) ([@kmaziarz]) - Pin `pydantic` version to `1.*` ([#10](https://github.com/microsoft/syntheseus/pull/10)) ([@kmaziarz]) - Fix compatibility with Python 3.7 ([#5](https://github.com/microsoft/syntheseus/pull/5)) ([@kmaziarz]) - Pin `zipp` version to `<3.16` ([#11](https://github.com/microsoft/syntheseus/pull/11)) ([@kmaziarz]) diff --git a/syntheseus/reaction_prediction/inference/megan.py b/syntheseus/reaction_prediction/inference/megan.py index 856cc0a2..708e6a49 100644 --- a/syntheseus/reaction_prediction/inference/megan.py +++ b/syntheseus/reaction_prediction/inference/megan.py @@ -6,10 +6,12 @@ The original MEGAN code is released under the MIT license. """ +from __future__ import annotations + import os import sys from pathlib import Path -from typing import List, Union +from typing import Any, Optional, Union from rdkit import Chem @@ -20,6 +22,7 @@ get_unique_file_in_dir, process_raw_smiles_outputs, ) +from syntheseus.reaction_prediction.utils.misc import suppress_outputs class MEGANModel(BackwardReactionModel): @@ -89,7 +92,7 @@ def __init__( def get_parameters(self): return self.model.parameters() - def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]: + def _mols_to_batch(self, inputs: list[Molecule]) -> list[Optional[Chem.Mol]]: from src.feat.utils import fix_explicit_hs # Inputs to the model are list of `rdkit` molecules. @@ -102,37 +105,46 @@ def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]: a.SetAtomMapNum(i + 1) try: - mol = fix_explicit_hs(mol) + input_batch.append(fix_explicit_hs(mol)) except Exception: - # Sometimes `fix_explicit_hs` may fail with an `rdkit` error. In such cases we give - # up and use the molecule as-is. - # TODO(kmaziarz): Investigate these cases in more detail. - pass - - input_batch.append(mol) + # MEGAN sometimes produces broken molecules containing C+ atoms which pass `rdkit` + # sanitization but fail in `fix_explicit_hs`. We block these here to avoid making + # predictions for them. + input_batch.append(None) return input_batch - def __call__(self, inputs: List[Molecule], num_results: int) -> List[BackwardPredictionList]: + def __call__(self, inputs: list[Molecule], num_results: int) -> list[BackwardPredictionList]: import torch from src.model.beam_search import beam_search # Get the inputs into the right form to call the underlying model. batch = self._mols_to_batch(inputs) - - with torch.no_grad(): - beam_search_results = beam_search( - [self.model], - batch, - rdkit_cache=self.rdkit_cache, - max_steps=self.max_gen_steps, - beam_size=num_results, - batch_size=self.beam_batch_size, - base_action_masks=self.base_action_masks, - max_atoms=self.n_max_atoms, - reaction_types=None, - action_vocab=self.action_vocab, - ) # returns a list of `beam_size` results for each input molecule. + batch_valid = [mol for mol in batch if mol is not None] + batch_valid_idxs = [idx for idx, mol in enumerate(batch) if mol is not None] + + if batch_valid: + with torch.no_grad(), suppress_outputs(): + beam_search_results = beam_search( + [self.model], + batch_valid, + rdkit_cache=self.rdkit_cache, + max_steps=self.max_gen_steps, + beam_size=num_results, + batch_size=self.beam_batch_size, + base_action_masks=self.base_action_masks, + max_atoms=self.n_max_atoms, + reaction_types=None, + action_vocab=self.action_vocab, + ) # returns a list of `beam_size` results for each input molecule + else: + beam_search_results = [] + + assert len(batch_valid_idxs) == len(beam_search_results) + + all_outputs: list[list[dict[str, Any]]] = [[] for _ in batch] + for idx, raw_outputs in zip(batch_valid_idxs, beam_search_results): + all_outputs[idx] = raw_outputs return [ process_raw_smiles_outputs( @@ -140,5 +152,5 @@ def __call__(self, inputs: List[Molecule], num_results: int) -> List[BackwardPre output_list=[prediction["final_smi_unmapped"] for prediction in raw_outputs], kwargs_list=[{"probability": prediction["prob"]} for prediction in raw_outputs], ) - for input, raw_outputs in zip(inputs, beam_search_results) + for input, raw_outputs in zip(inputs, all_outputs) ] diff --git a/syntheseus/reaction_prediction/utils/misc.py b/syntheseus/reaction_prediction/utils/misc.py index ae1a5e67..6172ff3f 100644 --- a/syntheseus/reaction_prediction/utils/misc.py +++ b/syntheseus/reaction_prediction/utils/misc.py @@ -1,3 +1,4 @@ +import logging import multiprocessing import os import random @@ -33,7 +34,9 @@ def suppress_outputs(): """Suppress messages written to both stdout and stderr.""" with open(devnull, "w") as fnull: with redirect_stderr(fnull), redirect_stdout(fnull): + logging.disable(logging.CRITICAL) yield + logging.disable(logging.NOTSET) def dictify(data: Any) -> Any: