Skip to content

Commit

Permalink
add gatv2
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Jun 21, 2024
1 parent bb6aaed commit b1363e5
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 14 deletions.
1 change: 1 addition & 0 deletions molexpress/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from molexpress.layers.base_layer import BaseLayer as BaseLayer
from molexpress.layers.gcn_conv import GCNConv as GCNConv
from molexpress.layers.gin_conv import GINConv as GINConv
from molexpress.layers.gatv2_conv import GATv2Conv as GATv2Conv
from molexpress.layers.peptide_readout import PeptideReadout as PeptideReadout
from molexpress.layers.residue_readout import ResidueReadout as ResidueReadout
from molexpress.layers.gather_incident import GatherIncident as GatherIncident
186 changes: 186 additions & 0 deletions molexpress/layers/gatv2_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from __future__ import annotations

import keras

from molexpress import types
from molexpress.layers.base_layer import BaseLayer
from molexpress.ops import gnn_ops


class GATv2Conv(BaseLayer):
def __init__(
self,
units: int,
heads: int,
activation: keras.layers.Activation = None,
use_bias: bool = True,
normalization: bool = True,
skip_connection: bool = True,
dropout_rate: float = 0,
kernel_initializer: keras.initializers.Initializer = "glorot_uniform",
bias_initializer: keras.initializers.Initializer = "zeros",
kernel_regularizer: keras.regularizers.Regularizer = None,
bias_regularizer: keras.regularizers.Regularizer = None,
activity_regularizer: keras.regularizers.Regularizer = None,
kernel_constraint: keras.constraints.Constraint = None,
bias_constraint: keras.constraints.Constraint = None,
**kwargs,
) -> None:
super().__init__(
units=units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs,
)
self.heads = heads
self.dropout_rate = dropout_rate
self.skip_connection = skip_connection
self.normalization = normalization
self.attention_activation = keras.activations.get('leaky_relu')
if self.units % self.heads != 0:
raise ValueError(
f"units ({self.units}) needs to be divisble by heads {self.heads}")
else:
self.units_per_head = self.units // self.heads

def build(self, input_shape: dict[str, tuple[int, ...]]) -> None:

node_state_shape = input_shape["node_state"]
edge_state_shape = input_shape.get("edge_state")

node_dim = node_state_shape[-1]

if edge_state_shape is not None:
edge_dim = edge_state_shape[-1]
else:
edge_dim = 0

self._transform_residual = node_dim != self.units
if self._transform_residual:
self.residual_node_kernel = self.add_kernel(
name="residual_node_kernel", shape=(node_dim, self.units)
)

self.kernel = self.add_kernel(
name="kernel", shape=(node_dim * 2 + edge_dim, self.units_per_head, self.heads))
if self.use_bias:
self.bias = self.add_bias(
name="bias", shape=(self.units_per_head, self.heads))

self.attention_kernel = self.add_kernel(
name="attention_kernel", shape=(self.units_per_head, 1, self.heads))
if self.use_bias:
self.attention_bias = self.add_bias(
name="attention_bias", shape=(1, self.heads))


self.node_kernel = self.add_kernel(
name="node_kernel", shape=(node_dim, self.units_per_head, self.heads))
if self.use_bias:
self.node_bias = self.add_bias(
name="node_bias", shape=(self.units_per_head, self.heads))


if edge_state_shape is not None:
self.edge_kernel = self.add_kernel(
name="edge_kernel", shape=(
self.units_per_head, self.units_per_head, self.heads)
)
if self.use_bias:
self.edge_bias = self.add_bias(
name="edge_bias", shape=(self.units_per_head, self.heads))

if self.normalization:
self.normalize = keras.layers.BatchNormalization()

if self.dropout_rate:
self.dropout = keras.layers.Dropout(self.dropout_rate)


def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
x = inputs.copy()

node_state = x.pop("node_state")
edge_src = keras.ops.cast(x["edge_src"], "int32")
edge_dst = keras.ops.cast(x["edge_dst"], "int32")
edge_state = x.pop("edge_state", None)
edge_weight = x.get("edge_weight")

Check failure on line 114 in molexpress/layers/gatv2_conv.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

molexpress/layers/gatv2_conv.py:114:9: F841 Local variable `edge_weight` is assigned to but never used

Check failure on line 114 in molexpress/layers/gatv2_conv.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

molexpress/layers/gatv2_conv.py:114:9: F841 Local variable `edge_weight` is assigned to but never used

Check failure on line 114 in molexpress/layers/gatv2_conv.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

molexpress/layers/gatv2_conv.py:114:9: F841 Local variable `edge_weight` is assigned to but never used

if edge_state is None:
attention_feature = keras.ops.concatenate([
gnn_ops.gather(node_state, edge_src),
gnn_ops.gather(node_state, edge_dst),
], axis=-1)
else:
attention_feature = keras.ops.concatenate([
gnn_ops.gather(node_state, edge_src),
gnn_ops.gather(node_state, edge_dst),
edge_state
], axis=-1)


node_state_updated = gnn_ops.transform(
node_state, self.node_kernel, self.node_bias)

attention_feature = gnn_ops.transform(
attention_feature, self.kernel, self.bias)

