Skip to content

Commit

Permalink
fix(chem): Make the stereo information removal utility retain the rea…
Browse files Browse the repository at this point in the history
…ction type
  • Loading branch information
kmaziarz committed Sep 12, 2024
1 parent 67c1491 commit 0bf0c86
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
23 changes: 15 additions & 8 deletions syntheseus/reaction_prediction/chem/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional
from typing import Dict, Optional, Union

from rdkit import Chem

from syntheseus import Bag, Molecule, Reaction
from syntheseus import Bag, Molecule, Reaction, SingleProductReaction
from syntheseus.interface.models import ReactionType
from syntheseus.interface.molecule import SMILES_SEPARATOR

ATOM_MAPPING_PROP_NAME = "molAtomMapNumber"
Expand Down Expand Up @@ -33,12 +34,18 @@ def remove_stereo_information(mol: Molecule) -> Molecule:
return Molecule(Chem.MolToSmiles(mol.rdkit_mol, isomericSmiles=False))


def remove_stereo_information_from_reaction(reaction: Reaction) -> Reaction:
return Reaction(
reactants=Bag([remove_stereo_information(mol) for mol in reaction.reactants]),
products=Bag([remove_stereo_information(mol) for mol in reaction.products]),
identifier=reaction.identifier,
metadata=reaction.metadata,
def remove_stereo_information_from_reaction(reaction: ReactionType) -> ReactionType:
mol_kwargs: Dict[str, Union[Molecule, Bag[Molecule]]] = {
"reactants": Bag([remove_stereo_information(mol) for mol in reaction.reactants])
}

if isinstance(reaction, SingleProductReaction):
mol_kwargs["product"] = remove_stereo_information(reaction.product)
else:
mol_kwargs["products"] = Bag([remove_stereo_information(mol) for mol in reaction.products])

return reaction.__class__(
**mol_kwargs, identifier=reaction.identifier, metadata=reaction.metadata # type: ignore[arg-type]
)


Expand Down
25 changes: 24 additions & 1 deletion syntheseus/tests/reaction_prediction/chem/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from rdkit import Chem

from syntheseus import Molecule
from syntheseus import Bag, Molecule, Reaction, SingleProductReaction
from syntheseus.reaction_prediction.chem.utils import (
remove_atom_mapping,
remove_atom_mapping_from_mol,
remove_stereo_information,
remove_stereo_information_from_reaction,
)


Expand All @@ -27,3 +28,25 @@ def test_remove_stereo_information() -> None:

assert len(set([mol] + mols_chiral)) == 3
assert len(set([mol] + [remove_stereo_information(m) for m in mols_chiral])) == 1


def test_remove_stereo_information_from_reaction() -> None:
reactants = Bag([Molecule("CCC"), Molecule("CC(N)C#N")])
reactants_chiral = Bag([Molecule("CCC"), Molecule("C[C@H](N)C#N")])

product = Molecule("CC(N)C#N")
product_chiral = Molecule("C[C@H](N)C#N")

rxn = Reaction(reactants=reactants, products=Bag([product]))
rxn_chiral = Reaction(reactants=reactants_chiral, products=Bag([product_chiral]))
rxn_stereo_removed = remove_stereo_information_from_reaction(rxn_chiral)

assert type(rxn_stereo_removed) == Reaction
assert rxn_stereo_removed == rxn

sp_rxn = SingleProductReaction(reactants=reactants, product=product)
sp_rxn_chiral = SingleProductReaction(reactants=reactants_chiral, product=product_chiral)
sp_rxn_stero_removed = remove_stereo_information_from_reaction(sp_rxn_chiral)

assert type(sp_rxn_stero_removed) == SingleProductReaction
assert sp_rxn_stero_removed == sp_rxn

0 comments on commit 0bf0c86

Please sign in to comment.