diff --git a/src/dcegm/budget.py b/src/dcegm/budget.py index 0e3cd03a..e1288c92 100644 --- a/src/dcegm/budget.py +++ b/src/dcegm/budget.py @@ -77,7 +77,7 @@ def _transform_lagged_choice_to_working_hours(lagged_choice): def calculate_resources_for_second_continuous_state( discrete_states_beginning_of_next_period, continuous_state_beginning_of_next_period, - savings_end_of_previous_period, + savings_grid, income_shocks_current_period, params, compute_beginning_of_period_resources, @@ -97,7 +97,7 @@ def calculate_resources_for_second_continuous_state( )( discrete_states_beginning_of_next_period, continuous_state_beginning_of_next_period, - savings_end_of_previous_period, + savings_grid, income_shocks_current_period, params, compute_beginning_of_period_resources, @@ -110,7 +110,7 @@ def calculate_resources_for_second_continuous_state( def calculate_resources( discrete_states_beginning_of_period, - savings_end_of_previous_period, + savings_grid, income_shocks_current_period, params, compute_beginning_of_period_resources, @@ -126,7 +126,7 @@ def calculate_resources( in_axes=(0, None, None, None, None), # discrete states )( discrete_states_beginning_of_period, - savings_end_of_previous_period, + savings_grid, income_shocks_current_period, params, compute_beginning_of_period_resources, diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index cc378692..96caff88 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -70,9 +70,9 @@ def interpolate_value_and_marg_util( interp_for_single_state_choice = vmap( vmap( interp2d_value_and_marg_util_for_state_choice, - in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # state-choice + in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state ), - in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state + in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # discrete state-choice ) return interp_for_single_state_choice( @@ -91,7 +91,7 @@ def interpolate_value_and_marg_util( else: interp_for_single_state_choice = vmap( interp1d_value_and_marg_util_for_state_choice, - in_axes=(None, None, 0, 0, 0, 0, 0, None), + in_axes=(None, None, 0, 0, 0, 0, 0, None), # discrete state-choice ) return interp_for_single_state_choice( @@ -181,12 +181,12 @@ def interp2d_value_and_marg_util_for_state_choice( compute_marginal_utility: Callable, compute_utility: Callable, state_choice_vec: Dict[str, int], + regular_grid: jnp.ndarray, wealth_beginning_of_next_period: jnp.ndarray, - # regular_grid_child_state_choice: jnp.ndarray, + continuous_state_beginning_of_next_period: jnp.ndarray, endog_grid_child_state_choice: jnp.ndarray, policy_child_state_choice: jnp.ndarray, value_child_state_choice: jnp.ndarray, - has_second_continuous_state: bool, params: Dict[str, float], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Interpolate value and policy for given child state and compute marginal utility. @@ -231,7 +231,7 @@ def interp_on_single_wealth_point(regular_point, wealth_point): policy_interp, value_interp = ( interp2d_policy_and_value_on_wealth_and_regular_grid( - # regular_grid=regular_grid_child_state_choice, + regular_grid=regular_grid, wealth_grid=endog_grid_child_state_choice, policy_grid=policy_child_state_choice, value_grid=value_child_state_choice, diff --git a/src/dcegm/interpolation/interp1d.py b/src/dcegm/interpolation/interp1d.py index 0557f563..14dfea03 100644 --- a/src/dcegm/interpolation/interp1d.py +++ b/src/dcegm/interpolation/interp1d.py @@ -211,7 +211,7 @@ def interp_value_and_check_creditconstraint( x_new=new_wealth, ) - # Now recalculate the value when consumed all wealth + # Now recalculate the closed-form value of consuming all wealth utility = compute_utility( consumption=new_wealth, params=params, @@ -222,7 +222,6 @@ def interp_value_and_check_creditconstraint( # Check if we are in the credit constrained region credit_constraint = new_wealth <= endog_grid_min - # If so we return the value if all is consumed. value_interp = ( credit_constraint * value_interp_closed_form + (1 - credit_constraint) * value_interp_on_grid diff --git a/src/dcegm/interpolation/interp2d.py b/src/dcegm/interpolation/interp2d.py index bca45ecf..c2c1fa3b 100644 --- a/src/dcegm/interpolation/interp2d.py +++ b/src/dcegm/interpolation/interp2d.py @@ -1,6 +1,6 @@ """Jax implementation of 2D interpolation.""" -from typing import Callable +from typing import Callable, Dict import jax.lax import jax.numpy as jnp @@ -23,6 +23,7 @@ def interp2d_policy_and_value_on_wealth_and_regular_grid( regular_point_to_interp: jnp.ndarray | float, wealth_point_to_interp: jnp.ndarray | float, compute_utility: Callable, + state_choice_vec: Dict[str, int], params: dict, ): """Linear 2D interpolation on two grids where wealth has irregular spacing. @@ -85,6 +86,7 @@ def interp2d_policy_and_value_on_wealth_and_regular_grid( compute_utility=compute_utility, wealth_min_unconstrained=wealth_grid[:, 1], value_at_zero_wealth=value_grid[:, 0], + state_choice_vec=state_choice_vec, params=params, ) @@ -147,6 +149,7 @@ def interp2d_value_and_check_creditconstraint( compute_utility, wealth_min_unconstrained, value_at_zero_wealth, + state_choice_vec, params, ): """Interpolate the value function on a 2D grid and check for credit constraints. @@ -189,7 +192,9 @@ def interp2d_value_and_check_creditconstraint( # Now recalculate the closed-form value of consuming all wealth value_calc_left = ( - compute_utility(wealth_point_to_interp, params) + compute_utility( + consumption=wealth_point_to_interp, params=params, **state_choice_vec + ) + params["beta"] * value_at_zero_wealth[regular_idx_left] ) @@ -198,7 +203,9 @@ def interp2d_value_and_check_creditconstraint( wealth_point_to_interp <= wealth_min_unconstrained[regular_idx_right] ) value_calc_right = ( - compute_utility(wealth_point_to_interp, params) + compute_utility( + consumption=wealth_point_to_interp, params=params, **state_choice_vec + ) + params["beta"] * value_at_zero_wealth[regular_idx_right] ) diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index eaa140ec..3f8627c4 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -228,10 +228,10 @@ def backward_induction( taste_shock_scale = params["lambda"] if has_second_continuous_state: - continuous_grid = jnp.linspace(0, 1, 10) + # continuous_grid = jnp.linspace(0, 1, 10) continuous_state_next_period = calculate_continuous_state( discrete_states_beginning_of_period=state_space_dict, - continuous_grid=continuous_grid, + continuous_grid=exog_grids[1], params=params, compute_continuous_state=model_funcs[ "compute_beginning_of_period_continuous_state" @@ -243,7 +243,7 @@ def backward_induction( calculate_resources_for_second_continuous_state( discrete_states_beginning_of_next_period=state_space_dict, continuous_state_beginning_of_next_period=continuous_state_next_period, - savings_end_of_previous_period=exog_grids, + savings_grid=exog_grids[0], income_shocks_current_period=income_shock_draws_unscaled * params["sigma"], params=params, @@ -261,7 +261,7 @@ def backward_induction( else: wealth_and_continuous_state_next_period = calculate_resources( discrete_states_beginning_of_period=state_space_dict, - savings_end_of_previous_period=exog_grids[0], + savings_grid=exog_grids[0], income_shocks_current_period=income_shock_draws_unscaled * params["sigma"], params=params, compute_beginning_of_period_resources=model_funcs[ @@ -269,7 +269,6 @@ def backward_induction( ], ) - # breakpoint() # Create solution containers. The 20 percent extra in wealth grid needs to go # into tuning parameters ( diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index dce9e53a..60a85dfa 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -17,6 +17,10 @@ from numpy.testing import assert_array_almost_equal as aaae from scipy.interpolate import griddata, interp1d +from dcegm.egm.interpolate_marginal_utility import ( + interp2d_value_and_marg_util_for_state_choice, + interpolate_value_and_marg_util, +) from dcegm.interpolation.interp2d import ( interp2d_policy_and_value_on_wealth_and_regular_grid, ) @@ -28,14 +32,12 @@ custom_interp2d_quad, custom_interp2d_quad_value_function, ) +from toy_models.consumption_retirement_model.utility_functions import ( + marginal_utility_crra, + utility_crra, +) -PARAMS = {"beta": 0.95, "theta": 0.5} - - -def flow_util(cons, params): - """Flow utility in the Epstein Zin case.""" - theta = params["theta"] - return (cons ** (1 - theta)) / (1 - theta) +PARAMS = {"beta": 0.95, "rho": 0.5, "delta": -1} @pytest.fixture() @@ -120,7 +122,7 @@ def functional_form(x, y): regular_grid, values=value, points=test_points, - float_util=flow_util, + flow_util=utility_crra, params=PARAMS, ) @@ -139,7 +141,8 @@ def functional_form(x, y): value_grid=value_jax, wealth_point_to_interp=x_in, regular_point_to_interp=y_in, - compute_utility=flow_util, + compute_utility=utility_crra, + state_choice_vec={"choice": 0}, params=PARAMS, ) ) @@ -154,3 +157,8 @@ def functional_form(x, y): aaae(policy_interp_jax, policy_interp_scipy, decimal=7) aaae(value_interp_jax, value_interp_custom, decimal=7) + + +def test_interp2d_value_and_marg_util(): + + pass diff --git a/tests/utils/interp2d_auxiliary.py b/tests/utils/interp2d_auxiliary.py index 648bf11b..4532815b 100644 --- a/tests/utils/interp2d_auxiliary.py +++ b/tests/utils/interp2d_auxiliary.py @@ -77,7 +77,7 @@ def custom_interp2d_quad(x_grids, y_grid, values, points): def custom_interp2d_quad_value_function( - x_grids, y_grid, values, points, *, float_util, params + x_grids, y_grid, values, points, *, flow_util, params ): """This function is able to interpolate linearly on a two dimensional grid where one dimension has irregular spacing between the grid points. @@ -141,7 +141,7 @@ def custom_interp2d_quad_value_function( # z[i, 0] = transform(values[y_indx][0], theta) # else: - z[i, 0] = float_util(x, params) + params["beta"] * values[y_indx][0] + z[i, 0] = flow_util(x, params) + params["beta"] * values[y_indx][0] if x < x_grids[y_indx + 1][1]: x_cords[i, 1] = x @@ -149,7 +149,7 @@ def custom_interp2d_quad_value_function( # z[i, 1] = transform(values[y_indx + 1][0], params) # else: - z[i, 1] = float_util(x, params) + params["beta"] * values[y_indx + 1][0] + z[i, 1] = flow_util(x, params) + params["beta"] * values[y_indx + 1][0] alpha, beta = (x_cords @ A), (y_cords @ A) m, l = calculate_map_params(points[:, 0], points[:, 1], alpha, beta)