diff --git a/econpizza/solvers/shooting.py b/econpizza/solvers/shooting.py index 07a0d7f..9a41d7b 100644 --- a/econpizza/solvers/shooting.py +++ b/econpizza/solvers/shooting.py @@ -101,8 +101,7 @@ def solve_current(pars, shock, XLag, XLastGuess, XPrime): """Solves for one period. """ - # partial_func = jax.tree_util.Partial(func, XLag=XLag, XPrime=XPrime, XSS=stst, shocks=shock, pars=pars) - def partial_func(x): return func(XLag, x, XPrime, stst, shock, pars) + def partial_func(x): return func(XLag, x, XPrime, stst, pars, shock) jav = val_and_jacfwd(partial_func) partial_jav = jax.tree_util.Partial(jav) res = newton_jax_jit(partial_jav, XLastGuess, verbose=False)