Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization for jax trainer. #888

Merged
merged 7 commits into from
Sep 20, 2023
Merged

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Sep 14, 2023

Update the JAX trainer to use less memory during fit/eval.

Currently we keep 3 copies of all the model variable state during training:

  1. jax.array attached to the KerasVariable
  2. jax.array as input to the jax.jit() train/eval function
  3. jax.array returned from jax.jit() train/eval function, which is also attached to the trainer.jax_state.

the 2 and 3 will keep getting update during the training process, but 1 will be a stale copy and unnecessarily occupying the heap size. In the large model case, this will be huge.

This PR will purge the 1, and it will restore the KerasVariable at the end of the epoch. This save the memory size by 33%. From some early test result internally, the per device memory usage is reduce from 9.49G to 6.65G for a OPT2 model.

I will send follow up PR to address additional memory usage for eval and predict functions.

@codecov
Copy link

codecov bot commented Sep 14, 2023

Codecov Report

Patch coverage: 12.50% and project coverage change: -0.06% ⚠️

Comparison is base (1704ecf) 79.75% compared to head (3fd4ad4) 79.70%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #888      +/-   ##
==========================================
- Coverage   79.75%   79.70%   -0.06%     
==========================================
  Files         318      318              
  Lines       28638    28657      +19     
  Branches     5451     5460       +9     
==========================================
  Hits        22841    22841              
- Misses       4333     4351      +18     
- Partials     1464     1465       +1     
Flag Coverage Δ
keras_core 79.62% <12.50%> (-0.06%) ⬇️
keras_core-numpy 60.37% <8.33%> (-0.04%) ⬇️
keras_core-tensorflow 66.78% <12.50%> (-0.05%) ⬇️
keras_core-torch 69.21% <12.50%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/jax/trainer.py 0.00% <0.00%> (ø)
keras_core/layers/layer.py 86.64% <0.00%> (-0.28%) ⬇️
keras_core/backend/common/variables.py 75.42% <75.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 15, 2023

Let's park this PR for now, since there are some underlying issues (discovered in the unit test) need to be address.

@fchollet fchollet changed the title Momery optimization for jax trainer. Memory optimization for jax trainer. Sep 16, 2023
@qlzh727
Copy link
Member Author

qlzh727 commented Sep 19, 2023

ok. I think i figure out the issue.

  1. The regularizer is actually a bug in the existing code, where a stale version of variable value is used. The layers.py is updated to retrieve the latest value if it is in a stateless scope.
  2. The RNN failed since it was trying to create a new stateless scope under the hood, which doesn't have any variable mapping. I added a conditional creation of the Stateless scope to fix the issue.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 19, 2023

Humm, its more tricky than I expected, esp for RNN. In the RNN (LSTM/GRU) with dropout, the seed generator get updated within the jax.lax.scan() in

new_states, outputs = lax.scan(
(we even have a comment about the function need to be stateless). Since the step function doesn't return the RNG state as explicit result, when we captured the updated RNG from train function and reuse it for jax constraint, JAX complain about this leaked state from the scan function.

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

Will take a closer look tomorrow.

@fchollet
Copy link
Member

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

We previously observed some weirdness in JAX RNNs (though not connected to dropout it seems). Maybe some relationship there? #322

@fchollet
Copy link
Member

Since the step function doesn't return the RNG state as explicit result, when we captured the updated RNG from train function and reuse it for jax constraint, JAX complain about this leaked state from the scan function.

I think we might want to add a stateless version of the step function / etc.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 19, 2023

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

We previously observed some weirdness in JAX RNNs (though not connected to dropout it seems). Maybe some relationship there? #322

The particular issue we hit is for the RNG state, which gets updated when using dropout. It runs fine if I disable the dropout for the layer. Need to dig a bit more.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 19, 2023

Having said that, I think my PR will actually fix the issue in #322. Due to the stateless scope in the jax.lax.scan function, it will only read the staled variable value, which is probably why it is not trained properly.

@qlzh727
Copy link
Member Author

qlzh727 commented Sep 19, 2023

So the stateless RNN test failure should be addressed by #924, and the test for this PR should pass now.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you!

@fchollet fchollet merged commit cbc2e47 into keras-team:main Sep 20, 2023
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants