Skip to content

Commit

Permalink
Refactor and adjust docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Aug 8, 2024
1 parent c5647ba commit 719b281
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 56 deletions.
66 changes: 32 additions & 34 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from jax import numpy as jnp
from jax import vmap

from dcegm.interpolation.interp1d import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_value_and_marg_utility_on_next_period_wealth(
def interpolate_value_and_marg_util(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -16,14 +16,12 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy in the child states and compute the marginal value.
"""Interpolate value and policy for all child states and compute marginal utility.
Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility. The input ```params``` is already partialled in.
compute_value (callable): Function for calculating the value from consumption
level, discrete choice and expected value. The inputs ```discount_rate```
and ```compute_utility``` are already partialled in.
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
Expand All @@ -41,19 +39,20 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
Returns:
tuple:
- marg_utils (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated marginal utilities for each wealth level and
income shock.
- value_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated value function.
- marg_util_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated marginal utilities for each wealth level and
income shock.
"""

interp_for_state_choice = vmap(
interpolate_value_and_marg_utility_for_single_state_choice,
interp_for_single_state_choice = vmap(
interpolate_value_and_marg_util_for_single_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)
return interp_for_state_choice(

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
Expand All @@ -65,7 +64,7 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
)


def interpolate_value_and_marg_utility_for_single_state_choice(
def interpolate_value_and_marg_util_for_single_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -75,24 +74,22 @@ def interpolate_value_and_marg_utility_for_single_state_choice(
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy in the child states and compute the marginal value.
"""Interpolate value and policy for given child state and compute marginal utility.
Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility. The input ```params``` is already partialled in.
compute_value (callable): Function for calculating the value from consumption
level, discrete choice and expected value. The inputs ```discount_rate```
and ```compute_utility``` are already partialled in.
wealth_next_period (jnp.ndarray): 2d array of shape
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
choice (int): The agent's discrete choice.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
choice_policies_child_state_choice (jnp.ndarray): 1d array containing the
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
choice_values_child_state_choice (jnp.ndarray): 1d array containing the
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
params (dict): Dictionary containing the model parameters.
Expand All @@ -108,10 +105,8 @@ def interpolate_value_and_marg_utility_for_single_state_choice(
"""

# Generate interpolation function for a single wealth point where the endogenous grid,
# policy and value are fixed.
def interp_on_single_wealth(wealth):
policy_interp, value_interp = interp_value_and_policy_on_wealth(
def interp_on_single_wealth_point(wealth):
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
Expand All @@ -120,14 +115,17 @@ def interp_on_single_wealth(wealth):
state_choice_vec=state_choice_vec,
params=params,
)
marg_utility_interp = compute_marginal_utility(
marg_util_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)
return marg_utility_interp, value_interp
return value_interp, marg_util_interp

# Vectorize interp_on_single_wealth over savings and income shock dimension
vector_interp_func = vmap(vmap(interp_on_single_wealth))

marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period)
# Vectorize over savings and income shock dimension
interp_for_savings_point_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point)
)
value_interp, marg_util_interp = interp_for_savings_point_and_income_shock_draw(
wealth_beginning_of_period
)

return marg_utils, value_interp
return value_interp, marg_util_interp
4 changes: 2 additions & 2 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from dcegm.interpolation.interp1d import (
interp_policy_on_wealth,
interp_value_and_policy_on_wealth,
interp_value_on_wealth,
interpolate_policy_and_value_on_wealth,
)


Expand Down Expand Up @@ -34,7 +34,7 @@ def policy_and_value_for_state_choice_vec(
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
policy, value = interp_value_and_policy_on_wealth(
policy, value = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/interpolation/interp1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_index_high_and_low(x, x_new):
return ind_high, ind_high - 1


def interp_value_and_policy_on_wealth(
def interpolate_policy_and_value_on_wealth(
wealth: float | jnp.ndarray,
endog_grid: jnp.ndarray,
policy: jnp.ndarray,
Expand All @@ -45,7 +45,7 @@ def interp_value_and_policy_on_wealth(
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate value and policy given a single wealth value.
"""Interpolate policy and value function given a single wealth grid point.
Args:
wealth (float): Wealth value to interpolate.
Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dcegm.budget import calculate_resources_for_all_agents
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation.interp1d import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_policy_and_value_for_all_agents(
Expand Down Expand Up @@ -172,7 +172,7 @@ def interpolate_policy_and_value_function(
):
state_choice_vec = {**state, "choice": choice}

policy_interp, value_interp = interp_value_and_policy_on_wealth(
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=resources_beginning_of_period,
endog_grid=endog_grid_agent,
policy=policy_agent,
Expand Down
27 changes: 11 additions & 16 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from jax import vmap

from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values
from dcegm.egm.interpolate_marginal_utility import (
interpolate_value_and_marg_utility_on_next_period_wealth,
)
from dcegm.egm.interpolate_marginal_utility import interpolate_value_and_marg_util
from dcegm.egm.solve_euler_equation import (
calculate_candidate_solutions_from_euler_equation,
)
Expand All @@ -19,8 +17,7 @@ def solve_single_period(
model_funcs,
taste_shock_scale,
):
"""This function solves a single period of the model using the discrete continuous
endogenous grid method (DCEGM)."""
"""Solve a single period of the model using the DCEGM method."""
(value_solved, policy_solved, endog_grid_solved) = carry

(
Expand All @@ -34,17 +31,15 @@ def solve_single_period(
) = xs

# EGM step 1)
marginal_utility_interpolated, value_interpolated = (
interpolate_value_and_marg_utility_on_next_period_wealth(
model_funcs["compute_marginal_utility"],
model_funcs["compute_utility"],
state_choice_mat_child,
resources_beginning_of_period[child_state_idxs, :, :],
endog_grid_solved[child_state_choice_idxs_to_interpolate, :],
policy_solved[child_state_choice_idxs_to_interpolate, :],
value_solved[child_state_choice_idxs_to_interpolate, :],
params,
)
value_interpolated, marginal_utility_interpolated = interpolate_value_and_marg_util(
model_funcs["compute_marginal_utility"],
model_funcs["compute_utility"],
state_choice_mat_child,
resources_beginning_of_period[child_state_idxs, :, :],
endog_grid_solved[child_state_choice_idxs_to_interpolate, :],
policy_solved[child_state_choice_idxs_to_interpolate, :],
value_solved[child_state_choice_idxs_to_interpolate, :],
params,
)

(
Expand Down

0 comments on commit 719b281

Please sign in to comment.