-
Notifications
You must be signed in to change notification settings - Fork 30
/
_model.py
109 lines (91 loc) · 3.56 KB
/
_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from keras.layers import Dense
from kgcnn.layers.conv import SchNetInteraction
from kgcnn.layers.geom import NodePosition, NodeDistanceEuclidean, GaussBasisLayer, ShiftPeriodicLattice
from kgcnn.layers.mlp import GraphMLP, MLP
from kgcnn.layers.modules import Embedding
from kgcnn.layers.pooling import PoolingNodes
def model_disjoint(
inputs,
use_node_embedding: bool = None,
input_node_embedding: dict = None,
make_distance: bool = None,
expand_distance: bool = None,
gauss_args: dict = None,
interaction_args: dict = None,
node_pooling_args: dict = None,
depth: int = None,
last_mlp: dict = None,
output_embedding: str = None,
use_output_mlp: bool = None,
output_mlp: dict = None):
n, x, disjoint_indices, batch_id_node, count_nodes = inputs
# Optional Embedding.
if use_node_embedding:
n = Embedding(**input_node_embedding)(n)
if make_distance:
pos1, pos2 = NodePosition()([x, disjoint_indices])
ed = NodeDistanceEuclidean()([pos1, pos2])
else:
ed = x
if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)
# Model
n = Dense(interaction_args["units"], activation='linear')(n)
for i in range(0, depth):
n = SchNetInteraction(**interaction_args)([n, ed, disjoint_indices])
n = GraphMLP(**last_mlp)([n, batch_id_node, count_nodes])
# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = n
if use_output_mlp:
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `SchNet` .")
return out
def model_disjoint_crystal(
inputs,
use_node_embedding: bool = None,
input_node_embedding: dict = None,
make_distance: bool = None,
expand_distance: bool = None,
gauss_args: dict = None,
interaction_args: dict = None,
node_pooling_args: dict = None,
depth: int = None,
last_mlp: dict = None,
output_embedding: str = None,
use_output_mlp: bool = None,
output_mlp: dict = None):
n, x, disjoint_indices, edge_image, lattice, batch_id_node, batch_id_edge, count_nodes = inputs
# Optional Embedding.
if use_node_embedding:
n = Embedding(**input_node_embedding)(n)
if make_distance:
pos1, pos2 = NodePosition()([x, disjoint_indices])
pos2 = ShiftPeriodicLattice()([pos2, edge_image, lattice, batch_id_edge])
ed = NodeDistanceEuclidean()([pos1, pos2])
else:
ed, _, _, _ = x
if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)
# Model
n = Dense(interaction_args["units"], activation='linear')(n)
for i in range(0, depth):
n = SchNetInteraction(**interaction_args)([n, ed, disjoint_indices])
n = GraphMLP(**last_mlp)([n, batch_id_node, count_nodes])
# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
elif output_embedding == 'node':
out = n
if use_output_mlp:
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `SchNet` .")
return out