vjp of while_loop with dynamic checkpointing #12097
Unanswered
dionhaefner
asked this question in
Ideas
Replies: 1 comment 1 reply
-
Take a look at Diffrax's You might also find Diffrax's other machinery for ODE solvers interesting |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I was thinking about enabling reverse-mode AD for adaptive solvers (using
jax.lax.while_loop
internally). As I understand it, the current failure to support that is because JAX cannot know how much memory is needed for the backwards pass before the forward pass is executed, which violates XLA's static memory constraints.But what if we use a checkpointing scheme that stores a fixed number of checkpoints? In this case we would know in advance how much memory is needed for the backwards pass. I wanted to play around with this but I have no idea how I would go about implementing something like that. I probably need to define a custom
vjp
forwhile_loop
? Is there existing machinery ofjax.remat
that I can re-use?Beta Was this translation helpful? Give feedback.
All reactions