Skip to content

Commit

Permalink
Add second continuous state variable (#118)
Browse files Browse the repository at this point in the history
Co-authored-by: Maximilian Blesch <maximilian.blesch@hu-berlin.de>
  • Loading branch information
segsell and MaxBlesch authored Oct 2, 2024
1 parent e14e47a commit 3799208
Show file tree
Hide file tree
Showing 50 changed files with 4,143 additions and 646 deletions.
63 changes: 0 additions & 63 deletions src/dcegm/budget.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/dcegm/egm/aggregate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def aggregate_marg_utils_and_exp_values(
)

log_sum_unsqueezed = max_value_per_state + taste_shock_scale * jnp.log(sum_exp)

# Because we kept the dimensions in the maximum and sum over choice specific objects
# to perform subtraction and division, we now need to squeeze the log_sum again
# to remove the redundant axis.
Expand Down Expand Up @@ -90,4 +91,5 @@ def calculate_choice_probs_and_unsqueezed_logsum(

sum_exp = jnp.nansum(rescaled_exponential, axis=1, keepdims=True)
choice_probs = rescaled_exponential / sum_exp

return choice_probs, max_value_per_state, sum_exp
185 changes: 157 additions & 28 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from jax import numpy as jnp
from jax import vmap

from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth
from dcegm.interpolation.interp1d import interp1d_policy_and_value_on_wealth
from dcegm.interpolation.interp2d import (
interp2d_policy_and_value_on_wealth_and_regular_grid,
)


def interpolate_value_and_marg_util(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
wealth_beginning_of_period: jnp.ndarray,
exog_grids: Tuple[jnp.ndarray, jnp.ndarray],
cont_grids_next_period: Dict[str, jnp.ndarray],
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
child_state_idxs: jnp.ndarray,
has_second_continuous_state: bool,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for all child states and compute marginal utility.
Expand All @@ -23,7 +29,7 @@ def interpolate_value_and_marg_util(
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
wealth_beginning_of_next_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
Expand All @@ -34,6 +40,9 @@ def interpolate_value_and_marg_util(
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
has_second_continuous_state (bool): Boolean indicating whether the model
features a second continuous state variable. If False, the only
continuous state variable is consumption/savings.
params (dict): Dictionary containing the model parameters.
Returns:
Expand All @@ -46,29 +55,55 @@ def interpolate_value_and_marg_util(
income shock.
"""
wealth_child_states = cont_grids_next_period["wealth"][child_state_idxs]

interp_for_single_state_choice = vmap(
interpolate_value_and_marg_util_for_single_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)
if has_second_continuous_state:
continuous_state_child_states = cont_grids_next_period["second_continuous"][
child_state_idxs
]
regular_grid = exog_grids["second_continuous"]

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
wealth_beginning_of_period,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)
interp_for_single_state_choice = vmap(
interp2d_value_and_marg_util_for_state_choice,
in_axes=(None, None, 0, None, 0, 0, 0, 0, 0, None), # discrete state-choice
)

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
regular_grid,
wealth_child_states,
continuous_state_child_states,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)

else:
interp_for_single_state_choice = vmap(
interp1d_value_and_marg_util_for_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None), # discrete state-choice
)

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
wealth_child_states,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)


def interpolate_value_and_marg_util_for_single_state_choice(
def interp1d_value_and_marg_util_for_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
wealth_beginning_of_period: jnp.ndarray,
wealth_beginning_of_next_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
Expand All @@ -81,7 +116,7 @@ def interpolate_value_and_marg_util_for_single_state_choice(
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
wealth_beginning_of_next_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
Expand All @@ -92,6 +127,9 @@ def interpolate_value_and_marg_util_for_single_state_choice(
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
has_second_continuous_state (bool): Boolean indicating whether the model
features a second continuous state variable. If False, the only
continuous state variable is consumption/savings.
params (dict): Dictionary containing the model parameters.
Returns:
Expand All @@ -105,9 +143,9 @@ def interpolate_value_and_marg_util_for_single_state_choice(
"""

def interp_on_single_wealth_point(wealth):
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=wealth,
def interp_on_single_wealth_point(wealth_point):
policy_interp, value_interp = interp1d_policy_and_value_on_wealth(
wealth=wealth_point,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
value=value_child_state_choice,
Expand All @@ -118,14 +156,105 @@ def interp_on_single_wealth_point(wealth):
marg_util_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)

return value_interp, marg_util_interp

interp_over_single_wealth_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point) # income shocks
) # wealth grid

value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw(
wealth_beginning_of_next_period
)

return value_interp, marg_util_interp


def interp2d_value_and_marg_util_for_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
regular_grid: jnp.ndarray,
wealth_beginning_of_next_period: jnp.ndarray,
continuous_state_beginning_of_next_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for given child state and compute marginal utility.
Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_next_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
has_second_continuous_state (bool): Boolean indicating whether the model
features a second continuous state variable. If False, the only
continuous state variable is consumption/savings.
params (dict): Dictionary containing the model parameters.
Returns:
tuple:
- marg_utils (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated marginal utilities for each wealth level and
income shock.
- value_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated value function.
"""

def interp_on_single_wealth_point(wealth_point, second_cont_grid_point):

policy_interp, value_interp = (
interp2d_policy_and_value_on_wealth_and_regular_grid(
regular_grid=regular_grid,
wealth_grid=endog_grid_child_state_choice,
policy_grid=policy_child_state_choice,
value_grid=value_child_state_choice,
wealth_point_to_interp=wealth_point,
regular_point_to_interp=second_cont_grid_point,
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
)
marg_util_interp = compute_marginal_utility(
consumption=policy_interp,
continuous_state=second_cont_grid_point,
params=params,
**state_choice_vec
)

return value_interp, marg_util_interp

# Vectorize over savings and income shock dimension
interp_for_savings_point_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point)
# Outer vmap applies first
interp_over_single_wealth_and_income_shock_draw = vmap(
vmap(
vmap(
interp_on_single_wealth_point,
in_axes=(0, None), # income shocks
),
in_axes=(0, None), # wealth grid
),
in_axes=(0, 0), # continuous state grid
)
value_interp, marg_util_interp = interp_for_savings_point_and_income_shock_draw(
wealth_beginning_of_period
# Old points: regular grid and endog grid
# New points: continuous state next period and wealth next period
value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw(
wealth_beginning_of_next_period, continuous_state_beginning_of_next_period
)

return value_interp, marg_util_interp
Loading

0 comments on commit 3799208

Please sign in to comment.