Skip to content

Commit

Permalink
change init and trafo for lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
davidruegamer committed Feb 13, 2024
1 parent 81afb0e commit 5cbe707
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions inst/python/psplines/psplines.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class LambdaLayer(tf.keras.layers.Layer):
def __init__(self, units, P, damping = 1.0, scale = 1.0, **kwargs):
super(LambdaLayer, self).__init__(**kwargs)
self.units = units
self.loglambda = self.add_weight(name='loglambda',
self.lambdasqrt = self.add_weight(name='lambdasqrt',
shape=(units,len(P)),
initializer=tf.keras.initializers.RandomNormal,
trainable=True)
Expand All @@ -259,9 +259,8 @@ def __init__(self, units, P, damping = 1.0, scale = 1.0, **kwargs):
self.P = P

def call(self, inputs, w):
# lmbda = tf.reshape(tf.math.exp(self.loglambda), [])
for i in range(len(self.P)):
lmbda = tf.reshape(self.loglambda[:,i], [])
lmbda = tf.reshape(tf.math.square(self.lambdasqrt[:,i]), [])
inf = 0.5 * tf.reduce_sum(vecmatvec(w, tf.cast(self.P[i], dtype="float32")))
damp_term = self.damping * inf**2 / 2
l_term = lmbda * inf
Expand All @@ -272,7 +271,7 @@ def get_config(self):
config = super().get_config().copy()
config.update({
'units': self.units,
'lambda': self.loglambda.numpy(),
'lambda': self.lambdasqrt.numpy(),
'P': self.P
})
return config
Expand Down

0 comments on commit 5cbe707

Please sign in to comment.