From 144d2af38522cd82c52d9b8d2e827a21f4c53297 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 16 Nov 2023 10:19:51 -0800 Subject: [PATCH] Add shape validation for variable init --- keras/backend/common/variables.py | 13 ++++++++++++- keras/metrics/f_score_metrics.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/keras/backend/common/variables.py b/keras/backend/common/variables.py index 8eee4752c2c..24e0320c5fb 100644 --- a/keras/backend/common/variables.py +++ b/keras/backend/common/variables.py @@ -48,7 +48,7 @@ def __init__( if callable(initializer): self._value = None self._initializer = initializer - self._shape = standardize_shape(shape) + self._shape = self._validate_shape(shape) register_uninitialized_variable(self) else: raise ValueError( @@ -70,6 +70,7 @@ def __init__( ) else: if callable(initializer): + shape = self._validate_shape(shape) value = initializer(shape, dtype=dtype) else: value = initializer @@ -91,6 +92,16 @@ def _deferred_initialize(self): value = self._initializer(self._shape, dtype=self._dtype) self._initialize(value) + def _validate_shape(self, shape): + shape = standardize_shape(shape) + if None in shape: + raise ValueError( + "Shapes used to initialize variables must be " + "fully-defined (no `None` dimensions). Received: " + f"shape={shape} for variable path='{self.path}'" + ) + return shape + def _maybe_autocast(self, value): autocast_scope = get_autocast_scope() if autocast_scope is not None: diff --git a/keras/metrics/f_score_metrics.py b/keras/metrics/f_score_metrics.py index 2bac6d44372..4a03afffa87 100644 --- a/keras/metrics/f_score_metrics.py +++ b/keras/metrics/f_score_metrics.py @@ -135,7 +135,7 @@ def _build(self, y_true_shape, y_pred_shape): ) num_classes = y_pred_shape[-1] if self.average != "micro": - init_shape = num_classes + init_shape = (num_classes,) else: init_shape = ()