Skip to content

Commit

Permalink
Add shape validation for variable init
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 16, 2023
1 parent 00aad90 commit 144d2af
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
13 changes: 12 additions & 1 deletion keras/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
if callable(initializer):
self._value = None
self._initializer = initializer
self._shape = standardize_shape(shape)
self._shape = self._validate_shape(shape)
register_uninitialized_variable(self)
else:
raise ValueError(
Expand All @@ -70,6 +70,7 @@ def __init__(
)
else:
if callable(initializer):
shape = self._validate_shape(shape)
value = initializer(shape, dtype=dtype)
else:
value = initializer
Expand All @@ -91,6 +92,16 @@ def _deferred_initialize(self):
value = self._initializer(self._shape, dtype=self._dtype)
self._initialize(value)

def _validate_shape(self, shape):
shape = standardize_shape(shape)
if None in shape:
raise ValueError(
"Shapes used to initialize variables must be "
"fully-defined (no `None` dimensions). Received: "
f"shape={shape} for variable path='{self.path}'"
)
return shape

def _maybe_autocast(self, value):
autocast_scope = get_autocast_scope()
if autocast_scope is not None:
Expand Down
2 changes: 1 addition & 1 deletion keras/metrics/f_score_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _build(self, y_true_shape, y_pred_shape):
)
num_classes = y_pred_shape[-1]
if self.average != "micro":
init_shape = num_classes
init_shape = (num_classes,)
else:
init_shape = ()

Expand Down

0 comments on commit 144d2af

Please sign in to comment.