Skip to content

Commit

Permalink
Update batch_normalization.py (keras-team#20057)
Browse files Browse the repository at this point in the history
* Refactoring for batch_normalization again

* Refactoring for batch_normalization again

* Refactoring for batch_normalization again
  • Loading branch information
shashaka committed Jul 28, 2024
1 parent f0e79cf commit 25391fe
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions keras/src/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def __init__(
self.gamma_constraint = constraints.get(gamma_constraint)
self.supports_masking = True

self.gamma = None
self.beta = None
self.moving_mean = None
self.moving_variance = None
self._reduction_axes = None

def build(self, input_shape):
shape = (input_shape[self.axis],)
if self.scale:
Expand Down Expand Up @@ -202,9 +208,11 @@ def build(self, input_shape):
trainable=False,
autocast=False,
)

self.input_spec = InputSpec(
ndim=len(input_shape), axes={self.axis: input_shape[self.axis]}
)

reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
self._reduction_axes = reduction_axes
Expand All @@ -230,10 +238,12 @@ def call(self, inputs, training=None, mask=None):
# out BN for mixed precision.
inputs = ops.cast(inputs, "float32")

moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)

if training and self.trainable:
mean, variance = self._moments(inputs, mask)
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)

self.moving_mean.assign(
moving_mean * self.momentum + mean * (1.0 - self.momentum)
)
Expand All @@ -242,8 +252,6 @@ def call(self, inputs, training=None, mask=None):
+ variance * (1.0 - self.momentum)
)
else:
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
mean = moving_mean
variance = moving_variance

Expand Down

0 comments on commit 25391fe

Please sign in to comment.