Skip to content

Commit

Permalink
Merge pull request #4 from compomics/cache-residue-encoding
Browse files Browse the repository at this point in the history
Add caching for encoding individual residues in PeptideGraphEncoder
  • Loading branch information
akensert committed May 16, 2024
2 parents 950bc83 + ef16038 commit a92e4d2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
lightning_logs/
notebooks/_*.ipynb

# vscode
Expand Down
56 changes: 32 additions & 24 deletions molexpress/datasets/encoders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Union
from functools import lru_cache
from typing import Dict, Tuple, Union

import numpy as np

Expand All @@ -23,18 +24,28 @@ def __call__(self, residues: list[types.Molecule | types.SMILES | types.InChI])
residue_graphs = []
residue_sizes = []
for residue in residues:
residue = chem_ops.get_molecule(residue)
residue_graph = {
**self.node_encoder(residue),
**self.edge_encoder(residue)
}
residue_graph, residue_size = self._encode_residue(
residue, self.node_encoder, self.edge_encoder
)
residue_graphs.append(residue_graph)
residue_sizes.append(residue.GetNumAtoms())
residue_sizes.append(residue_size)

disjoint_peptide_graph = self._merge_molecular_graphs(residue_graphs)
disjoint_peptide_graph["residue_size"] = np.array(residue_sizes)
disjoint_peptide_graph["peptide_size"] = np.array([len(residues)], dtype="int32")
return disjoint_peptide_graph

@staticmethod
@lru_cache(maxsize=None)
def _encode_residue(
residue: types.Molecule | types.SMILES | types.InChI,
node_encoder: MolecularNodeEncoder,
edge_encoder: MolecularEdgeEncoder,
) -> Tuple[Dict[str, np.ndarray], int]:
residue = chem_ops.get_molecule(residue)
residue_graph = {**node_encoder(residue), **edge_encoder(residue)}
return residue_graph, residue.GetNumAtoms()

@staticmethod
def collate_fn(
data: list[Union[types.MolecularGraph, tuple[types.MolecularGraph, np.ndarray]]],
Expand All @@ -55,12 +66,12 @@ def collate_fn(
disjoint_peptide_batch_graph = PeptideGraphEncoder._merge_molecular_graphs(
disjoint_peptide_graphs
)
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate([
g["residue_size"].shape[:1] for g in disjoint_peptide_graphs
]).astype("int32")
disjoint_peptide_batch_graph["residue_size"] = np.concatenate([
g["residue_size"] for g in disjoint_peptide_graphs
]).astype("int32")
disjoint_peptide_batch_graph["peptide_size"] = np.concatenate(
[g["residue_size"].shape[:1] for g in disjoint_peptide_graphs]
).astype("int32")
disjoint_peptide_batch_graph["residue_size"] = np.concatenate(
[g["residue_size"] for g in disjoint_peptide_graphs]
).astype("int32")

if y is None:
return disjoint_peptide_batch_graph
Expand All @@ -71,21 +82,18 @@ def collate_fn(
def _merge_molecular_graphs(
molecular_graphs: list[types.MolecularGraph],
) -> types.MolecularGraph:

num_nodes = np.array([
g["node_state"].shape[0] for g in molecular_graphs
])
num_nodes = np.array([g["node_state"].shape[0] for g in molecular_graphs])

disjoint_molecular_graph = {}

disjoint_molecular_graph["node_state"] = np.concatenate([
g["node_state"] for g in molecular_graphs
])
disjoint_molecular_graph["node_state"] = np.concatenate(
[g["node_state"] for g in molecular_graphs]
)

if "edge_state" in molecular_graphs[0]:
disjoint_molecular_graph["edge_state"] = np.concatenate([
g["edge_state"] for g in molecular_graphs
])
disjoint_molecular_graph["edge_state"] = np.concatenate(
[g["edge_state"] for g in molecular_graphs]
)

edge_src = np.concatenate([graph["edge_src"] for graph in molecular_graphs])
edge_dst = np.concatenate([graph["edge_dst"] for graph in molecular_graphs])
Expand Down Expand Up @@ -147,7 +155,7 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray:
if molecule.GetNumBonds() == 0:
edge_state = np.zeros(
shape=(int(self.self_loops), self.output_dim + int(self.self_loops)),
dtype=self.output_dtype
dtype=self.output_dtype,
)
return {
"edge_src": edge_src,
Expand Down
2 changes: 1 addition & 1 deletion molexpress/layers/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
node_state_updated = self.activation(node_state_updated)

if self.skip_connection:
node_state_updated += node_state
node_state_updated = node_state_updated + node_state

if self.dropout_rate:
node_state_updated = self.dropout(node_state_updated)
Expand Down

0 comments on commit a92e4d2

Please sign in to comment.