-
Notifications
You must be signed in to change notification settings - Fork 30
/
_layers.py
60 lines (55 loc) · 2.65 KB
/
_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from keras.layers import Layer, GRU
from kgcnn.layers.casting import CastDisjointToBatchedAttributes
class PoolingNodesGRU(Layer):
def __init__(self, units, static_output_shape=None,
activation='tanh', recurrent_activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros', kernel_regularizer=None,
recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None,
recurrent_constraint=None, bias_constraint=None, dropout=0.0,
recurrent_dropout=0.0, reset_after=True, seed=None,
**kwargs):
super(PoolingNodesGRU, self).__init__(**kwargs)
self.units = units
self.static_output_shape = static_output_shape
self.cast_layer = CastDisjointToBatchedAttributes(
static_output_shape=static_output_shape, return_mask=True)
self.gru = GRU(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
reset_after=reset_after,
seed=seed
)
def call(self, inputs, **kwargs):
n, mask = self.cast_layer(inputs)
out = self.gru(n, mask=mask)
return out
def get_config(self):
config = super(PoolingNodesGRU, self).get_config()
config.update({"units": self.units, "static_output_shape": self.static_output_shape})
conf_gru = self.gru.get_config()
param_list = ["units", "activation", "recurrent_activation",
"use_bias", "kernel_initializer",
"recurrent_initializer",
"bias_initializer", "kernel_regularizer",
"recurrent_regularizer", "bias_regularizer", "kernel_constraint",
"recurrent_constraint", "bias_constraint", "dropout",
"recurrent_dropout", "reset_after"]
for x in param_list:
if x in conf_gru.keys():
config.update({x: conf_gru[x]})
return config