-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory optimization for jax trainer. #888
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #888 +/- ##
==========================================
- Coverage 79.75% 79.70% -0.06%
==========================================
Files 318 318
Lines 28638 28657 +19
Branches 5451 5460 +9
==========================================
Hits 22841 22841
- Misses 4333 4351 +18
- Partials 1464 1465 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
Let's park this PR for now, since there are some underlying issues (discovered in the unit test) need to be address. |
ok. I think i figure out the issue.
|
Humm, its more tricky than I expected, esp for RNN. In the RNN (LSTM/GRU) with dropout, the seed generator get updated within the jax.lax.scan() in keras-core/keras_core/backend/jax/rnn.py Line 188 in b4019bc
I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost. Will take a closer look tomorrow. |
We previously observed some weirdness in JAX RNNs (though not connected to dropout it seems). Maybe some relationship there? #322 |
I think we might want to add a stateless version of the step function / etc. |
The particular issue we hit is for the RNG state, which gets updated when using dropout. It runs fine if I disable the dropout for the layer. Need to dig a bit more. |
Having said that, I think my PR will actually fix the issue in #322. Due to the stateless scope in the jax.lax.scan function, it will only read the staled variable value, which is probably why it is not trained properly. |
So the stateless RNN test failure should be addressed by #924, and the test for this PR should pass now. |
The trainable vars are passed in and returned by the eval_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you!
Update the JAX trainer to use less memory during fit/eval.
Currently we keep 3 copies of all the model variable state during training:
the 2 and 3 will keep getting update during the training process, but 1 will be a stale copy and unnecessarily occupying the heap size. In the large model case, this will be huge.
This PR will purge the 1, and it will restore the KerasVariable at the end of the epoch. This save the memory size by 33%. From some early test result internally, the per device memory usage is reduce from 9.49G to 6.65G for a OPT2 model.
I will send follow up PR to address additional memory usage for eval and predict functions.