if edge_state is not None:
edge_state_updated = gnn_ops.transform(
attention_feature, self.edge_kernel, self.edge_bias)
edge_state_updated = keras.ops.reshape(
edge_state_updated, (-1, self.units))


attention_feature = self.attention_activation(attention_feature)
attention_feature = gnn_ops.transform(
attention_feature, self.attention_kernel, self.attention_bias
)
attention_score = gnn_ops.edge_softmax(attention_feature, edge_dst)

node_state_updated = gnn_ops.aggregate(
node_state=node_state_updated,
edge_src=edge_src,
edge_dst=edge_dst,
edge_state=None,
edge_weight=attention_score,
)

node_state_updated = keras.ops.reshape(
node_state_updated, (-1, self.units)
)
if self.activation is not None:
node_state_updated = self.activation(node_state_updated)

if self.skip_connection:
if self._transform_residual:
node_state = gnn_ops.transform(
node_state, self.residual_node_kernel)
node_state_updated = node_state_updated + node_state

if self.dropout_rate:
node_state_updated = self.dropout(node_state_updated)

return dict(
node_state=node_state_updated,
edge_state=edge_state_updated,
**x)

def get_config(self) -> dict[str, types.Any]:
config = super().get_config()
config.update(
{
"heads": self.heads,
"normalization": self.normalization,
"skip_connection": self.skip_connection,
"dropout_rate": self.dropout_rate,
}
)
return config
4 changes: 2 additions & 2 deletions molexpress/layers/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
x = inputs.copy()

node_state = x.pop("node_state")
edge_src = x["edge_src"]
edge_dst = x["edge_dst"]
edge_src = keras.ops.cast(x["edge_src"], "int32")
edge_dst = keras.ops.cast(x["edge_dst"], "int32")
edge_state = x.get("edge_state")
edge_weight = x.get("edge_weight")

Expand Down
4 changes: 2 additions & 2 deletions molexpress/layers/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def call(self, inputs: types.MolecularGraph) -> types.MolecularGraph:
x = inputs.copy()

node_state = x.pop("node_state")
edge_src = x["edge_src"]
edge_dst = x["edge_dst"]
edge_src = keras.ops.cast(x["edge_src"], "int32")
edge_dst = keras.ops.cast(x["edge_dst"], "int32")
edge_state = x.get("edge_state")
edge_weight = x.get("edge_weight")

Expand Down
34 changes: 24 additions & 10 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ def transform(
Returns:
A transformed node state.
"""

state_transformed = keras.ops.matmul(state, kernel)
if len(keras.ops.shape(kernel)) == 2:
# kernel.rank == state.rank == 2
state_transformed = keras.ops.matmul(state, kernel)
elif len(keras.ops.shape(kernel)) == len(keras.ops.shape(state)):
# kernel.rank == state.rank == 3
state_transformed = keras.ops.einsum('ijh,jkh->ikh', state, kernel)
else:
# kernel.rank == 3 and state.rank == 2
state_transformed = keras.ops.einsum('ij,jkh->ikh', state, kernel)
if bias is not None:
state_transformed += bias
return state_transformed
Expand Down Expand Up @@ -58,36 +65,43 @@ def aggregate(
"""
num_nodes = keras.ops.shape(node_state)[0]

# Instead of casting to int, throw an error if not int?
edge_src = keras.ops.cast(edge_src, "int32")
edge_dst = keras.ops.cast(edge_dst, "int32")

expected_rank = 2
expected_rank = len(keras.ops.shape(node_state))
current_rank = len(keras.ops.shape(edge_src))
for _ in range(expected_rank - current_rank):
edge_src = keras.ops.expand_dims(edge_src, axis=-1)
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)

node_state_src = keras.ops.take_along_axis(node_state, edge_src, axis=0)

if edge_weight is not None:
node_state_src *= edge_weight

if edge_state is not None:
node_state_src += edge_state

edge_dst = keras.ops.squeeze(edge_dst, axis=-1)
edge_dst = keras.ops.squeeze(edge_dst)

node_state_updated = keras.ops.segment_sum(
data=node_state_src, segment_ids=edge_dst, num_segments=num_nodes, sorted=False
)
return node_state_updated

def edge_softmax(score, edge_dst):
numerator = keras.ops.exp(score - keras.ops.max(score, axis=0, keepdims=True))
num_segments = keras.ops.max(edge_dst) + 1
denominator = keras.ops.segment_sum(numerator, edge_dst, num_segments, sorted=False)
expected_rank = len(keras.ops.shape(denominator))
current_rank = len(keras.ops.shape(edge_dst))
for _ in range(expected_rank - current_rank):
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)
denominator = keras.ops.take_along_axis(denominator, edge_dst, axis=0)
return numerator / denominator

def gather(
node_state: types.Array,
edge: types.Array,
) -> types.Array:
edge = keras.ops.cast(edge, "int32")
expected_rank = 2
expected_rank = len(keras.ops.shape(node_state))
current_rank = len(keras.ops.shape(edge))
for _ in range(expected_rank - current_rank):
edge = keras.ops.expand_dims(edge, axis=-1)
Expand Down

0 comments on commit b1363e5

Please sign in to comment.