diff --git a/src/dcegm/budget.py b/src/dcegm/budget.py index f3fee9fa..0e3cd03a 100644 --- a/src/dcegm/budget.py +++ b/src/dcegm/budget.py @@ -77,8 +77,8 @@ def _transform_lagged_choice_to_working_hours(lagged_choice): def calculate_resources_for_second_continuous_state( discrete_states_beginning_of_next_period, continuous_state_beginning_of_next_period, - savings_end_of_last_period, - income_shocks_of_period, + savings_end_of_previous_period, + income_shocks_current_period, params, compute_beginning_of_period_resources, ): @@ -97,8 +97,8 @@ def calculate_resources_for_second_continuous_state( )( discrete_states_beginning_of_next_period, continuous_state_beginning_of_next_period, - savings_end_of_last_period, - income_shocks_of_period, + savings_end_of_previous_period, + income_shocks_current_period, params, compute_beginning_of_period_resources, ) @@ -110,8 +110,8 @@ def calculate_resources_for_second_continuous_state( def calculate_resources( discrete_states_beginning_of_period, - savings_end_of_last_period, - income_shocks_of_period, + savings_end_of_previous_period, + income_shocks_current_period, params, compute_beginning_of_period_resources, ): @@ -126,8 +126,8 @@ def calculate_resources( in_axes=(0, None, None, None, None), # discrete states )( discrete_states_beginning_of_period, - savings_end_of_last_period, - income_shocks_of_period, + savings_end_of_previous_period, + income_shocks_current_period, params, compute_beginning_of_period_resources, ) diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index 250e6918..cc378692 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -2,7 +2,6 @@ from jax import numpy as jnp from jax import vmap -from numpy.testing import assert_array_almost_equal as aaae from dcegm.interpolation.interp1d import interp1d_policy_and_value_on_wealth from dcegm.interpolation.interp2d import ( @@ -14,6 +13,7 @@ def interpolate_value_and_marg_util( compute_marginal_utility: Callable, compute_utility: Callable, state_choice_vec: Dict[str, int], + exog_grids: Tuple[jnp.ndarray, jnp.ndarray], wealth_and_continuous_state_next: jnp.ndarray, endog_grid_child_state_choice: jnp.ndarray, policy_child_state_choice: jnp.ndarray, @@ -57,6 +57,7 @@ def interpolate_value_and_marg_util( if has_second_continuous_state: wealth_next, continuous_state_next = wealth_and_continuous_state_next + regular_grid = exog_grids[1] # interp_for_single_state_choice = vmap( # interp2d_value_and_marg_util_for_state_choice, @@ -69,15 +70,16 @@ def interpolate_value_and_marg_util( interp_for_single_state_choice = vmap( vmap( interp2d_value_and_marg_util_for_state_choice, - in_axes=(None, None, 0, 0, 0, 0, 0, 0, None), # state-choice + in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # state-choice ), - in_axes=(None, None, 0, 1, 1, 0, 0, 0, None), # continuous state + in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state ) return interp_for_single_state_choice( compute_marginal_utility, compute_utility, state_choice_vec, + regular_grid, wealth_next, continuous_state_next, endog_grid_child_state_choice, @@ -164,8 +166,8 @@ def interp_on_single_wealth_point(wealth_point): return value_interp, marg_util_interp - interp_over_single_wealth_and_income_shock_draw = vmap( # income shocks - vmap(interp_on_single_wealth_point) # wealth + interp_over_single_wealth_and_income_shock_draw = vmap( # outer: income shocks + vmap(interp_on_single_wealth_point) # inner: wealth ) value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw( diff --git a/src/dcegm/egm/solve_euler_equation.py b/src/dcegm/egm/solve_euler_equation.py index 51f3eae2..c991ea33 100644 --- a/src/dcegm/egm/solve_euler_equation.py +++ b/src/dcegm/egm/solve_euler_equation.py @@ -8,7 +8,7 @@ def calculate_candidate_solutions_from_euler_equation( - exogenous_savings_grid: np.ndarray, + exog_savings_grid: np.ndarray, marg_util: jnp.ndarray, emax: jnp.ndarray, state_choice_vec: np.ndarray, @@ -43,7 +43,7 @@ def calculate_candidate_solutions_from_euler_equation( )( feasible_marg_utils_child, feasible_emax_child, - exogenous_savings_grid, + exog_savings_grid, state_choice_vec, compute_inverse_marginal_utility, compute_utility, diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index 0c55d4f1..4a830170 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -16,7 +16,7 @@ def setup_model( options: Dict, - exog_savings_grid: jnp.ndarray, + exog_grids: jnp.ndarray, utility_functions: Dict[str, Callable], utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, @@ -47,9 +47,7 @@ def setup_model( budget_constraint (Callable): User supplied budget constraint. """ - options = check_options_and_set_defaults( - options, exog_savings_grid=exog_savings_grid - ) + options = check_options_and_set_defaults(options, exog_grids=exog_grids) model_funcs = process_model_functions( options, @@ -76,7 +74,7 @@ def setup_model( return { "options": options, - "exog_savings_grid": exog_savings_grid, + "exog_savings_grid": exog_grids, "model_funcs": model_funcs, "model_structure": model_structure, "batch_info": jax.tree.map(create_array_with_smallest_int_dtype, batch_info), @@ -102,7 +100,7 @@ def setup_and_save_model( model = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=exog_savings_grid, state_space_functions=state_space_functions, utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, @@ -132,7 +130,7 @@ def load_and_setup_model( model = pickle.load(open(path, "rb")) model["options"] = check_options_and_set_defaults( - options, exog_savings_grid=model["exog_savings_grid"] + options, exog_grids=model["exog_savings_grid"] ) model["model_funcs"] = process_model_functions( diff --git a/src/dcegm/pre_processing/state_space.py b/src/dcegm/pre_processing/state_space.py index 077fdfca..5e4155bc 100644 --- a/src/dcegm/pre_processing/state_space.py +++ b/src/dcegm/pre_processing/state_space.py @@ -513,9 +513,9 @@ def create_indexer_for_space(space): return map_vars_to_index -def check_options_and_set_defaults(options, exog_savings_grid): +def check_options_and_set_defaults(options, exog_grids): """Check if options are valid and set defaults.""" - n_grid_points = exog_savings_grid.shape[0] + n_savings_grid_points = exog_grids[0].shape[0] if not isinstance(options, dict): raise ValueError("Options must be a dictionary.") @@ -556,12 +556,14 @@ def check_options_and_set_defaults(options, exog_savings_grid): options["tuning_params"]["n_constrained_points_to_add"] = ( options["tuning_params"]["n_constrained_points_to_add"] if "n_constrained_points_to_add" in options["tuning_params"] - else n_grid_points // 10 + else n_savings_grid_points // 10 ) if ( - n_grid_points * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) - < n_grid_points + options["tuning_params"]["n_constrained_points_to_add"] + n_savings_grid_points + * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) + < n_savings_grid_points + + options["tuning_params"]["n_constrained_points_to_add"] ): raise ValueError( f"""\n\n @@ -572,7 +574,8 @@ def check_options_and_set_defaults(options, exog_savings_grid): the credit constrained part of the wealth grid. \n\n""" ) options["tuning_params"]["n_total_wealth_grid"] = int( - n_grid_points * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) + n_savings_grid_points + * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) ) options["has_second_continuous_state"] = False diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index e24cedd5..eaa140ec 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -24,7 +24,7 @@ def solve_dcegm( params: pd.DataFrame, options: Dict, - exog_savings_grid: jnp.ndarray, + exog_grids: Tuple[jnp.ndarray, jnp.ndarray], utility_functions: Dict[str, Callable], utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, @@ -60,7 +60,7 @@ def solve_dcegm( backward_jit = get_solve_function( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=exog_grids, state_space_functions=state_space_functions, utility_functions=utility_functions, budget_constraint=budget_constraint, @@ -74,7 +74,7 @@ def solve_dcegm( def get_solve_function( options: Dict[str, Any], - exog_savings_grid: jnp.ndarray, + exog_grids: jnp.ndarray, utility_functions: Dict[str, Callable], budget_constraint: Callable, utility_functions_final_period: Dict[str, Callable], @@ -108,7 +108,7 @@ def get_solve_function( model = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=exog_grids, state_space_functions=state_space_functions, utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, @@ -122,8 +122,8 @@ def get_solve_func_for_model(model): """Create a solve function, which only takes params as input.""" options = model["options"] - exog_savings_grid = model["exog_savings_grid"] has_second_continuous_state = options["has_second_continuous_state"] + exog_grids = model["exog_savings_grid"] # ToDo: Make interface with several draw possibilities. # ToDo: Some day make user supplied draw function. @@ -135,7 +135,7 @@ def get_solve_func_for_model(model): partial( backward_induction, options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=exog_grids, has_second_continuous_state=has_second_continuous_state, state_space_dict=model["model_structure"]["state_space_dict"], n_state_choices=model["model_structure"]["state_choice_space"].shape[0], @@ -157,7 +157,7 @@ def backward_induction( params: Dict[str, float], options: Dict[str, Any], has_second_continuous_state: bool, - exog_savings_grid: np.ndarray, + exog_grids: np.ndarray, state_space_dict: np.ndarray, n_state_choices: int, batch_info: Dict[str, np.ndarray], @@ -243,8 +243,9 @@ def backward_induction( calculate_resources_for_second_continuous_state( discrete_states_beginning_of_next_period=state_space_dict, continuous_state_beginning_of_next_period=continuous_state_next_period, - savings_end_of_last_period=exog_savings_grid, - income_shocks_of_period=income_shock_draws_unscaled * params["sigma"], + savings_end_of_previous_period=exog_grids, + income_shocks_current_period=income_shock_draws_unscaled + * params["sigma"], params=params, compute_beginning_of_period_resources=model_funcs[ "compute_beginning_of_period_resources" @@ -260,8 +261,8 @@ def backward_induction( else: wealth_and_continuous_state_next_period = calculate_resources( discrete_states_beginning_of_period=state_space_dict, - savings_end_of_last_period=exog_savings_grid, - income_shocks_of_period=income_shock_draws_unscaled * params["sigma"], + savings_end_of_previous_period=exog_grids[0], + income_shocks_current_period=income_shock_draws_unscaled * params["sigma"], params=params, compute_beginning_of_period_resources=model_funcs[ "compute_beginning_of_period_resources" @@ -293,7 +294,7 @@ def backward_induction( params=params, taste_shock_scale=taste_shock_scale, income_shock_weights=income_shock_weights, - exog_savings_grid=exog_savings_grid, + exog_savings_grid=exog_grids[0], model_funcs=model_funcs, batch_info=batch_info, value_solved=value_solved, @@ -311,7 +312,7 @@ def partial_single_period(carry, xs): xs=xs, has_second_continuous_state=has_second_continuous_state, params=params, - exog_savings_grid=exog_savings_grid, + exog_grids=exog_grids, wealth_and_continuous_state_next_period=wealth_and_continuous_state_next_period, income_shock_weights=income_shock_weights, model_funcs=model_funcs, diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index b0fc270c..7b08b3d4 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -1,4 +1,3 @@ -import jax from jax import vmap from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values @@ -13,7 +12,7 @@ def solve_single_period( xs, has_second_continuous_state, params, - exog_savings_grid, + exog_grids, income_shock_weights, wealth_and_continuous_state_next_period, model_funcs, @@ -37,6 +36,7 @@ def solve_single_period( compute_marginal_utility=model_funcs["compute_marginal_utility"], compute_utility=model_funcs["compute_utility"], state_choice_vec=state_choice_mat_child, + exog_grids=exog_grids, wealth_and_continuous_state_next=wealth_and_continuous_state_next_period[ child_state_idxs ], @@ -62,7 +62,7 @@ def solve_single_period( params, taste_shock_scale, income_shock_weights, - exog_savings_grid, + exog_grids[0], model_funcs, ) @@ -109,7 +109,7 @@ def solve_for_interpolated_values( policy_candidate, expected_values, ) = calculate_candidate_solutions_from_euler_equation( - exogenous_savings_grid=exog_savings_grid, + exog_savings_grid=exog_savings_grid, marg_util=marg_util, emax=emax, state_choice_vec=state_choice_mat, diff --git a/tests/conftest.py b/tests/conftest.py index 6dc71a1b..4e59bb9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -177,7 +177,7 @@ def toy_model_exog_ltc( out = {} model = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), @@ -191,7 +191,7 @@ def toy_model_exog_ltc( ) = solve_dcegm( params, options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), @@ -218,7 +218,7 @@ def toy_model_exog_ltc_and_job_offer( out = {} model = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), @@ -232,7 +232,7 @@ def toy_model_exog_ltc_and_job_offer( ) = solve_dcegm( params, options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), diff --git a/tests/test_changing_choice_set.py b/tests/test_changing_choice_set.py index 0f8d5d84..4dbe82d0 100644 --- a/tests/test_changing_choice_set.py +++ b/tests/test_changing_choice_set.py @@ -209,7 +209,7 @@ def test_extended_choice_set_model( state_space_functions=state_space_functions, utility_functions=utility_functions, budget_constraint=budget, - exog_savings_grid=savings_grid, + exog_grids=(savings_grid,), utility_functions_final_period=utility_functions_final_period, ) sol = solve_func(params) @@ -219,7 +219,7 @@ def test_extended_choice_set_model( ) model = setup_model( options=options, - exog_savings_grid=savings_grid, + exog_grids=(savings_grid,), state_space_functions=state_space_functions, utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, diff --git a/tests/test_pre_processing.py b/tests/test_pre_processing.py index fa0abc1d..871caf80 100644 --- a/tests/test_pre_processing.py +++ b/tests/test_pre_processing.py @@ -126,7 +126,7 @@ def test_load_and_save_model( model_setup = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), @@ -135,7 +135,7 @@ def test_load_and_save_model( model_after_saving = setup_and_save_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_savings_grid=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), @@ -200,7 +200,7 @@ def test_grid_parameters(): with pytest.raises(ValueError) as e: setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), diff --git a/tests/test_replication.py b/tests/test_replication.py index a2dab55a..aa2ea454 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -68,7 +68,7 @@ def test_benchmark_models( model = setup_model( options=options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=state_space_functions, utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, @@ -78,7 +78,7 @@ def test_benchmark_models( value, policy, endog_grid = solve_dcegm( params, options, - exog_savings_grid=exog_savings_grid, + exog_grids=(exog_savings_grid,), state_space_functions=state_space_functions, utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period,