From 163a6da17d1ddf06508dd0db9a3f12841243f2e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20R=C3=BCgamer?= Date: Wed, 28 Feb 2024 15:34:25 +0100 Subject: [PATCH] add Laetitia as ctb --- DESCRIPTION | 3 ++- R/layers.R | 8 ++++++++ inst/python/psplines/psplines.py | 35 ++++++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 70798bb..5fae451 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,9 +1,10 @@ Package: deepregression Title: Fitting Deep Distributional Regression -Version: 2.0.0 +Version: 2.1.0 Authors@R: c( person("David", "Ruegamer", , "david.ruegamer@gmail.com", role = c("aut", "cre")), person("Christopher", "Marquardt", , "ch.marquardt@campus.lmu.de", role = c("ctb")), + person("Laetitia", "Frost", , "lae.frost@campus.lmu.de ", role = c("ctb")), person("Florian", "Pfisterer", , "florian.pfisterer@stat.uni-muenchen.de", role = c("ctb")), person("Philipp", "Baumann", , "baumann@kof.ethz.ch", role = c("ctb")), person("Chris", "Kolb", , "chris.kolb@stat.uni-muenchen.de", role = c("ctb")), diff --git a/R/layers.R b/R/layers.R index 1ecdbd3..6f693c3 100644 --- a/R/layers.R +++ b/R/layers.R @@ -24,6 +24,14 @@ pen_layer = function(units, P, ...) { layers$CombinedModel(units = units, P = P, ...) } +update_factor_callback = function(model, weightnr = -1L, ...) { + python_path <- system.file("python", package = "deepregression") + layers <- reticulate::import_from_path("psplines", path = python_path) + layers$UpdateMultiplicationFactorFromWeight(model = model, + weightnr = weightnr, + ...) +} + #' Hadamard-type layers #' #' @param units integer; number of units diff --git a/inst/python/psplines/psplines.py b/inst/python/psplines/psplines.py index 05410d1..0577b44 100644 --- a/inst/python/psplines/psplines.py +++ b/inst/python/psplines/psplines.py @@ -250,20 +250,21 @@ class LambdaLayer(tf.keras.layers.Layer): def __init__(self, units, P, damping = 1.0, scale = 1.0, **kwargs): super(LambdaLayer, self).__init__(**kwargs) self.units = units - self.lambdasqrt = self.add_weight(name='lambdasqrt', + self.trafolambda = self.add_weight(name='trafolambda', shape=(units,len(P)), - initializer=tf.keras.initializers.RandomNormal, + initializer=tf.keras.initializers.Constant(value=0), trainable=True) + self.phi = tf.Variable(1.0, name = 'phimultiplier', trainable=False, dtype=tf.float32) self.damping = damping self.scale = scale self.P = P def call(self, inputs, w): for i in range(len(self.P)): - lmbda = tf.reshape(tf.math.square(self.lambdasqrt[:,i]), []) + lmbda = tf.reshape(tf.math.exp(self.trafolambda[:,i]), []) inf = 0.5 * tf.reduce_sum(vecmatvec(w, tf.cast(self.P[i], dtype="float32"))) damp_term = self.damping * inf**2 / 2 - l_term = lmbda * inf + l_term = lmbda * inf / self.phi self.add_loss(self.scale * (l_term + damp_term)) return inputs @@ -271,8 +272,9 @@ def get_config(self): config = super().get_config().copy() config.update({ 'units': self.units, - 'lambda': self.lambdasqrt.numpy(), - 'P': self.P + 'trafolambda': self.trafolambda.numpy(), + 'P': self.P, + 'phi': self.phi }) return config @@ -290,6 +292,27 @@ def call(self, inputs): def compute_output_shape(self, input_shape): output_shape = input_shape[:-1] + (self.units,) return output_shape + +class UpdateMultiplicationFactorFromWeight(tf.keras.callbacks.Callback): + def __init__(self, model, weightnr = -1, trafo = lambda x: tf.math.square(tf.math.exp(x))): + super().__init__() + self.model = model + self.weightnr = weightnr + self.trafo = trafo + + def on_batch_begin(self, epoch, logs=None): + # Extract the value of the last weight of the model + new_phi_value = self.model.weights[self.weightnr].numpy() + + # Iterate through all layers of the model + for layer in self.model.layers: + # Check if the layer is an instance of CombinedModel + if isinstance(layer, CombinedModel): + # Access the LambdaLayer within the CombinedModel + lambda_layer = layer.lambda_layer + + # Update the phi variable within the LambdaLayer + tf.keras.backend.set_value(lambda_layer.phi, tf.reshape(self.trafo(new_phi_value), [])) def get_masks(mod): masks = []