diff --git a/keras/src/backend/tensorflow/optimizer.py b/keras/src/backend/tensorflow/optimizer.py index ca40a1f42b3..eb615f7d230 100644 --- a/keras/src/backend/tensorflow/optimizer.py +++ b/keras/src/backend/tensorflow/optimizer.py @@ -116,18 +116,18 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate): v.value if isinstance(v, backend.Variable) else v for v in trainable_variables ] + grads_and_vars = list(zip(grads, trainable_variables)) + grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars) tf.__internal__.distribute.interim.maybe_merge_call( self._distributed_tf_update_step, self._distribution_strategy, - list(zip(grads, trainable_variables)), + grads_and_vars, learning_rate, ) def _distributed_tf_update_step( self, distribution, grads_and_vars, learning_rate ): - grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars) - def apply_grad_to_update_var(var, grad, learning_rate): return self.update_step(grad, var, learning_rate) diff --git a/keras/src/backend/tensorflow/optimizer_distribute_test.py b/keras/src/backend/tensorflow/optimizer_distribute_test.py index 10ad17d2364..2d55c21796c 100644 --- a/keras/src/backend/tensorflow/optimizer_distribute_test.py +++ b/keras/src/backend/tensorflow/optimizer_distribute_test.py @@ -3,6 +3,7 @@ import numpy as np import pytest import tensorflow as tf +from absl.testing import parameterized from tensorflow.python.eager import context from keras.src import backend @@ -14,7 +15,7 @@ backend.backend() != "tensorflow", reason="The distribute test can only run with TF backend.", ) -class OptimizerDistributeTest(testing.TestCase): +class OptimizerDistributeTest(testing.TestCase, parameterized.TestCase): def setUp(self): super().setUp() # Need at least 2 devices for distribution related tests. @@ -39,20 +40,32 @@ def test_config(self): ) self.run_class_serialization_test(optimizer) - def test_single_step(self): + @parameterized.parameters([("keras_sgd",), ("tf_keras_sgd",)]) + def test_single_step(self, optimizer_type): + if optimizer_type == "tf_keras_sgd": + try: + import tf_keras + + optimizer_fn = tf_keras.optimizers.SGD + except (ImportError, AttributeError): + self.skipTest("tf_keras not installed") + else: + optimizer_fn = SGD with self.strategy.scope(): - optimizer = SGD( + optimizer = optimizer_fn( learning_rate=0.5, momentum=0.06, ) - grads = tf.constant([1.0, 6.0, 7.0, 2.0]) - vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + # use tf variable to work both in k2 & k3. + vars = tf.Variable([1.0, 2.0, 3.0, 4.0]) - self.strategy.run( - lambda: optimizer.apply_gradients(zip([grads], [vars])) - ) + def update(): + grads = tf.constant([1.0, 6.0, 7.0, 2.0]) + optimizer.apply_gradients(zip([grads], [vars])) + + self.strategy.run(update) self.assertAllClose( - vars, [0.5, -1.0, -0.5, 3.0], rtol=1e-4, atol=1e-4 + vars, [0.0, -4.0, -4.0, 2.0], rtol=1e-4, atol=1e-4 ) def test_weight_decay(self): @@ -91,31 +104,32 @@ def opt3_run(): def test_correctness_with_golden(self): with self.strategy.scope(): optimizer = SGD(nesterov=True) - x = backend.Variable(np.ones([10])) - grads = np.arange(0.1, 1.1, 0.1) - first_grads = np.full((10,), 0.01) + + def update_grads(): + grads = backend.convert_to_tensor(np.arange(0.1, 1.1, 0.1)) + optimizer.apply_gradients(zip([grads], [x])) + + def update_first_grads(): + first_grads = backend.convert_to_tensor(np.full((10,), 0.01)) + optimizer.apply_gradients(zip([first_grads], [x])) # fmt: off golden = np.array( - [[0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, - 0.9999, 0.9999], [0.9989, 0.9979, 0.9969, 0.9959, 0.9949, 0.9939, - 0.9929, 0.9919, 0.9909, 0.9899], [0.9979, 0.9959, 0.9939, 0.9919, - 0.9899, 0.9879, 0.9859, 0.9839, 0.9819, 0.9799], [0.9969, 0.9939, - 0.9909, 0.9879, 0.9849, 0.9819, 0.9789, 0.9759, 0.9729, 0.9699], - [0.9959, 0.9919, 0.9879, 0.9839, 0.9799, 0.9759, 0.9719, 0.9679, - 0.9639, 0.9599]] + [ + [0.9980, 0.9960, 0.9940, 0.9920, 0.9900, 0.9880, 0.9860, 0.9840, 0.9820, 0.9800], + [0.9978, 0.9958, 0.9938, 0.9918, 0.9898, 0.9878, 0.9858, 0.9838, 0.9818, 0.9798], + [0.9976, 0.9956, 0.9936, 0.9916, 0.9896, 0.9876, 0.9856, 0.9836, 0.9816, 0.9796], + [0.9974, 0.9954, 0.9934, 0.9914, 0.9894, 0.9874, 0.9854, 0.9834, 0.9814, 0.9794], + [0.9972, 0.9952, 0.9932, 0.9912, 0.9892, 0.9872, 0.9852, 0.9832, 0.9812, 0.9792], + ] ) # fmt: on - self.strategy.run( - lambda: optimizer.apply_gradients(zip([first_grads], [x])) - ) + self.strategy.run(update_grads) for i in range(5): self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) - self.strategy.run( - lambda: optimizer.apply_gradients(zip([grads], [x])) - ) + self.strategy.run(update_first_grads) def test_clip_norm(self): with self.strategy.scope(): @@ -190,7 +204,7 @@ def test_gradient_accumulation(self): self.assertAllClose(optimizer._iterations, 2) self.assertAllClose(optimizer.iterations, 0) self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) - self.assertAllClose(v, [[0.0, 1.0], [1.0, 2.0]]) + self.assertAllClose(v, [[-1.0, 0.0], [-1.0, 0.0]]) self.assertAllClose( optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] ) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index c6226d666f7..d38c0be05db 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -291,6 +291,8 @@ def fit( steps_per_execution=self.steps_per_execution, ) + self._maybe_symbolic_build(iterator=epoch_iterator) + # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( @@ -406,6 +408,8 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) + self._maybe_symbolic_build(iterator=epoch_iterator) + # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( @@ -525,7 +529,6 @@ def train_on_batch( return_dict=False, ): self._assert_compile_called("train_on_batch") - self.make_train_function() if class_weight is not None: if sample_weight is not None: raise ValueError( @@ -538,6 +541,10 @@ def train_on_batch( y, class_weight ) + # Maybe build model + self._maybe_symbolic_build(data_batch=(x, y, sample_weight)) + self.make_train_function() + def data(): yield (x, y, sample_weight) @@ -555,11 +562,14 @@ def test_on_batch( return_dict=False, ): self._assert_compile_called("test_on_batch") - self.make_test_function() def data(): yield (x, y, sample_weight) + # Maybe build model + self._maybe_symbolic_build(data_batch=(x, y, sample_weight)) + self.make_test_function() + logs = self.test_function(data()) logs = tree.map_structure(lambda x: np.array(x), logs) if return_dict: @@ -621,6 +631,35 @@ def loss(self, y, y_pred, sample_weight=None): x=None, y=y, y_pred=y_pred, sample_weight=sample_weight ) + def _maybe_symbolic_build(self, iterator=None, data_batch=None): + # Only symbolic build when distribute strategy is created in tf trainer + if self._distribute_strategy is None: + # When no distribution strategy is set, defer building + # to when the train/test/predict function gets traced. + # This maximizes backwards compatibility. + return + + # Unlike jax/torch iterator, tf iterator returns an iterator instead + # of data batch in `iterator.enumerate_epoch()`. + if iterator is not None: + for _, it in iterator.enumerate_epoch(): + maybe_distributed_data_batch = next(it) + has_distributed_values = tree.map_structure( + lambda x: isinstance(x, tf.distribute.DistributedValues), + maybe_distributed_data_batch, + ) + if all(tree.flatten(has_distributed_values)): + data_batch = self.distribute_strategy.reduce( + "MEAN", + maybe_distributed_data_batch, + axis=None, + ) + else: + data_batch = maybe_distributed_data_batch + break + with self.distribute_strategy.scope(): + self._symbolic_build(data_batch=data_batch) + class TFEpochIterator(EpochIterator): def __init__(self, distribute_strategy=None, *args, **kwargs): diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index e9f52011e9d..65dadc6e298 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1913,8 +1913,12 @@ def test_moments_sync(self): reason="synchronized=True only implemented for TF backend", ) def test_moments_sync_with_distribution_strategy(self, dtype): + from tensorflow.python.eager import context + from keras.src.utils.module_utils import tensorflow as tf + context._reset_context() + # Config 2 CPUs for testing. logical_cpus = tf.config.list_logical_devices("CPU") if len(logical_cpus) == 1: @@ -1944,6 +1948,8 @@ def test_on_moments(inputs): self.assertEqual(variance.values[0], 8.75) self.assertEqual(variance.values[0], 8.75) + context._reset_context() + def test_batch_normalization(self): x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) mean = np.array([0.2, 0.3, 0.4]) diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 064df23adc2..fee4fa3f282 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1712,15 +1712,13 @@ def test_end_to_end_tf_distribute(self): loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"], ) - x = (np.arange(512) / 128).reshape((256, 2)) - y = (np.arange(256) % 2).reshape((256, 1)) - out_fit = model.fit(x, y) - self.assertLess( - out_fit.history["sparse_categorical_accuracy"][0], 0.6 - ) - out_eval = model.evaluate(x, y) - self.assertLess(out_eval[1], 0.6) - out_predict = model.predict(x) - self.assertEqual(out_predict.shape, (256, 2)) + x = (np.arange(512) / 128).reshape((256, 2)) + y = (np.arange(256) % 2).reshape((256, 1)) + out_fit = model.fit(x, y) + self.assertLess(out_fit.history["sparse_categorical_accuracy"][0], 0.6) + out_eval = model.evaluate(x, y) + self.assertLess(out_eval[1], 0.6) + out_predict = model.predict(x) + self.assertEqual(out_predict.shape, (256, 2)) context._reset_context() diff --git a/requirements.txt b/requirements.txt index 17c94a9a8fc..30356a67b13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Tensorflow. tensorflow-cpu~=2.16.2;sys_platform != 'darwin' # Pin to TF 2.16 tensorflow~=2.16.2;sys_platform == 'darwin' +tf_keras # Torch. # TODO: Pin to < 2.3.0 (GitHub issue #19602)