keeping a state #4074
-
hi, a jax newbie here. thanks for creating a great framework! i'm trying to figure out what's the best way to keep some running stats of computation across jax functions. for instance, i'm trying to implement batch normalization as a class, and this class keeps the running averages of the mean and variance of the batch statistics. see, e.g., the code snippet below.
i understand this goes against the jax's implementation and philosophy and that i need to think of a different way (e.g., return the new buffer content together with the output of the computation,) but there's a huge inertia for me as the current pytorch user to want to keep some states across computation. is there any other way around than to put everything in as a part of the arguments and return everything out as a part of the output to keep any state? or, perhaps, is there any plan to introduce some kind of global tensor storage for jax that can be easily accessible from any jax-based functions? i totally understand this goes against the philosophy, but i was just thinking aloud my wish. cheers, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
JAX doesn't (yet) have any builtin support for mutable state, but this is something you can find in a number of higher level neural net libraries build on top of JAX. For two examples, see |
Beta Was this translation helpful? Give feedback.
-
We've been thinking hard about this one for a while, with collaborators from the Flax, Haiku, Trax, and Oryx teams too. It's still a work in progress though. Basically, +1 to what @shoyer said, with the extra emphasis that we're interested in doing better here. |
Beta Was this translation helpful? Give feedback.
JAX doesn't (yet) have any builtin support for mutable state, but this is something you can find in a number of higher level neural net libraries build on top of JAX.
For two examples, see
haiku.transform_with_state
andflax.nn.stateful
.