From e424ba94dc5983a2fdfe58f0d0fc373d7d8e801f Mon Sep 17 00:00:00 2001 From: Gregor Boehl Date: Tue, 29 Oct 2024 14:27:15 +0100 Subject: [PATCH] add more flexible transformations --- econpizza/__init__.py | 26 ++++++---- econpizza/parser/__init__.py | 29 +++++++---- econpizza/parser/build_generic_functions.py | 8 +-- econpizza/parser/checks.py | 10 ++-- econpizza/parser/compile_model_functions.py | 16 +++--- econpizza/parser/het_agent_base_funcs.py | 12 +++-- econpizza/solvers/stacking.py | 24 ++++----- econpizza/solvers/steady_state.py | 54 ++++++++++----------- 8 files changed, 97 insertions(+), 82 deletions(-) diff --git a/econpizza/__init__.py b/econpizza/__init__.py index a334be3..49bb970 100644 --- a/econpizza/__init__.py +++ b/econpizza/__init__.py @@ -18,7 +18,8 @@ # set number of cores for XLA -os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}" +os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={ + os.cpu_count()}" jax.config.update("jax_enable_x64", True) @@ -57,13 +58,14 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None): a dictionary of the distributions """ + # get transformers (experimental) + transform_back = self['options'].get('transform_back') or (lambda x: x) + dist0 = jnp.array(init_dist) if init_dist is not None else jnp.array( self['steady_state'].get('distributions')) - if self.get('exp_all'): - pars = jnp.log(jnp.array(list(self['pars'].values())) if pars is None else pars) - trajectory = jnp.log(trajectory) - else: - pars = jnp.array(list(self['pars'].values())) if pars is None else pars + pars = transform_back( + jnp.array(list(self['pars'].values())) if pars is None else pars) + trajectory = transform_back(trajectory) shocks = self.get("shocks") or () dist_names = list(self['distributions'].keys()) decisions_inputs = self['decisions']['inputs'] @@ -81,13 +83,17 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None): # get functions and execute backwards_sweep = self['context']['backwards_sweep'] forwards_sweep = self['context']['forwards_sweep'] - wf_storage, decisions_output_storage = backwards_sweep(x, x0, shock_series.T, pars, return_wf=True) + wf_storage, decisions_output_storage = backwards_sweep( + x, x0, shock_series.T, pars, return_wf=True) dists_storage = forwards_sweep(decisions_output_storage, dist0) # store this - rdict = {oput: wf_storage[i] for i, oput in enumerate(decisions_inputs)} - rdict.update({oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)}) - rdict.update({oput: dists_storage[i] for i, oput in enumerate(dist_names)}) + rdict = {oput: wf_storage[i] + for i, oput in enumerate(decisions_inputs)} + rdict.update( + {oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)}) + rdict.update({oput: dists_storage[i] + for i, oput in enumerate(dist_names)}) return rdict diff --git a/econpizza/parser/__init__.py b/econpizza/parser/__init__.py index 3ade1ca..7c0e3d5 100644 --- a/econpizza/parser/__init__.py +++ b/econpizza/parser/__init__.py @@ -188,6 +188,16 @@ def _define_function(func_str, context): return tmpf.name +def wrap_with_transform(func, transform): + if transform: + def outfunc(XLag, X, XPrime, XSS, pars, *args, **kwargs): + Xandpar = (transform(y) for y in (XLag, X, XPrime, XSS, pars)) + return func(*Xandpar, *args, **kwargs) + return outfunc + else: + return func + + def _get_pre_stst_mapping(init_vals, fixed_values, evars, par_names): """Define the mapping from init & fixed vals to model variables & parameters """ @@ -329,6 +339,11 @@ def load( raise TypeError( f'parameters must be a list and not {type(par_names)}.') + # get and evaluate options + _ = _define_subdict_if_absent(model, "options") + model['options'], _ = _eval_strs( + model['options'], context=model['context']) + # get function strings for decisions and distributions, if they exist if model.get('decisions'): decisions_outputs = model['decisions']['outputs'] @@ -337,11 +352,8 @@ def load( evars, par_names, shocks, decisions_inputs, decisions_outputs, model['decisions']['calls']) _define_function(model['func_strings'] ['func_backw'], model['context']) - if model.get('exp_all'): - model['context']['func_backw'] = lambda xl, xc, xp, XSS, WFPrime, shocks, pars: model['context']['func_backw_raw']( - jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), WFPrime, shocks, jnp.exp(pars)) - else: - model['context']['func_backw'] = model['context']['func_backw_raw'] + model['context']['func_backw'] = wrap_with_transform( + model['context']['func_backw_raw'], model['options'].get('transform_to')) else: decisions_outputs = [] decisions_inputs = [] @@ -359,11 +371,8 @@ def load( # writing to tempfiles helps to get nice debug traces if the model does not work _define_function(model['func_strings']['func_eqns'], model['context']) - if model.get('exp_all'): - model['context']['func_eqns'] = lambda xl, xc, xp, XSS, shocks, pars, distributions, decisions_outputs: model['context']['func_eqns_raw']( - jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), shocks, jnp.exp(pars), distributions, decisions_outputs) - else: - model['context']['func_eqns'] = model['context']['func_eqns_raw'] + model['context']['func_eqns'] = wrap_with_transform( + model['context']['func_eqns_raw'], model['options'].get('transform_to')) # compile fixed and initial values stst_inputs = compile_stst_inputs(model) # try if function works on initvals diff --git a/econpizza/parser/build_generic_functions.py b/econpizza/parser/build_generic_functions.py index ad61eb9..9839213 100644 --- a/econpizza/parser/build_generic_functions.py +++ b/econpizza/parser/build_generic_functions.py @@ -132,7 +132,8 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver combined_sweep = model['context']['combined_sweep'] distSS = jnp.array(model['steady_state']['distributions']) - decisions_outputSS = (jnp.array(d)[..., None] for d in list(model['steady_state']['decisions'].values())) + decisions_outputSS = (jnp.array(d)[..., None] for d in list( + model['steady_state']['decisions'].values())) # basis for steady state jacobian construction basis = jnp.zeros((nvars*(horizon-1), nvars)) @@ -148,7 +149,7 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver # get steady state jacobians for direct effects x on f jacrev_func_eqns = jax.jacrev(func_eqns, argnums=(0, 1, 2)) f2X = jacrev_func_eqns(stst[:, None], stst[:, None], stst[:, None], - stst, zshocks[:, 0], pars, distSS[..., None], decisions_outputSS) + stst, pars, zshocks[:, 0], distSS[..., None], decisions_outputSS) if verbose: duration = time.time() - st @@ -174,7 +175,8 @@ def get_stacked_func_het_agents(func_backw, func_forw, func_eqns, stst, wfSS, ho forwards_sweep, horizon=horizon, func_forw=partial_forw) final_step_local = jax.tree_util.Partial( final_step, stst=stst, horizon=horizon, nshpe=nshpe, func_eqns=func_eqns) - combined_sweep_local = jax.tree_util.Partial(combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local) + combined_sweep_local = jax.tree_util.Partial( + combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local) stacked_func_local = jax.tree_util.Partial( stacked_func_het_agents, backwards_sweep=backwards_sweep_local, combined_sweep=combined_sweep_local) diff --git a/econpizza/parser/checks.py b/econpizza/parser/checks.py index 9e1b5d1..727f418 100644 --- a/econpizza/parser/checks.py +++ b/econpizza/parser/checks.py @@ -53,7 +53,8 @@ def check_determinancy(evars, eqns): sorted_evars = evars[:] = sorted(list(set(evars)), key=str.lower) if len(sorted_evars) != len(eqns): raise Exception( - f"Model has {len(sorted_evars)} variables but {len(eqns)} equations." + f"Model has {len(sorted_evars)} variables but { + len(eqns)} equations." ) return sorted_evars @@ -71,7 +72,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre mess = '' if model.get('decisions'): # make a test backward and forward run - _, decisions_output_init = model['context']['func_backw'](init_vals, init_vals, init_vals, init_vals, init_wf, jnp.zeros(len(shocks)), par) + _, decisions_output_init = model['context']['func_backw']( + init_vals, init_vals, init_vals, init_vals, par, init_wf, jnp.zeros(len(shocks))) dists_init, _ = model['context']['func_forw_stst']( decisions_output_init, 1e-8, 1) @@ -91,8 +93,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre # final test of main function init_vals = init_vals[..., None] - test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, jnp.zeros( - len(shocks)), par, jnp.array(dists_init)[..., None], (doi[...,None] for doi in decisions_output_init)) + test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, par, jnp.zeros( + len(shocks)), jnp.array(dists_init)[..., None], (doi[..., None] for doi in decisions_output_init)) if mess: pass diff --git a/econpizza/parser/compile_model_functions.py b/econpizza/parser/compile_model_functions.py index e834432..f7e0750 100644 --- a/econpizza/parser/compile_model_functions.py +++ b/econpizza/parser/compile_model_functions.py @@ -25,7 +25,7 @@ def compile_backw_func_str(evars, par, shocks, inputs, outputs, calls): if isinstance(calls, str): calls = calls.splitlines() - func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, WFPrime, shocks, pars): + func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, pars, WFPrime, shocks): {compile_func_basics_str(evars, par, shocks)} \n ({"".join(v + ", " for v in inputs)}) = WFPrime \n %s @@ -49,8 +49,10 @@ def get_forw_funcs(model): dist = distributions[dist_name] # *_generic should be depreciated at some point - implemented_endo = ('exogenous', 'exogenous_rouwenhorst', 'exogenous_generic', 'exogenous_custom') - implemented_exo = ('endogenous', 'endogenous_log', 'endogenous_generic', 'endogenous_custom') + implemented_endo = ('exogenous', 'exogenous_rouwenhorst', + 'exogenous_generic', 'exogenous_custom') + implemented_exo = ('endogenous', 'endogenous_log', + 'endogenous_generic', 'endogenous_custom') exog = [v for v in dist if dist[v]['type'] in implemented_endo] endo = [v for v in dist if dist[v]['type'] in implemented_exo] other = [dist[v]['type'] for v in dist if dist[v] @@ -68,13 +70,15 @@ def get_forw_funcs(model): # for each object, check if it is provided in decisions_outputs try: - transition = model['decisions']['outputs'].index(dist[exog[0]]['transition_name']) + transition = model['decisions']['outputs'].index( + dist[exog[0]]['transition_name']) except ValueError: transition = model['context'][dist[exog[0]]['transition_name']] grids = [] for i in endo: try: - grids.append(model['decisions']['outputs'].index(dist[i]['grid_name'])) + grids.append(model['decisions'] + ['outputs'].index(dist[i]['grid_name'])) except ValueError: grids.append(model['context'][dist[i]['grid_name']]) indices = [model['decisions']['outputs'].index(i) for i in endo] @@ -108,7 +112,7 @@ def compile_eqn_func_str(evars, eqns, par, eqns_aux, shocks, distributions, deci eqns_stack = "\n ".join(eqns) # compile the final function string - func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, shocks, pars, distributions=[], decisions_outputs=[]): + func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, pars, shocks, distributions=[], decisions_outputs=[]): {compile_func_basics_str(evars, par, shocks)} \n ({"".join(d+', ' for d in distributions)}) = distributions \n ({"".join(d+', ' for d in decisions_outputs)}) = decisions_outputs diff --git a/econpizza/parser/het_agent_base_funcs.py b/econpizza/parser/het_agent_base_funcs.py index d0f3771..aacd614 100644 --- a/econpizza/parser/het_agent_base_funcs.py +++ b/econpizza/parser/het_agent_base_funcs.py @@ -16,7 +16,7 @@ def _backwards_stst_cond(carry): def _backwards_stst_body(carry): (x, par), (wf, _), (_, cnt), (func, tol, maxit) = carry - return (x, par), func(x, x, x, x, wf, pars=par), (wf, cnt + 1), (func, tol, maxit) + return (x, par), func(x, x, x, x, pars=par, WFPrime=wf), (wf, cnt + 1), (func, tol, maxit) def backwards_sweep_stst(x, par, carry): @@ -29,7 +29,7 @@ def _backwards_step(carry, i): wf, X, shocks, func_backw, stst, pars = carry wf, decisions_output = func_backw( - X[:, i], X[:, i+1], X[:, i+2], WFPrime=wf, shocks=shocks[:, i], pars=pars) + X[:, i], X[:, i+1], X[:, i+2], pars=pars, WFPrime=wf, shocks=shocks[:, i]) return (wf, X, shocks, func_backw, stst, pars), (wf, decisions_output) @@ -40,7 +40,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array _, (wf_storage, decisions_output_storage) = jax.lax.scan( _backwards_step, (wfSS, X, shocks, func_backw, stst, pars), jnp.arange(horizon-1), reverse=True) - decisions_output_storage = [jnp.moveaxis(dos, 0, -1) for dos in decisions_output_storage] + decisions_output_storage = [jnp.moveaxis( + dos, 0, -1) for dos in decisions_output_storage] wf_storage = jnp.moveaxis(wf_storage, 0, -1) if return_wf: @@ -51,7 +52,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array def _forwards_step(carry, i): dist_old, decisions_output_storage, func_forw = carry - dist = func_forw(dist_old, [dos[..., i] for dos in decisions_output_storage]) + dist = func_forw(dist_old, [dos[..., i] + for dos in decisions_output_storage]) return (dist, decisions_output_storage, func_forw), dist_old @@ -69,7 +71,7 @@ def final_step(x: Array, dists_storage: Array, decisions_output_storage: Array, X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T out = func_eqns(X[:, :-2].reshape(nshpe), X[:, 1:-1].reshape(nshpe), X[:, 2:].reshape( - nshpe), stst, shocks, pars, dists_storage, decisions_output_storage) + nshpe), stst, pars, shocks, dists_storage, decisions_output_storage) return out diff --git a/econpizza/solvers/stacking.py b/econpizza/solvers/stacking.py index 44ad46a..b0096ee 100644 --- a/econpizza/solvers/stacking.py +++ b/econpizza/solvers/stacking.py @@ -72,22 +72,19 @@ def find_path_stacking( # only skip jacobian calculation if it exists skip_jacobian = skip_jacobian if self['cache'].get( 'jac_factorized') else False + # get transformers (experimental) + transform_to = self['options'].get('transform_to') or (lambda x: x) + transform_back = self['options'].get('transform_back') or (lambda x: x) # get variables nvars = len(self["var_names"]) - if self.get('exp_all'): - stst = jnp.log(d2jnp(self["stst"])) - pars = jnp.log(d2jnp(pars if pars is not None else self["pars"])) - else: - stst = d2jnp(self["stst"]) - pars = d2jnp(pars if pars is not None else self["pars"]) + stst = transform_back(d2jnp(self["stst"])) + pars = transform_back(d2jnp(pars if pars is not None else self["pars"])) shocks = self.get("shocks") or () # get initial guess - if self.get('exp_all'): - x0 = jnp.log(jnp.array(list(init_state))) if init_state is not None else stst - else: - x0 = jnp.array(list(init_state)) if init_state is not None else stst + x0 = transform_back(jnp.array(list(init_state)) + ) if init_state is not None else stst init_dist = init_dist if init_dist is not None else self['steady_state'].get( 'distributions') dist0 = jnp.array(init_dist if init_dist is not None else jnp.nan) @@ -110,7 +107,7 @@ def find_path_stacking( func_eqns = self['context']["func_eqns"] jav_func_eqns = val_and_jacrev(func_eqns, (0, 1, 2)) jav_func_eqns_partial = jax.tree_util.Partial( - jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[]) + jav_func_eqns, pars=pars, XSS=stst, distributions=[], decisions_outputs=[]) self['context']['jav_func'] = jav_func_eqns_partial # mark as compiled write_cache(self, horizon, pars, stst) @@ -161,7 +158,4 @@ def find_path_stacking( elif verbose: print(mess) - if self.get('exp_all'): - return jnp.exp(x_out), (flag, f) - else: - return x_out, (flag, f) + return transform_to(x_out), (flag, f) diff --git a/econpizza/solvers/steady_state.py b/econpizza/solvers/steady_state.py index fa51f0f..40b7dce 100644 --- a/econpizza/solvers/steady_state.py +++ b/econpizza/solvers/steady_state.py @@ -31,7 +31,8 @@ def _get_stst_dist_objs(self, res, maxit_backwards, maxit_forwards): elif jnp.isnan(distSS).any(): mess += f"Forward iteration returns NaNs. " elif distSS.min() < 0: - mess += f"Distribution contains negative values ({distSS.min():0.1e}). " + mess += f"Distribution contains negative values ({ + distSS.min():0.1e}). " if cnt_backwards == maxit_backwards: mess += f'Maximum of {maxit_backwards} backwards calls reached. ' if cnt_forwards == maxit_forwards: @@ -99,30 +100,26 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 func_backw = self['context'].get('func_backw') func_forw_stst = self['context'].get('func_forw_stst') func_pre_stst = self['context']['func_pre_stst'] + # get transformers (experimental) + transform_to = self['options'].get('transform_to') or (lambda x: x) + transform_back = self['options'].get('transform_back') or (lambda x: x) # get initial values for heterogenous agents decisions_output_init = self['context']['init_run'].get( 'decisions_output') # get the actual steady state function - if self.get('exp_all'): - func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=jnp.log(d2jnp(fixed_vals)), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) - else: - func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=d2jnp(fixed_vals), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) + func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=transform_back(d2jnp( + fixed_vals)), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards) # store jitted stst function that returns jacobian and func. value self["context"]['func_stst'] = func_stst if not self['steady_state'].get('skip'): # actual root finding - if self.get('exp_all'): - res = newton_jax(func_stst, jnp.log(d2jnp(init_vals)), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) - else: - res = newton_jax(func_stst, d2jnp(init_vals), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) + res = newton_jax(func_stst, transform_back( + d2jnp(init_vals)), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs) else: - if self.get('exp_all'): - f, jac, aux = func_stst(jnp.log(d2jnp(init_vals))) - else: - f, jac, aux = func_stst(d2jnp(init_vals)) + f, jac, aux = func_stst(transform_back(d2jnp(init_vals))) res = {'x': d2jnp(init_vals), 'fun': f, 'jac': jac, @@ -132,22 +129,18 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 } # exchange those values that are identified via stst_equations - if self.get('exp_all'): - stst_vals, par_vals = func_pre_stst(res['x'], jnp.log(d2jnp(fixed_vals)), pre_stst_mapping) - else: - stst_vals, par_vals = func_pre_stst(res['x'], d2jnp(fixed_vals), pre_stst_mapping) - res['initial_values'] = {'guesses': init_vals, 'fixed': fixed_vals, 'value_functions': wf_init, 'decisions': decisions_output_init} + stst_vals, par_vals = func_pre_stst( + res['x'], jnp.log(d2jnp(fixed_vals)), pre_stst_mapping) + res['initial_values'] = {'guesses': init_vals, 'fixed': fixed_vals, + 'value_functions': wf_init, 'decisions': decisions_output_init} # store results self['steady_state']['root_finding_result'] = res - if self.get('exp_all'): - self['steady_state']['found_values'] = dict(zip(init_vals.keys(),jnp.exp(res['x']))) - self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, jnp.exp(stst_vals))) - self['pars'] = dict(zip(par_names, jnp.exp(par_vals))) - else: - self['steady_state']['found_values'] = dict(zip(init_vals.keys(),res['x'])) - self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, stst_vals)) - self['pars'] = dict(zip(par_names, par_vals)) + self['steady_state']['found_values'] = dict( + zip(init_vals.keys(), transform_to(res['x']))) + self['stst'] = self['steady_state']['all_values'] = dict( + zip(evars, transform_to(stst_vals))) + self['pars'] = dict(zip(par_names, transform_to(par_vals))) # calculate dist objects and compile message mess = _get_stst_dist_objs(self, res, maxit_backwards, @@ -163,13 +156,15 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 nvars = len(evars)+len(par_names) nfixed = len(fixed_vals) if rank != nvars - nfixed: - mess += f"Jacobian has rank {rank} for {nvars - nfixed} degrees of freedom ({nfixed} out of a total of {nvars} variables/parameters were fixed). " + mess += f"Jacobian has rank {rank} for {nvars - nfixed} degrees of freedom ({ + nfixed} out of a total of {nvars} variables/parameters were fixed). " # check if any of the fixed variables are neither a parameter nor variable if mess: not_var_nor_par = list( set(self['steady_state']['fixed_values']) - set(evars) - set(par_names)) - mess += f"Fixed value(s) ``{'``, ``'.join(not_var_nor_par)}`` not declared. " if not_var_nor_par else '' + mess += f"Fixed value(s) ``{'``, ``'.join(not_var_nor_par) + }`` not declared. " if not_var_nor_par else '' if err > tol or not res['success']: if not res["success"] or raise_errors: @@ -177,7 +172,8 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200 err) else f" (max. error is {err:1.2e} in eqn. {errarg})" mess = f"Steady state FAILED{location}. {res['message']} {mess}" else: - mess = f"{res['message']} WARNING: Steady state error is {err:1.2e} in eqn. {errarg}. {mess}" + mess = f"{res['message']} WARNING: Steady state error is { + err:1.2e} in eqn. {errarg}. {mess}" if raise_errors: raise Exception(mess) elif verbose: