Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax trainer function to save memory buffer. #897

Merged
merged 3 commits into from
Sep 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

Check warning on line 1 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L1

Added line #L1 was not covered by tests

import jax
import numpy as np
import tree
Expand Down Expand Up @@ -237,8 +239,11 @@
train_step = one_train_step

if not self.run_eagerly and self.jit_compile:

@jax.jit
# Note that we mark the state and data to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
@partial(jax.jit, donate_argnames="state")

Check warning on line 246 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L246

Added line #L246 was not covered by tests
def compiled_train_step(state, data):
return train_step(state, data)

Expand Down Expand Up @@ -266,8 +271,11 @@
test_step = one_test_step

if not self.run_eagerly and self.jit_compile:

@jax.jit
# Note that we mark the state and data to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
@partial(jax.jit, donate_argnames="state")

Check warning on line 278 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L278

Added line #L278 was not covered by tests
def compiled_test_step(state, data):
return test_step(state, data)

Expand Down Expand Up @@ -578,15 +586,18 @@
)
data = self._distribute_data(data)
logs, state = self.test_function(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
_, non_trainable_variables, metrics_variables = state
(

Check warning on line 589 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L589

Added line #L589 was not covered by tests
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
Expand Down Expand Up @@ -764,8 +775,9 @@
logs, state = self.test_function(state, [data])

# State sync
_, non_trainable_variables, metrics_variables = state
trainable_variables, non_trainable_variables, metrics_variables = state

Check warning on line 778 in keras_core/backend/jax/trainer.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/trainer.py#L778

Added line #L778 was not covered by tests
self._jax_state = {
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
Expand Down