Skip to content

Commit

Permalink
Merge pull request #8 from ecrl/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
tjkessler authored Jun 15, 2021
2 parents bac5dda + 7ff945d commit 94361ee
Show file tree
Hide file tree
Showing 14 changed files with 1,014 additions and 131 deletions.
Binary file removed examples/model.enc
Binary file not shown.
Binary file removed examples/model.pt
Binary file not shown.
Binary file added examples/model_cn.enc
Binary file not shown.
Binary file added examples/model_cn.pt
Binary file not shown.
Binary file added examples/model_mon.enc
Binary file not shown.
Binary file added examples/model_mon.pt
Binary file not shown.
Binary file added examples/model_ron.enc
Binary file not shown.
Binary file added examples/model_ron.pt
Binary file not shown.
282 changes: 160 additions & 122 deletions examples/predict_cn.ipynb

Large diffs are not rendered by default.

397 changes: 397 additions & 0 deletions examples/predict_mon.ipynb

Large diffs are not rendered by default.

437 changes: 437 additions & 0 deletions examples/predict_ron.ipynb

Large diffs are not rendered by default.

17 changes: 11 additions & 6 deletions graphchem/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self, node_dim: int, edge_dim: int, output_dim: int,
self._n_messages = n_messages

self.lin0 = nn.Linear(node_dim, node_dim)
self.lin0_edge = nn.Linear(edge_dim, edge_dim)

# Construct message passing layers for node, edge networks
self.node_conv = pyg_nn.MFConv(node_dim, node_dim)
self.edge_conv = pyg_nn.EdgeConv(nn.Sequential(
nn.Linear(2 * edge_dim, edge_dim)
))
self.node_gru = nn.GRU(node_dim, node_dim)
self.edge_gru = nn.GRU(edge_dim, edge_dim)

# Construct post-message passing layers
self.post_conv = nn.ModuleList()
Expand Down Expand Up @@ -82,7 +84,9 @@ def forward(self, data: 'torch_geometric.data.Data') -> Tuple[
x = torch.ones(data.num_nodes, 1)

out = F.relu(self.lin0(x))
out_edge = F.relu(self.lin0_edge(edge_attr))
h = out.unsqueeze(0)
h_edge = out_edge.unsqueeze(0)

# Feed forward, node and edge messages
for i in range(self._n_messages):
Expand All @@ -91,14 +95,15 @@ def forward(self, data: 'torch_geometric.data.Data') -> Tuple[
m = F.dropout(m, p=self._dropout, training=self.training)
out, h = self.node_gru(m.unsqueeze(0), h)
out = out.squeeze(0)
edge_attr = self.edge_conv(edge_attr, edge_index)
emb_edge = edge_attr
edge_attr = F.relu(edge_attr)
edge_attr = F.dropout(edge_attr, p=self._dropout,
training=self.training)

m_edge = F.relu(self.edge_conv(out_edge, edge_index))
emb_edge = m_edge
m_edge = F.dropout(m_edge, p=self._dropout, training=self.training)
out_edge, h_edge = self.edge_gru(m_edge.unsqueeze(0), h_edge)
out_edge = out_edge.squeeze(0)

# Concatenate node network and edge network output tensors
out = torch.cat([out[row], edge_attr[col]], dim=1)
out = torch.cat([out[row], out_edge[col]], dim=1)

# Perform scatter add, reshape to original node dimensionality
out = scatter_add(out, col, dim=0, dim_size=x.size(0))
Expand Down
10 changes: 8 additions & 2 deletions graphchem/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,15 @@ def encode(self, smiles: str) -> Tuple['torch.tensor', 'torch.tensor',
connectivity = np.zeros((2, 2 * mol.GetNumBonds()))
bond_index = 0
for atom in mol.GetAtoms():
start_index = atom.GetIdx()
for bond in atom.GetBonds():
connectivity[0, bond_index] = bond.GetBeginAtomIdx()
connectivity[1, bond_index] = bond.GetEndAtomIdx()
rev = bond.GetBeginAtomIdx() != start_index
if not rev:
connectivity[0, bond_index] = bond.GetBeginAtomIdx()
connectivity[1, bond_index] = bond.GetEndAtomIdx()
else:
connectivity[0, bond_index] = bond.GetEndAtomIdx()
connectivity[1, bond_index] = bond.GetBeginAtomIdx()
bond_index += 1
connectivity = torch.from_numpy(connectivity).type(torch.long)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='graphchem',
version='1.1.0',
version='1.2.0',
description='Graph-based machine learning for chemical property prediction',
url='https://github.com/ecrl/graphchem',
author='Travis Kessler',
Expand Down

0 comments on commit 94361ee

Please sign in to comment.