Skip to content

Commit

Permalink
Reland tf distribute fix (keras-team#20051)
Browse files Browse the repository at this point in the history
* Revert "Rollback tf distribute change (keras-team#20017)"

This reverts commit 36a0628.

* support model.fit outside of the scope in tf backend
  • Loading branch information
haohuanw committed Jul 28, 2024
1 parent a226835 commit f0e79cf
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 41 deletions.
6 changes: 3 additions & 3 deletions keras/src/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
v.value if isinstance(v, backend.Variable) else v
for v in trainable_variables
]
grads_and_vars = list(zip(grads, trainable_variables))
grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)
tf.__internal__.distribute.interim.maybe_merge_call(
self._distributed_tf_update_step,
self._distribution_strategy,
list(zip(grads, trainable_variables)),
grads_and_vars,
learning_rate,
)

def _distributed_tf_update_step(
self, distribution, grads_and_vars, learning_rate
):
grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)

def apply_grad_to_update_var(var, grad, learning_rate):
return self.update_step(grad, var, learning_rate)

Expand Down
66 changes: 40 additions & 26 deletions keras/src/backend/tensorflow/optimizer_distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.python.eager import context

from keras.src import backend
Expand All @@ -14,7 +15,7 @@
backend.backend() != "tensorflow",
reason="The distribute test can only run with TF backend.",
)
class OptimizerDistributeTest(testing.TestCase):
class OptimizerDistributeTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Need at least 2 devices for distribution related tests.
Expand All @@ -39,20 +40,32 @@ def test_config(self):
)
self.run_class_serialization_test(optimizer)

def test_single_step(self):
@parameterized.parameters([("keras_sgd",), ("tf_keras_sgd",)])
def test_single_step(self, optimizer_type):
if optimizer_type == "tf_keras_sgd":
try:
import tf_keras

optimizer_fn = tf_keras.optimizers.SGD
except (ImportError, AttributeError):
self.skipTest("tf_keras not installed")
else:
optimizer_fn = SGD
with self.strategy.scope():
optimizer = SGD(
optimizer = optimizer_fn(
learning_rate=0.5,
momentum=0.06,
)
grads = tf.constant([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
# use tf variable to work both in k2 & k3.
vars = tf.Variable([1.0, 2.0, 3.0, 4.0])

self.strategy.run(
lambda: optimizer.apply_gradients(zip([grads], [vars]))
)
def update():
grads = tf.constant([1.0, 6.0, 7.0, 2.0])
optimizer.apply_gradients(zip([grads], [vars]))

self.strategy.run(update)
self.assertAllClose(
vars, [0.5, -1.0, -0.5, 3.0], rtol=1e-4, atol=1e-4
vars, [0.0, -4.0, -4.0, 2.0], rtol=1e-4, atol=1e-4
)

def test_weight_decay(self):
Expand Down Expand Up @@ -91,31 +104,32 @@ def opt3_run():
def test_correctness_with_golden(self):
with self.strategy.scope():
optimizer = SGD(nesterov=True)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)

def update_grads():
grads = backend.convert_to_tensor(np.arange(0.1, 1.1, 0.1))
optimizer.apply_gradients(zip([grads], [x]))

def update_first_grads():
first_grads = backend.convert_to_tensor(np.full((10,), 0.01))
optimizer.apply_gradients(zip([first_grads], [x]))

# fmt: off
golden = np.array(
[[0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999,
0.9999, 0.9999], [0.9989, 0.9979, 0.9969, 0.9959, 0.9949, 0.9939,
0.9929, 0.9919, 0.9909, 0.9899], [0.9979, 0.9959, 0.9939, 0.9919,
0.9899, 0.9879, 0.9859, 0.9839, 0.9819, 0.9799], [0.9969, 0.9939,
0.9909, 0.9879, 0.9849, 0.9819, 0.9789, 0.9759, 0.9729, 0.9699],
[0.9959, 0.9919, 0.9879, 0.9839, 0.9799, 0.9759, 0.9719, 0.9679,
0.9639, 0.9599]]
[
[0.9980, 0.9960, 0.9940, 0.9920, 0.9900, 0.9880, 0.9860, 0.9840, 0.9820, 0.9800],
[0.9978, 0.9958, 0.9938, 0.9918, 0.9898, 0.9878, 0.9858, 0.9838, 0.9818, 0.9798],
[0.9976, 0.9956, 0.9936, 0.9916, 0.9896, 0.9876, 0.9856, 0.9836, 0.9816, 0.9796],
[0.9974, 0.9954, 0.9934, 0.9914, 0.9894, 0.9874, 0.9854, 0.9834, 0.9814, 0.9794],
[0.9972, 0.9952, 0.9932, 0.9912, 0.9892, 0.9872, 0.9852, 0.9832, 0.9812, 0.9792],
]
)
# fmt: on

self.strategy.run(
lambda: optimizer.apply_gradients(zip([first_grads], [x]))
)
self.strategy.run(update_grads)
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
self.strategy.run(
lambda: optimizer.apply_gradients(zip([grads], [x]))
)
self.strategy.run(update_first_grads)

def test_clip_norm(self):
with self.strategy.scope():
Expand Down Expand Up @@ -190,7 +204,7 @@ def test_gradient_accumulation(self):
self.assertAllClose(optimizer._iterations, 2)
self.assertAllClose(optimizer.iterations, 0)
self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))
self.assertAllClose(v, [[0.0, 1.0], [1.0, 2.0]])
self.assertAllClose(v, [[-1.0, 0.0], [-1.0, 0.0]])
self.assertAllClose(
optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]
)
Expand Down
43 changes: 41 additions & 2 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def fit(
steps_per_execution=self.steps_per_execution,
)

