Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization for jax trainer. #888

Merged
merged 7 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions keras_core/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def value(self):

def assign(self, value):
value = self._convert_to_tensor(value, dtype=self.dtype)
if not shape_equal(value, self.value):
if not shape_equal(value.shape, self.shape):
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
Expand Down Expand Up @@ -446,11 +446,11 @@ def standardize_shape(shape):
return shape


def shape_equal(a, b):
"""Return whether a.shape == b.shape (allows None entries)."""
if len(a.shape) != len(b.shape):
def shape_equal(a_shape, b_shape):
"""Return whether a_shape == b_shape (allows None entries)."""
if len(a_shape) != len(b_shape):
return False
for e1, e2 in zip(a.shape, b.shape):
for e1, e2 in zip(a_shape, b_shape):
if e1 is not None and e2 is not None and e1 != e2:
return False
return True
Expand Down
41 changes: 40 additions & 1 deletion keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@
return_losses=True,
**kwargs,
)
loss = self.compute_loss(x, y, y_pred, sample_weight, allow_empty=True)

trainable_mapping = zip(self.trainable_variables, trainable_variables)
with backend.StatelessScope(state_mapping=trainable_mapping):

Check warning on line 49 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L48-L49

Added lines #L48 - L49 were not covered by tests
# Note that this is needed for the regularization loss, which need
# the latest value of train/non-trainable variables.
loss = self.compute_loss(

Check warning on line 52 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L52

Added line #L52 was not covered by tests
x, y, y_pred, sample_weight, allow_empty=True
)
if losses:
loss += ops.sum(losses)
unscaled_loss = loss
Expand Down Expand Up @@ -418,6 +425,7 @@
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables()

Check warning on line 428 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L428

Added line #L428 was not covered by tests
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)
Expand Down Expand Up @@ -576,6 +584,7 @@
]
metrics_variables = [v.value for v in self.metrics_variables]

self._purge_model_variables(optimizer_variables=False)

Check warning on line 587 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L587

Added line #L587 was not covered by tests
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)

Expand Down Expand Up @@ -907,3 +916,33 @@
optimizer_variables,
metrics_variables,
)

def _purge_model_variables(

Check warning on line 920 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L920

Added line #L920 was not covered by tests
self,
trainable_variables=True,
non_trainable_variables=True,
optimizer_variables=True,
metric_variables=True,
):
"""Remove all the model variable for memory saving.

During JAX training, since the training function are stateless, we have
to pass in and get the model weights over and over, during which the
copy of the weights that attached to the KerasVariable are still and
occupying extra memory. We remove those variable to save memory (for
better memory utilization) at the beginning of the epoch, and reattach
the value back to variables at the end of the epoch, via
`jax_state_sync()`.
"""
if trainable_variables:
for v in self.trainable_variables:
v._value = None

Check warning on line 939 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L939

Added line #L939 was not covered by tests
if non_trainable_variables:
for v in self.non_trainable_variables:
v._value = None

Check warning on line 942 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L942

Added line #L942 was not covered by tests
if optimizer_variables:
for v in self.optimizer.variables:
v._value = None

Check warning on line 945 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L945

Added line #L945 was not covered by tests
if metric_variables:
for v in self.metrics_variables:
v._value = None

Check warning on line 948 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L948

Added line #L948 was not covered by tests
2 changes: 2 additions & 0 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,8 @@
losses.extend(layer._get_own_losses())
weight_regularization_losses = []
for v in self.trainable_weights:
if backend.in_stateless_scope():
v = backend.get_stateless_scope().get_current_value(v)

Check warning on line 1044 in keras_core/layers/layer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/layers/layer.py#L1044

Added line #L1044 was not covered by tests
regularizer = getattr(v, "regularizer", None)
if regularizer:
weight_regularization_losses.append(regularizer(v))
Expand Down
57 changes: 57 additions & 0 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,63 @@ def test_predict_flow(self, run_eagerly, jit_compile):
self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3)))
self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3)))

@pytest.mark.skipif(
backend.backend() != "jax",
reason="Memory optimization is only implemented in JAX",
)
def test_fit_eval_flow_for_jax_model_weights(self):
model = ExampleModel(units=3)
epochs = 3
batch_size = 20
steps_per_epoch = 7
dataset_size = batch_size * (steps_per_epoch - 2)
x = np.ones((dataset_size, 4))
y = np.zeros((dataset_size, 3))

class ModelWeightCheck(Callback):
def __init__(self):
super().__init__()

# Note that we access model via self._model since self.model
# will trigger a sync of the jax training state back to the model.
def on_train_batch_begin(self, batch, logs=None):
for v in self._model.trainable_variables:
assert v._value is None
for v in self._model.non_trainable_variables:
assert v._value is None
for v in self._model.optimizer.variables:
assert v._value is None
for v in self._model.metrics_variables:
assert v._value is None

def on_test_batch_begin(self, batch, logs=None):
for v in self._model.non_trainable_variables:
assert v._value is None
for v in self._model.metrics_variables:
assert v._value is None

model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)

model.fit(
x,
y,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[ModelWeightCheck()],
)

model.evaluate(
x,
y,
batch_size=batch_size,
callbacks=[ModelWeightCheck()],
)

@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(
backend.backend() == "torch",
Expand Down
Loading