Skip to content

Commit

Permalink
Restore hacky workaround for tf dict wrapper errors (keras-team#18682)
Browse files Browse the repository at this point in the history
In keras-team#18507 we removed a hacky
workaround for functional models so dictionary output is not made a
trackable.

We actually still need this when calling `predict` on a functional with
dictionary output.

Added testing to plug the coverage gap. Still no idea what the status
of the underlying bug is :/, filed with tensorflow

tensorflow/tensorflow#62217
  • Loading branch information
mattdangerw authored Oct 25, 2023
1 parent b7152af commit d8a06ff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
10 changes: 10 additions & 0 deletions keras/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def test_functional_dict_outputs_dict_losses(self):
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Check dict outputs.
outputs = model.predict(x)
self.assertIsInstance(outputs, dict)
self.assertEqual(outputs["output_a"].shape, (8, 1))
self.assertEqual(outputs["output_b"].shape, (8, 1))
# Fit the model to make sure compile_metrics are built
hist = model.fit(
x,
Expand Down Expand Up @@ -332,6 +337,11 @@ def test_functional_list_outputs_dict_losses_metrics(self):
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Check list outputs.
outputs = model.predict(x)
self.assertIsInstance(outputs, list)
self.assertEqual(outputs[0].shape, (8, 1))
self.assertEqual(outputs[1].shape, (8, 1))
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
Expand Down
12 changes: 12 additions & 0 deletions keras/ops/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from keras.api_export import keras_export
from keras.backend import KerasTensor
from keras.backend.config import backend
from keras.ops.operation import Operation
from keras.utils.nest import pack_sequence_as

Expand Down Expand Up @@ -46,10 +47,21 @@ class Function(Operation):
def __init__(self, inputs, outputs, name=None):
super().__init__(name=name)

if backend() == "tensorflow":
# Temporary work around for
# https://github.com/keras-team/keras/issues/931
# This stop tensorflow from wrapping tf.function output in a
# _DictWrapper object.
_self_setattr_tracking = getattr(
self, "_self_setattr_tracking", True
)
self._self_setattr_tracking = False
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
self._inputs = tree.flatten(inputs)
self._outputs = tree.flatten(outputs)
if backend() == "tensorflow":
self._self_setattr_tracking = _self_setattr_tracking

(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
self._inputs, self._outputs
Expand Down

0 comments on commit d8a06ff

Please sign in to comment.