Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error handling in MEGAN #29

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
- Fix bug where standardizing MolSetGraphs crashed ([#24](https://github.com/microsoft/syntheseus/pull/24)) ([@austint])
- Change default node depth to infinity ([#16](https://github.com/microsoft/syntheseus/pull/16)) ([@austint])
- Adapt tutorials to the renaming from PR #9 ([#17](https://github.com/microsoft/syntheseus/pull/17)) ([@jagarridotorres])
- Fix error handling in MEGAN ([#29](https://github.com/microsoft/syntheseus/pull/29)) ([@kmaziarz])
- Pin `pydantic` version to `1.*` ([#10](https://github.com/microsoft/syntheseus/pull/10)) ([@kmaziarz])
- Fix compatibility with Python 3.7 ([#5](https://github.com/microsoft/syntheseus/pull/5)) ([@kmaziarz])
- Pin `zipp` version to `<3.16` ([#11](https://github.com/microsoft/syntheseus/pull/11)) ([@kmaziarz])
Expand Down
58 changes: 34 additions & 24 deletions syntheseus/reaction_prediction/inference/megan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import sys
from pathlib import Path
from typing import List, Union
from typing import Any, Dict, List, Optional, Union
kmaziarz marked this conversation as resolved.
Show resolved Hide resolved

from rdkit import Chem

Expand All @@ -20,6 +20,7 @@
get_unique_file_in_dir,
process_raw_smiles_outputs,
)
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class MEGANModel(BackwardReactionModel):
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
def get_parameters(self):
return self.model.parameters()

def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]:
def _mols_to_batch(self, inputs: List[Molecule]) -> List[Optional[Chem.Mol]]:
from src.feat.utils import fix_explicit_hs

# Inputs to the model are list of `rdkit` molecules.
Expand All @@ -102,14 +103,12 @@ def _mols_to_batch(self, inputs: List[Molecule]) -> List[Chem.Mol]:
a.SetAtomMapNum(i + 1)

try:
mol = fix_explicit_hs(mol)
input_batch.append(fix_explicit_hs(mol))
except Exception:
# Sometimes `fix_explicit_hs` may fail with an `rdkit` error. In such cases we give
# up and use the molecule as-is.
# TODO(kmaziarz): Investigate these cases in more detail.
pass

input_batch.append(mol)
# MEGAN sometimes produces broken molecules containing C+ atoms which pass `rdkit`
# sanitization but fail in `fix_explicit_hs`. We block these here to avoid making
# predictions for them.
input_batch.append(None)

return input_batch

Expand All @@ -119,26 +118,37 @@ def __call__(self, inputs: List[Molecule], num_results: int) -> List[BackwardPre

# Get the inputs into the right form to call the underlying model.
batch = self._mols_to_batch(inputs)

with torch.no_grad():
beam_search_results = beam_search(
[self.model],
batch,
rdkit_cache=self.rdkit_cache,
max_steps=self.max_gen_steps,
beam_size=num_results,
batch_size=self.beam_batch_size,
base_action_masks=self.base_action_masks,
max_atoms=self.n_max_atoms,
reaction_types=None,
action_vocab=self.action_vocab,
) # returns a list of `beam_size` results for each input molecule.
batch_valid = [mol for mol in batch if mol is not None]
batch_valid_idxs = [idx for idx, mol in enumerate(batch) if mol is not None]

if batch_valid:
with torch.no_grad(), suppress_outputs():
beam_search_results = beam_search(
[self.model],
batch_valid,
rdkit_cache=self.rdkit_cache,
max_steps=self.max_gen_steps,
beam_size=num_results,
batch_size=self.beam_batch_size,
base_action_masks=self.base_action_masks,
max_atoms=self.n_max_atoms,
reaction_types=None,
action_vocab=self.action_vocab,
) # returns a list of `beam_size` results for each input molecule
else:
beam_search_results = []

assert len(batch_valid_idxs) == len(beam_search_results)

all_outputs: List[List[Dict[str, Any]]] = [[] for _ in batch]
for idx, raw_outputs in zip(batch_valid_idxs, beam_search_results):
all_outputs[idx] = raw_outputs

return [
process_raw_smiles_outputs(
input=input,
output_list=[prediction["final_smi_unmapped"] for prediction in raw_outputs],
kwargs_list=[{"probability": prediction["prob"]} for prediction in raw_outputs],
)
for input, raw_outputs in zip(inputs, beam_search_results)
for input, raw_outputs in zip(inputs, all_outputs)
]
3 changes: 3 additions & 0 deletions syntheseus/reaction_prediction/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import multiprocessing
import os
import random
Expand Down Expand Up @@ -33,7 +34,9 @@ def suppress_outputs():
"""Suppress messages written to both stdout and stderr."""
with open(devnull, "w") as fnull:
with redirect_stderr(fnull), redirect_stdout(fnull):
logging.disable(logging.CRITICAL)
yield
logging.disable(logging.NOTSET)


def dictify(data: Any) -> Any:
Expand Down
Loading