diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index f4529d56..1bbe0541 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -3,10 +3,10 @@ from jax import numpy as jnp from jax import vmap -from dcegm.interpolation.interp1d import interp_value_and_policy_on_wealth +from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth -def interpolate_value_and_marg_utility_on_next_period_wealth( +def interpolate_value_and_marg_util( compute_marginal_utility: Callable, compute_utility: Callable, state_choice_vec: Dict[str, int], @@ -16,14 +16,12 @@ def interpolate_value_and_marg_utility_on_next_period_wealth( value_child_state_choice: jnp.ndarray, params: Dict[str, float], ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Interpolate value and policy in the child states and compute the marginal value. + """Interpolate value and policy for all child states and compute marginal utility. Args: compute_marginal_utility (callable): User-defined function to compute the - agent's marginal utility. The input ```params``` is already partialled in. - compute_value (callable): Function for calculating the value from consumption - level, discrete choice and expected value. The inputs ```discount_rate``` - and ```compute_utility``` are already partialled in. + agent's marginal utility of consumption. + compute_utility (callable): Function for calculating the utility of consumption. state_choice_vec (dict): Dictionary containing the state and choice of the agent. wealth_beginning_of_period (jnp.ndarray): 2d array of shape (n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of @@ -41,19 +39,20 @@ def interpolate_value_and_marg_utility_on_next_period_wealth( Returns: tuple: - - marg_utils (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks) - containing the interpolated marginal utilities for each wealth level and - income shock. - value_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks) containing the interpolated value function. + - marg_util_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks) + containing the interpolated marginal utilities for each wealth level and + income shock. """ - interp_for_state_choice = vmap( - interpolate_value_and_marg_utility_for_single_state_choice, + interp_for_single_state_choice = vmap( + interpolate_value_and_marg_util_for_single_state_choice, in_axes=(None, None, 0, 0, 0, 0, 0, None), ) - return interp_for_state_choice( + + return interp_for_single_state_choice( compute_marginal_utility, compute_utility, state_choice_vec, @@ -65,7 +64,7 @@ def interpolate_value_and_marg_utility_on_next_period_wealth( ) -def interpolate_value_and_marg_utility_for_single_state_choice( +def interpolate_value_and_marg_util_for_single_state_choice( compute_marginal_utility: Callable, compute_utility: Callable, state_choice_vec: Dict[str, int], @@ -75,24 +74,22 @@ def interpolate_value_and_marg_utility_for_single_state_choice( value_child_state_choice: jnp.ndarray, params: Dict[str, float], ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Interpolate value and policy in the child states and compute the marginal value. + """Interpolate value and policy for given child state and compute marginal utility. Args: compute_marginal_utility (callable): User-defined function to compute the - agent's marginal utility. The input ```params``` is already partialled in. - compute_value (callable): Function for calculating the value from consumption - level, discrete choice and expected value. The inputs ```discount_rate``` - and ```compute_utility``` are already partialled in. - wealth_next_period (jnp.ndarray): 2d array of shape + agent's marginal utility of consumption. + compute_utility (callable): Function for calculating the utility of consumption. + state_choice_vec (dict): Dictionary containing the state and choice of the agent. + wealth_beginning_of_period (jnp.ndarray): 2d array of shape (n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of period wealth. - choice (int): The agent's discrete choice. endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous wealth grid of the child state/choice pair. Shape (n_grid_wealth,). - choice_policies_child_state_choice (jnp.ndarray): 1d array containing the + policy_child_state_choice (jnp.ndarray): 1d array containing the corresponding policy function values of the endogenous wealth grid of the child state/choice pair. Shape (n_grid_wealth,). - choice_values_child_state_choice (jnp.ndarray): 1d array containing the + value_child_state_choice (jnp.ndarray): 1d array containing the corresponding value function values of the endogenous wealth grid of the child state/choice pair. Shape (n_grid_wealth,). params (dict): Dictionary containing the model parameters. @@ -108,10 +105,8 @@ def interpolate_value_and_marg_utility_for_single_state_choice( """ - # Generate interpolation function for a single wealth point where the endogenous grid, - # policy and value are fixed. - def interp_on_single_wealth(wealth): - policy_interp, value_interp = interp_value_and_policy_on_wealth( + def interp_on_single_wealth_point(wealth): + policy_interp, value_interp = interpolate_policy_and_value_on_wealth( wealth=wealth, endog_grid=endog_grid_child_state_choice, policy=policy_child_state_choice, @@ -120,14 +115,17 @@ def interp_on_single_wealth(wealth): state_choice_vec=state_choice_vec, params=params, ) - marg_utility_interp = compute_marginal_utility( + marg_util_interp = compute_marginal_utility( consumption=policy_interp, params=params, **state_choice_vec ) - return marg_utility_interp, value_interp + return value_interp, marg_util_interp - # Vectorize interp_on_single_wealth over savings and income shock dimension - vector_interp_func = vmap(vmap(interp_on_single_wealth)) - - marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period) + # Vectorize over savings and income shock dimension + interp_for_savings_point_and_income_shock_draw = vmap( + vmap(interp_on_single_wealth_point) + ) + value_interp, marg_util_interp = interp_for_savings_point_and_income_shock_draw( + wealth_beginning_of_period + ) - return marg_utils, value_interp + return value_interp, marg_util_interp diff --git a/src/dcegm/interface.py b/src/dcegm/interface.py index ddac3aae..13211802 100644 --- a/src/dcegm/interface.py +++ b/src/dcegm/interface.py @@ -2,8 +2,8 @@ from dcegm.interpolation.interp1d import ( interp_policy_on_wealth, - interp_value_and_policy_on_wealth, interp_value_on_wealth, + interpolate_policy_and_value_on_wealth, ) @@ -34,7 +34,7 @@ def policy_and_value_for_state_choice_vec( ) state_choice_index = map_state_choice_to_index[state_choice_tuple] - policy, value = interp_value_and_policy_on_wealth( + policy, value = interpolate_policy_and_value_on_wealth( wealth=wealth, endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), policy=jnp.take(policy_solved, state_choice_index, axis=0), diff --git a/src/dcegm/interpolation/interp1d.py b/src/dcegm/interpolation/interp1d.py index 49816573..6f6e77b3 100644 --- a/src/dcegm/interpolation/interp1d.py +++ b/src/dcegm/interpolation/interp1d.py @@ -36,7 +36,7 @@ def get_index_high_and_low(x, x_new): return ind_high, ind_high - 1 -def interp_value_and_policy_on_wealth( +def interpolate_policy_and_value_on_wealth( wealth: float | jnp.ndarray, endog_grid: jnp.ndarray, policy: jnp.ndarray, @@ -45,7 +45,7 @@ def interp_value_and_policy_on_wealth( state_choice_vec: Dict[str, int], params: Dict[str, float], ) -> Tuple[float, float]: - """Interpolate value and policy given a single wealth value. + """Interpolate policy and value function given a single wealth grid point. Args: wealth (float): Wealth value to interpolate. diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index f58899e8..d264b646 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -6,7 +6,7 @@ from dcegm.budget import calculate_resources_for_all_agents from dcegm.interface import get_state_choice_index_per_state -from dcegm.interpolation.interp1d import interp_value_and_policy_on_wealth +from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth def interpolate_policy_and_value_for_all_agents( @@ -172,7 +172,7 @@ def interpolate_policy_and_value_function( ): state_choice_vec = {**state, "choice": choice} - policy_interp, value_interp = interp_value_and_policy_on_wealth( + policy_interp, value_interp = interpolate_policy_and_value_on_wealth( wealth=resources_beginning_of_period, endog_grid=endog_grid_agent, policy=policy_agent, diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index ced2ee1a..a6ee7a34 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -1,9 +1,7 @@ from jax import vmap from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values -from dcegm.egm.interpolate_marginal_utility import ( - interpolate_value_and_marg_utility_on_next_period_wealth, -) +from dcegm.egm.interpolate_marginal_utility import interpolate_value_and_marg_util from dcegm.egm.solve_euler_equation import ( calculate_candidate_solutions_from_euler_equation, ) @@ -19,8 +17,7 @@ def solve_single_period( model_funcs, taste_shock_scale, ): - """This function solves a single period of the model using the discrete continuous - endogenous grid method (DCEGM).""" + """Solve a single period of the model using the DCEGM method.""" (value_solved, policy_solved, endog_grid_solved) = carry ( @@ -34,17 +31,15 @@ def solve_single_period( ) = xs # EGM step 1) - marginal_utility_interpolated, value_interpolated = ( - interpolate_value_and_marg_utility_on_next_period_wealth( - model_funcs["compute_marginal_utility"], - model_funcs["compute_utility"], - state_choice_mat_child, - resources_beginning_of_period[child_state_idxs, :, :], - endog_grid_solved[child_state_choice_idxs_to_interpolate, :], - policy_solved[child_state_choice_idxs_to_interpolate, :], - value_solved[child_state_choice_idxs_to_interpolate, :], - params, - ) + value_interpolated, marginal_utility_interpolated = interpolate_value_and_marg_util( + model_funcs["compute_marginal_utility"], + model_funcs["compute_utility"], + state_choice_mat_child, + resources_beginning_of_period[child_state_idxs, :, :], + endog_grid_solved[child_state_choice_idxs_to_interpolate, :], + policy_solved[child_state_choice_idxs_to_interpolate, :], + value_solved[child_state_choice_idxs_to_interpolate, :], + params, ) (