Gradient clipping #22823
-
How to carry out gradient clipping in JAX's ecosystem? I havent seen a clear cannonical method. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Thanks for the question! Do you mean like this gradient clipping, or did you have something else in mind? |
Beta Was this translation helpful? Give feedback.
-
Dear @mattjj For example, taking a segment of my code:
I imagine I iterate in a self recursive manner through grads and clip? But are there any functionality already predefined to do this? Namely, a cannonical, recommended way? |
Beta Was this translation helpful? Give feedback.
@mattjj might be able to say more, but I think the typical approach would be to use something like the
clip
transformation from optax (as documented here), especially if you're already using optax for training. Even if you're not using optax, I expect that you could repurpose those functions.