Skip to content

Commit

Permalink
add more flexible transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Oct 29, 2024
1 parent d0f531e commit e424ba9
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 82 deletions.
26 changes: 16 additions & 10 deletions econpizza/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


# set number of cores for XLA
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={
os.cpu_count()}"

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -57,13 +58,14 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None):
a dictionary of the distributions
"""

# get transformers (experimental)
transform_back = self['options'].get('transform_back') or (lambda x: x)

dist0 = jnp.array(init_dist) if init_dist is not None else jnp.array(
self['steady_state'].get('distributions'))
if self.get('exp_all'):
pars = jnp.log(jnp.array(list(self['pars'].values())) if pars is None else pars)
trajectory = jnp.log(trajectory)
else:
pars = jnp.array(list(self['pars'].values())) if pars is None else pars
pars = transform_back(
jnp.array(list(self['pars'].values())) if pars is None else pars)
trajectory = transform_back(trajectory)
shocks = self.get("shocks") or ()
dist_names = list(self['distributions'].keys())
decisions_inputs = self['decisions']['inputs']
Expand All @@ -81,13 +83,17 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None):
# get functions and execute
backwards_sweep = self['context']['backwards_sweep']
forwards_sweep = self['context']['forwards_sweep']
wf_storage, decisions_output_storage = backwards_sweep(x, x0, shock_series.T, pars, return_wf=True)
wf_storage, decisions_output_storage = backwards_sweep(
x, x0, shock_series.T, pars, return_wf=True)
dists_storage = forwards_sweep(decisions_output_storage, dist0)

# store this
rdict = {oput: wf_storage[i] for i, oput in enumerate(decisions_inputs)}
rdict.update({oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)})
rdict.update({oput: dists_storage[i] for i, oput in enumerate(dist_names)})
rdict = {oput: wf_storage[i]
for i, oput in enumerate(decisions_inputs)}
rdict.update(
{oput: decisions_output_storage[i] for i, oput in enumerate(decisions_outputs)})
rdict.update({oput: dists_storage[i]
for i, oput in enumerate(dist_names)})

return rdict

Expand Down
29 changes: 19 additions & 10 deletions econpizza/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def _define_function(func_str, context):
return tmpf.name


def wrap_with_transform(func, transform):
if transform:
def outfunc(XLag, X, XPrime, XSS, pars, *args, **kwargs):
Xandpar = (transform(y) for y in (XLag, X, XPrime, XSS, pars))
return func(*Xandpar, *args, **kwargs)
return outfunc
else:
return func


def _get_pre_stst_mapping(init_vals, fixed_values, evars, par_names):
"""Define the mapping from init & fixed vals to model variables & parameters
"""
Expand Down Expand Up @@ -329,6 +339,11 @@ def load(
raise TypeError(
f'parameters must be a list and not {type(par_names)}.')

# get and evaluate options
_ = _define_subdict_if_absent(model, "options")
model['options'], _ = _eval_strs(
model['options'], context=model['context'])

# get function strings for decisions and distributions, if they exist
if model.get('decisions'):
decisions_outputs = model['decisions']['outputs']
Expand All @@ -337,11 +352,8 @@ def load(
evars, par_names, shocks, decisions_inputs, decisions_outputs, model['decisions']['calls'])
_define_function(model['func_strings']
['func_backw'], model['context'])
if model.get('exp_all'):
model['context']['func_backw'] = lambda xl, xc, xp, XSS, WFPrime, shocks, pars: model['context']['func_backw_raw'](
jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), WFPrime, shocks, jnp.exp(pars))
else:
model['context']['func_backw'] = model['context']['func_backw_raw']
model['context']['func_backw'] = wrap_with_transform(
model['context']['func_backw_raw'], model['options'].get('transform_to'))
else:
decisions_outputs = []
decisions_inputs = []
Expand All @@ -359,11 +371,8 @@ def load(

# writing to tempfiles helps to get nice debug traces if the model does not work
_define_function(model['func_strings']['func_eqns'], model['context'])
if model.get('exp_all'):
model['context']['func_eqns'] = lambda xl, xc, xp, XSS, shocks, pars, distributions, decisions_outputs: model['context']['func_eqns_raw'](
jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), shocks, jnp.exp(pars), distributions, decisions_outputs)
else:
model['context']['func_eqns'] = model['context']['func_eqns_raw']
model['context']['func_eqns'] = wrap_with_transform(
model['context']['func_eqns_raw'], model['options'].get('transform_to'))
# compile fixed and initial values
stst_inputs = compile_stst_inputs(model)
# try if function works on initvals
Expand Down
8 changes: 5 additions & 3 deletions econpizza/parser/build_generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver
combined_sweep = model['context']['combined_sweep']

distSS = jnp.array(model['steady_state']['distributions'])
decisions_outputSS = (jnp.array(d)[..., None] for d in list(model['steady_state']['decisions'].values()))
decisions_outputSS = (jnp.array(d)[..., None] for d in list(
model['steady_state']['decisions'].values()))

# basis for steady state jacobian construction
basis = jnp.zeros((nvars*(horizon-1), nvars))
Expand All @@ -148,7 +149,7 @@ def get_stst_derivatives(model, nvars, pars, stst, x_stst, zshocks, horizon, ver
# get steady state jacobians for direct effects x on f
jacrev_func_eqns = jax.jacrev(func_eqns, argnums=(0, 1, 2))
f2X = jacrev_func_eqns(stst[:, None], stst[:, None], stst[:, None],
stst, zshocks[:, 0], pars, distSS[..., None], decisions_outputSS)
stst, pars, zshocks[:, 0], distSS[..., None], decisions_outputSS)

if verbose:
duration = time.time() - st
Expand All @@ -174,7 +175,8 @@ def get_stacked_func_het_agents(func_backw, func_forw, func_eqns, stst, wfSS, ho
forwards_sweep, horizon=horizon, func_forw=partial_forw)
final_step_local = jax.tree_util.Partial(
final_step, stst=stst, horizon=horizon, nshpe=nshpe, func_eqns=func_eqns)
combined_sweep_local = jax.tree_util.Partial(combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local)
combined_sweep_local = jax.tree_util.Partial(
combined_sweep, forwards_sweep=forwards_sweep_local, final_step=final_step_local)
stacked_func_local = jax.tree_util.Partial(
stacked_func_het_agents, backwards_sweep=backwards_sweep_local, combined_sweep=combined_sweep_local)

Expand Down
10 changes: 6 additions & 4 deletions econpizza/parser/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def check_determinancy(evars, eqns):
sorted_evars = evars[:] = sorted(list(set(evars)), key=str.lower)
if len(sorted_evars) != len(eqns):
raise Exception(
f"Model has {len(sorted_evars)} variables but {len(eqns)} equations."
f"Model has {len(sorted_evars)} variables but {
len(eqns)} equations."
)
return sorted_evars

Expand All @@ -71,7 +72,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre
mess = ''
if model.get('decisions'):
# make a test backward and forward run
_, decisions_output_init = model['context']['func_backw'](init_vals, init_vals, init_vals, init_vals, init_wf, jnp.zeros(len(shocks)), par)
_, decisions_output_init = model['context']['func_backw'](
init_vals, init_vals, init_vals, init_vals, par, init_wf, jnp.zeros(len(shocks)))
dists_init, _ = model['context']['func_forw_stst'](
decisions_output_init, 1e-8, 1)

Expand All @@ -91,8 +93,8 @@ def check_initial_values(model, shocks, init_guesses, fixed_values, init_wf, pre

# final test of main function
init_vals = init_vals[..., None]
test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, jnp.zeros(
len(shocks)), par, jnp.array(dists_init)[..., None], (doi[...,None] for doi in decisions_output_init))
test = model['context']['func_eqns'](init_vals, init_vals, init_vals, init_vals, par, jnp.zeros(
len(shocks)), jnp.array(dists_init)[..., None], (doi[..., None] for doi in decisions_output_init))

if mess:
pass
Expand Down
16 changes: 10 additions & 6 deletions econpizza/parser/compile_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compile_backw_func_str(evars, par, shocks, inputs, outputs, calls):
if isinstance(calls, str):
calls = calls.splitlines()

func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, WFPrime, shocks, pars):
func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, pars, WFPrime, shocks):
{compile_func_basics_str(evars, par, shocks)}
\n ({"".join(v + ", " for v in inputs)}) = WFPrime
\n %s
Expand All @@ -49,8 +49,10 @@ def get_forw_funcs(model):
dist = distributions[dist_name]

# *_generic should be depreciated at some point
implemented_endo = ('exogenous', 'exogenous_rouwenhorst', 'exogenous_generic', 'exogenous_custom')
implemented_exo = ('endogenous', 'endogenous_log', 'endogenous_generic', 'endogenous_custom')
implemented_endo = ('exogenous', 'exogenous_rouwenhorst',
'exogenous_generic', 'exogenous_custom')
implemented_exo = ('endogenous', 'endogenous_log',
'endogenous_generic', 'endogenous_custom')
exog = [v for v in dist if dist[v]['type'] in implemented_endo]
endo = [v for v in dist if dist[v]['type'] in implemented_exo]
other = [dist[v]['type'] for v in dist if dist[v]
Expand All @@ -68,13 +70,15 @@ def get_forw_funcs(model):

# for each object, check if it is provided in decisions_outputs
try:
transition = model['decisions']['outputs'].index(dist[exog[0]]['transition_name'])
transition = model['decisions']['outputs'].index(
dist[exog[0]]['transition_name'])
except ValueError:
transition = model['context'][dist[exog[0]]['transition_name']]
grids = []
for i in endo:
try:
grids.append(model['decisions']['outputs'].index(dist[i]['grid_name']))
grids.append(model['decisions']
['outputs'].index(dist[i]['grid_name']))
except ValueError:
grids.append(model['context'][dist[i]['grid_name']])
indices = [model['decisions']['outputs'].index(i) for i in endo]
Expand Down Expand Up @@ -108,7 +112,7 @@ def compile_eqn_func_str(evars, eqns, par, eqns_aux, shocks, distributions, deci
eqns_stack = "\n ".join(eqns)

# compile the final function string
func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, shocks, pars, distributions=[], decisions_outputs=[]):
func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, pars, shocks, distributions=[], decisions_outputs=[]):
{compile_func_basics_str(evars, par, shocks)}
\n ({"".join(d+', ' for d in distributions)}) = distributions
\n ({"".join(d+', ' for d in decisions_outputs)}) = decisions_outputs
Expand Down
12 changes: 7 additions & 5 deletions econpizza/parser/het_agent_base_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _backwards_stst_cond(carry):

def _backwards_stst_body(carry):
(x, par), (wf, _), (_, cnt), (func, tol, maxit) = carry
return (x, par), func(x, x, x, x, wf, pars=par), (wf, cnt + 1), (func, tol, maxit)
return (x, par), func(x, x, x, x, pars=par, WFPrime=wf), (wf, cnt + 1), (func, tol, maxit)


def backwards_sweep_stst(x, par, carry):
Expand All @@ -29,7 +29,7 @@ def _backwards_step(carry, i):

wf, X, shocks, func_backw, stst, pars = carry
wf, decisions_output = func_backw(
X[:, i], X[:, i+1], X[:, i+2], WFPrime=wf, shocks=shocks[:, i], pars=pars)
X[:, i], X[:, i+1], X[:, i+2], pars=pars, WFPrime=wf, shocks=shocks[:, i])

return (wf, X, shocks, func_backw, stst, pars), (wf, decisions_output)

Expand All @@ -40,7 +40,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array

_, (wf_storage, decisions_output_storage) = jax.lax.scan(
_backwards_step, (wfSS, X, shocks, func_backw, stst, pars), jnp.arange(horizon-1), reverse=True)
decisions_output_storage = [jnp.moveaxis(dos, 0, -1) for dos in decisions_output_storage]
decisions_output_storage = [jnp.moveaxis(
dos, 0, -1) for dos in decisions_output_storage]
wf_storage = jnp.moveaxis(wf_storage, 0, -1)

if return_wf:
Expand All @@ -51,7 +52,8 @@ def backwards_sweep(x: Array, x0: Array, shocks: Array, pars: Array, stst: Array
def _forwards_step(carry, i):

dist_old, decisions_output_storage, func_forw = carry
dist = func_forw(dist_old, [dos[..., i] for dos in decisions_output_storage])
dist = func_forw(dist_old, [dos[..., i]
for dos in decisions_output_storage])

return (dist, decisions_output_storage, func_forw), dist_old

Expand All @@ -69,7 +71,7 @@ def final_step(x: Array, dists_storage: Array, decisions_output_storage: Array,

X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T
out = func_eqns(X[:, :-2].reshape(nshpe), X[:, 1:-1].reshape(nshpe), X[:, 2:].reshape(
nshpe), stst, shocks, pars, dists_storage, decisions_output_storage)
nshpe), stst, pars, shocks, dists_storage, decisions_output_storage)

return out

Expand Down
24 changes: 9 additions & 15 deletions econpizza/solvers/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,19 @@ def find_path_stacking(
# only skip jacobian calculation if it exists
skip_jacobian = skip_jacobian if self['cache'].get(
'jac_factorized') else False
# get transformers (experimental)
transform_to = self['options'].get('transform_to') or (lambda x: x)
transform_back = self['options'].get('transform_back') or (lambda x: x)

# get variables
nvars = len(self["var_names"])
if self.get('exp_all'):
stst = jnp.log(d2jnp(self["stst"]))
pars = jnp.log(d2jnp(pars if pars is not None else self["pars"]))
else:
stst = d2jnp(self["stst"])
pars = d2jnp(pars if pars is not None else self["pars"])
stst = transform_back(d2jnp(self["stst"]))
pars = transform_back(d2jnp(pars if pars is not None else self["pars"]))
shocks = self.get("shocks") or ()

# get initial guess
if self.get('exp_all'):
x0 = jnp.log(jnp.array(list(init_state))) if init_state is not None else stst
else:
x0 = jnp.array(list(init_state)) if init_state is not None else stst
x0 = transform_back(jnp.array(list(init_state))
) if init_state is not None else stst
init_dist = init_dist if init_dist is not None else self['steady_state'].get(
'distributions')
dist0 = jnp.array(init_dist if init_dist is not None else jnp.nan)
Expand All @@ -110,7 +107,7 @@ def find_path_stacking(
func_eqns = self['context']["func_eqns"]
jav_func_eqns = val_and_jacrev(func_eqns, (0, 1, 2))
jav_func_eqns_partial = jax.tree_util.Partial(
jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[])
jav_func_eqns, pars=pars, XSS=stst, distributions=[], decisions_outputs=[])
self['context']['jav_func'] = jav_func_eqns_partial
# mark as compiled
write_cache(self, horizon, pars, stst)
Expand Down Expand Up @@ -161,7 +158,4 @@ def find_path_stacking(
elif verbose:
print(mess)

if self.get('exp_all'):
return jnp.exp(x_out), (flag, f)
else:
return x_out, (flag, f)
return transform_to(x_out), (flag, f)
Loading

0 comments on commit e424ba9

Please sign in to comment.