-
As one step in my objective function I am solving a nonlinear system with Newton's method. I'm trying to be able to compute both gradients and hessians of the resulting objective. Following the jax online examples for custom vjp/jvp (one of which has a sign error in it, by the way), I can get both vjp and jvp to work on a simple example (simpler than the training example, I'd argue). I have one main question: Why can't we use forward mode differentation to a custom vjp? Here is the vjp:
If I uncomment this last line I get: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function. I have also implemented a jvp for this example (I can include if helpful). This works fine, but my understanding is that in reverse mode it will have to recompute the newton iterations on the backward pass. I'd ideally want a way to save info from the forward pass to use in the reverse pass (so as not to redo these calculations and only apply the linearized 'adjoint'), but also have that function work in forward differentiation so I can call hessian or use hessian-vector products. Is there a way to setup something like this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Thanks for the question! The reason is that a If you define a We could in principle attempt automatic transposition of a I think with your example you can define a Alternatively, if you want to define only a def hessian(f):
return jax.jacrev(jax.jacrev(f))
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) The autodiff cookbook has some discussion about these options. What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The reason is that a
custom_vjp
defines how a function should behave only under reverse-mode autodiff, using a custom rule. It leaves undefined how a function should behave under forward-mode autodiff.If you define a
custom_jvp
then you can use both forward- and reverse-mode autodiff because even though your rule only specifies how the function should behave under forward-mode autodiff, JAX will attempt to automatically transpose the forward-mode rule (see the pdf on this page for a description of how automatic transposition relates to autodiff).We could in principle attempt automatic transposition of a
custom_vjp
rule. We haven't yet only because that'd be a bi…