Skip to content

Commit

Permalink
Try a different initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
gmontamat committed Sep 21, 2024
1 parent 6fd917d commit 0f7d24a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/gentun/models/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def build_model(
return Model(inputs=x_input, outputs=x, name=f"{self.name}")

def reset_weights(self):
"""Initialize model weights."""
"""Initialize model weights using Xavier initializer."""
initializer = tf.keras.initializers.GlorotUniform()
for layer in self.model.layers:
if hasattr(layer, "kernel_initializer") and hasattr(layer, "bias_initializer"):
layer.kernel.assign(layer.kernel_initializer(tf.shape(layer.kernel)))
layer.bias.assign(layer.bias_initializer(tf.shape(layer.bias)))
elif hasattr(layer, "kernel_initializer"):
layer.kernel.assign(layer.kernel_initializer(tf.shape(layer.kernel)))
if hasattr(layer, "kernel"):
layer.kernel.assign(initializer(tf.shape(layer.kernel)))
if hasattr(layer, "bias"):
layer.bias.assign(tf.zeros(tf.shape(layer.bias)))

def create_train_evaluate(
self, x_train: np.ndarray, y_train: np.ndarray, x_test: np.ndarray, y_test: np.ndarray
Expand Down

0 comments on commit 0f7d24a

Please sign in to comment.