self._maybe_symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
Expand Down Expand Up @@ -406,6 +408,8 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

self._maybe_symbolic_build(iterator=epoch_iterator)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
Expand Down Expand Up @@ -525,7 +529,6 @@ def train_on_batch(
return_dict=False,
):
self._assert_compile_called("train_on_batch")
self.make_train_function()
if class_weight is not None:
if sample_weight is not None:
raise ValueError(
Expand All @@ -538,6 +541,10 @@ def train_on_batch(
y, class_weight
)

# Maybe build model
self._maybe_symbolic_build(data_batch=(x, y, sample_weight))
self.make_train_function()

def data():
yield (x, y, sample_weight)

Expand All @@ -555,11 +562,14 @@ def test_on_batch(
return_dict=False,
):
self._assert_compile_called("test_on_batch")
self.make_test_function()

def data():
yield (x, y, sample_weight)

# Maybe build model
self._maybe_symbolic_build(data_batch=(x, y, sample_weight))
self.make_test_function()

logs = self.test_function(data())
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
Expand Down Expand Up @@ -621,6 +631,35 @@ def loss(self, y, y_pred, sample_weight=None):
x=None, y=y, y_pred=y_pred, sample_weight=sample_weight
)

def _maybe_symbolic_build(self, iterator=None, data_batch=None):
# Only symbolic build when distribute strategy is created in tf trainer
if self._distribute_strategy is None:
# When no distribution strategy is set, defer building
# to when the train/test/predict function gets traced.
# This maximizes backwards compatibility.
return

# Unlike jax/torch iterator, tf iterator returns an iterator instead
# of data batch in `iterator.enumerate_epoch()`.
if iterator is not None:
for _, it in iterator.enumerate_epoch():
maybe_distributed_data_batch = next(it)
has_distributed_values = tree.map_structure(
lambda x: isinstance(x, tf.distribute.DistributedValues),
maybe_distributed_data_batch,
)
if all(tree.flatten(has_distributed_values)):
data_batch = self.distribute_strategy.reduce(
"MEAN",
maybe_distributed_data_batch,
axis=None,
)
else:
data_batch = maybe_distributed_data_batch
break
with self.distribute_strategy.scope():
self._symbolic_build(data_batch=data_batch)


class TFEpochIterator(EpochIterator):
def __init__(self, distribute_strategy=None, *args, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,8 +1913,12 @@ def test_moments_sync(self):
reason="synchronized=True only implemented for TF backend",
)
def test_moments_sync_with_distribution_strategy(self, dtype):
from tensorflow.python.eager import context

from keras.src.utils.module_utils import tensorflow as tf

context._reset_context()

# Config 2 CPUs for testing.
logical_cpus = tf.config.list_logical_devices("CPU")
if len(logical_cpus) == 1:
Expand Down Expand Up @@ -1944,6 +1948,8 @@ def test_on_moments(inputs):
self.assertEqual(variance.values[0], 8.75)
self.assertEqual(variance.values[0], 8.75)

context._reset_context()

def test_batch_normalization(self):
x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
mean = np.array([0.2, 0.3, 0.4])
Expand Down
18 changes: 8 additions & 10 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,15 +1712,13 @@ def test_end_to_end_tf_distribute(self):
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)
x = (np.arange(512) / 128).reshape((256, 2))
y = (np.arange(256) % 2).reshape((256, 1))
out_fit = model.fit(x, y)
self.assertLess(
out_fit.history["sparse_categorical_accuracy"][0], 0.6
)
out_eval = model.evaluate(x, y)
self.assertLess(out_eval[1], 0.6)
out_predict = model.predict(x)
self.assertEqual(out_predict.shape, (256, 2))
x = (np.arange(512) / 128).reshape((256, 2))
y = (np.arange(256) % 2).reshape((256, 1))
out_fit = model.fit(x, y)
self.assertLess(out_fit.history["sparse_categorical_accuracy"][0], 0.6)
out_eval = model.evaluate(x, y)
self.assertLess(out_eval[1], 0.6)
out_predict = model.predict(x)
self.assertEqual(out_predict.shape, (256, 2))

context._reset_context()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Tensorflow.
tensorflow-cpu~=2.16.2;sys_platform != 'darwin' # Pin to TF 2.16
tensorflow~=2.16.2;sys_platform == 'darwin'
tf_keras

# Torch.
# TODO: Pin to < 2.3.0 (GitHub issue #19602)
Expand Down

0 comments on commit f0e79cf

Please sign in to comment.