From 25391fef9b44c807030f79ee008e11b6520739cb Mon Sep 17 00:00:00 2001 From: Ugeun Park <37043543+shashaka@users.noreply.github.com> Date: Sun, 28 Jul 2024 14:41:15 -0400 Subject: [PATCH] Update batch_normalization.py (#20057) * Refactoring for batch_normalization again * Refactoring for batch_normalization again * Refactoring for batch_normalization again --- .../layers/normalization/batch_normalization.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py index efc92cd160ad..8b0160344eec 100644 --- a/keras/src/layers/normalization/batch_normalization.py +++ b/keras/src/layers/normalization/batch_normalization.py @@ -166,6 +166,12 @@ def __init__( self.gamma_constraint = constraints.get(gamma_constraint) self.supports_masking = True + self.gamma = None + self.beta = None + self.moving_mean = None + self.moving_variance = None + self._reduction_axes = None + def build(self, input_shape): shape = (input_shape[self.axis],) if self.scale: @@ -202,9 +208,11 @@ def build(self, input_shape): trainable=False, autocast=False, ) + self.input_spec = InputSpec( ndim=len(input_shape), axes={self.axis: input_shape[self.axis]} ) + reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] self._reduction_axes = reduction_axes @@ -230,10 +238,12 @@ def call(self, inputs, training=None, mask=None): # out BN for mixed precision. inputs = ops.cast(inputs, "float32") + moving_mean = ops.cast(self.moving_mean, inputs.dtype) + moving_variance = ops.cast(self.moving_variance, inputs.dtype) + if training and self.trainable: mean, variance = self._moments(inputs, mask) - moving_mean = ops.cast(self.moving_mean, inputs.dtype) - moving_variance = ops.cast(self.moving_variance, inputs.dtype) + self.moving_mean.assign( moving_mean * self.momentum + mean * (1.0 - self.momentum) ) @@ -242,8 +252,6 @@ def call(self, inputs, training=None, mask=None): + variance * (1.0 - self.momentum) ) else: - moving_mean = ops.cast(self.moving_mean, inputs.dtype) - moving_variance = ops.cast(self.moving_variance, inputs.dtype) mean = moving_mean variance = moving_variance