Skip to content

Commit

Permalink
Combine savings and second grid in tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Aug 21, 2024
1 parent 817d33d commit 582a010
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 56 deletions.
16 changes: 8 additions & 8 deletions src/dcegm/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _transform_lagged_choice_to_working_hours(lagged_choice):
def calculate_resources_for_second_continuous_state(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
savings_end_of_last_period,
income_shocks_of_period,
savings_end_of_previous_period,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
):
Expand All @@ -97,8 +97,8 @@ def calculate_resources_for_second_continuous_state(
)(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
savings_end_of_last_period,
income_shocks_of_period,
savings_end_of_previous_period,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
)
Expand All @@ -110,8 +110,8 @@ def calculate_resources_for_second_continuous_state(

def calculate_resources(
discrete_states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
savings_end_of_previous_period,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
):
Expand All @@ -126,8 +126,8 @@ def calculate_resources(
in_axes=(0, None, None, None, None), # discrete states
)(
discrete_states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
savings_end_of_previous_period,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
)
Expand Down
12 changes: 7 additions & 5 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from jax import numpy as jnp
from jax import vmap
from numpy.testing import assert_array_almost_equal as aaae

from dcegm.interpolation.interp1d import interp1d_policy_and_value_on_wealth
from dcegm.interpolation.interp2d import (
Expand All @@ -14,6 +13,7 @@ def interpolate_value_and_marg_util(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
exog_grids: Tuple[jnp.ndarray, jnp.ndarray],
wealth_and_continuous_state_next: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
Expand Down Expand Up @@ -57,6 +57,7 @@ def interpolate_value_and_marg_util(

if has_second_continuous_state:
wealth_next, continuous_state_next = wealth_and_continuous_state_next
regular_grid = exog_grids[1]

Check warning on line 60 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L59-L60

Added lines #L59 - L60 were not covered by tests

# interp_for_single_state_choice = vmap(
# interp2d_value_and_marg_util_for_state_choice,
Expand All @@ -69,15 +70,16 @@ def interpolate_value_and_marg_util(
interp_for_single_state_choice = vmap(

Check warning on line 70 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L70

Added line #L70 was not covered by tests
vmap(
interp2d_value_and_marg_util_for_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, 0, None), # state-choice
in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # state-choice
),
in_axes=(None, None, 0, 1, 1, 0, 0, 0, None), # continuous state
in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state
)

return interp_for_single_state_choice(

Check warning on line 78 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L78

Added line #L78 was not covered by tests
compute_marginal_utility,
compute_utility,
state_choice_vec,
regular_grid,
wealth_next,
continuous_state_next,
endog_grid_child_state_choice,
Expand Down Expand Up @@ -164,8 +166,8 @@ def interp_on_single_wealth_point(wealth_point):

return value_interp, marg_util_interp

interp_over_single_wealth_and_income_shock_draw = vmap( # income shocks
vmap(interp_on_single_wealth_point) # wealth
interp_over_single_wealth_and_income_shock_draw = vmap( # outer: income shocks
vmap(interp_on_single_wealth_point) # inner: wealth
)

value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw(
Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/egm/solve_euler_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def calculate_candidate_solutions_from_euler_equation(
exogenous_savings_grid: np.ndarray,
exog_savings_grid: np.ndarray,
marg_util: jnp.ndarray,
emax: jnp.ndarray,
state_choice_vec: np.ndarray,
Expand Down Expand Up @@ -43,7 +43,7 @@ def calculate_candidate_solutions_from_euler_equation(
)(
feasible_marg_utils_child,
feasible_emax_child,
exogenous_savings_grid,
exog_savings_grid,
state_choice_vec,
compute_inverse_marginal_utility,
compute_utility,
Expand Down
12 changes: 5 additions & 7 deletions src/dcegm/pre_processing/setup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def setup_model(
options: Dict,
exog_savings_grid: jnp.ndarray,
exog_grids: jnp.ndarray,
utility_functions: Dict[str, Callable],
utility_functions_final_period: Dict[str, Callable],
budget_constraint: Callable,
Expand Down Expand Up @@ -47,9 +47,7 @@ def setup_model(
budget_constraint (Callable): User supplied budget constraint.
"""
options = check_options_and_set_defaults(
options, exog_savings_grid=exog_savings_grid
)
options = check_options_and_set_defaults(options, exog_grids=exog_grids)

model_funcs = process_model_functions(
options,
Expand All @@ -76,7 +74,7 @@ def setup_model(

return {
"options": options,
"exog_savings_grid": exog_savings_grid,
"exog_savings_grid": exog_grids,
"model_funcs": model_funcs,
"model_structure": model_structure,
"batch_info": jax.tree.map(create_array_with_smallest_int_dtype, batch_info),
Expand All @@ -102,7 +100,7 @@ def setup_and_save_model(

model = setup_model(
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=exog_savings_grid,
state_space_functions=state_space_functions,
utility_functions=utility_functions,
utility_functions_final_period=utility_functions_final_period,
Expand Down Expand Up @@ -132,7 +130,7 @@ def load_and_setup_model(
model = pickle.load(open(path, "rb"))

model["options"] = check_options_and_set_defaults(
options, exog_savings_grid=model["exog_savings_grid"]
options, exog_grids=model["exog_savings_grid"]
)

model["model_funcs"] = process_model_functions(
Expand Down
15 changes: 9 additions & 6 deletions src/dcegm/pre_processing/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,9 @@ def create_indexer_for_space(space):
return map_vars_to_index


def check_options_and_set_defaults(options, exog_savings_grid):
def check_options_and_set_defaults(options, exog_grids):
"""Check if options are valid and set defaults."""
n_grid_points = exog_savings_grid.shape[0]
n_savings_grid_points = exog_grids[0].shape[0]

if not isinstance(options, dict):
raise ValueError("Options must be a dictionary.")
Expand Down Expand Up @@ -556,12 +556,14 @@ def check_options_and_set_defaults(options, exog_savings_grid):
options["tuning_params"]["n_constrained_points_to_add"] = (
options["tuning_params"]["n_constrained_points_to_add"]
if "n_constrained_points_to_add" in options["tuning_params"]
else n_grid_points // 10
else n_savings_grid_points // 10
)

if (
n_grid_points * (1 + options["tuning_params"]["extra_wealth_grid_factor"])
< n_grid_points + options["tuning_params"]["n_constrained_points_to_add"]
n_savings_grid_points
* (1 + options["tuning_params"]["extra_wealth_grid_factor"])
< n_savings_grid_points
+ options["tuning_params"]["n_constrained_points_to_add"]
):
raise ValueError(
f"""\n\n
Expand All @@ -572,7 +574,8 @@ def check_options_and_set_defaults(options, exog_savings_grid):
the credit constrained part of the wealth grid. \n\n"""
)
options["tuning_params"]["n_total_wealth_grid"] = int(
n_grid_points * (1 + options["tuning_params"]["extra_wealth_grid_factor"])
n_savings_grid_points
* (1 + options["tuning_params"]["extra_wealth_grid_factor"])
)

options["has_second_continuous_state"] = False
Expand Down
27 changes: 14 additions & 13 deletions src/dcegm/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def solve_dcegm(
params: pd.DataFrame,
options: Dict,
exog_savings_grid: jnp.ndarray,
exog_grids: Tuple[jnp.ndarray, jnp.ndarray],
utility_functions: Dict[str, Callable],
utility_functions_final_period: Dict[str, Callable],
budget_constraint: Callable,
Expand Down Expand Up @@ -60,7 +60,7 @@ def solve_dcegm(

backward_jit = get_solve_function(
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=exog_grids,
state_space_functions=state_space_functions,
utility_functions=utility_functions,
budget_constraint=budget_constraint,
Expand All @@ -74,7 +74,7 @@ def solve_dcegm(

def get_solve_function(
options: Dict[str, Any],
exog_savings_grid: jnp.ndarray,
exog_grids: jnp.ndarray,
utility_functions: Dict[str, Callable],
budget_constraint: Callable,
utility_functions_final_period: Dict[str, Callable],
Expand Down Expand Up @@ -108,7 +108,7 @@ def get_solve_function(

model = setup_model(
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=exog_grids,
state_space_functions=state_space_functions,
utility_functions=utility_functions,
utility_functions_final_period=utility_functions_final_period,
Expand All @@ -122,8 +122,8 @@ def get_solve_func_for_model(model):
"""Create a solve function, which only takes params as input."""

options = model["options"]
exog_savings_grid = model["exog_savings_grid"]
has_second_continuous_state = options["has_second_continuous_state"]
exog_grids = model["exog_savings_grid"]

# ToDo: Make interface with several draw possibilities.
# ToDo: Some day make user supplied draw function.
Expand All @@ -135,7 +135,7 @@ def get_solve_func_for_model(model):
partial(
backward_induction,
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=exog_grids,
has_second_continuous_state=has_second_continuous_state,
state_space_dict=model["model_structure"]["state_space_dict"],
n_state_choices=model["model_structure"]["state_choice_space"].shape[0],
Expand All @@ -157,7 +157,7 @@ def backward_induction(
params: Dict[str, float],
options: Dict[str, Any],
has_second_continuous_state: bool,
exog_savings_grid: np.ndarray,
exog_grids: np.ndarray,
state_space_dict: np.ndarray,
n_state_choices: int,
batch_info: Dict[str, np.ndarray],
Expand Down Expand Up @@ -243,8 +243,9 @@ def backward_induction(
calculate_resources_for_second_continuous_state(
discrete_states_beginning_of_next_period=state_space_dict,
continuous_state_beginning_of_next_period=continuous_state_next_period,
savings_end_of_last_period=exog_savings_grid,
income_shocks_of_period=income_shock_draws_unscaled * params["sigma"],
savings_end_of_previous_period=exog_grids,
income_shocks_current_period=income_shock_draws_unscaled
* params["sigma"],
params=params,
compute_beginning_of_period_resources=model_funcs[
"compute_beginning_of_period_resources"
Expand All @@ -260,8 +261,8 @@ def backward_induction(
else:
wealth_and_continuous_state_next_period = calculate_resources(
discrete_states_beginning_of_period=state_space_dict,
savings_end_of_last_period=exog_savings_grid,
income_shocks_of_period=income_shock_draws_unscaled * params["sigma"],
savings_end_of_previous_period=exog_grids[0],
income_shocks_current_period=income_shock_draws_unscaled * params["sigma"],
params=params,
compute_beginning_of_period_resources=model_funcs[
"compute_beginning_of_period_resources"
Expand Down Expand Up @@ -293,7 +294,7 @@ def backward_induction(
params=params,
taste_shock_scale=taste_shock_scale,
income_shock_weights=income_shock_weights,
exog_savings_grid=exog_savings_grid,
exog_savings_grid=exog_grids[0],
model_funcs=model_funcs,
batch_info=batch_info,
value_solved=value_solved,
Expand All @@ -311,7 +312,7 @@ def partial_single_period(carry, xs):
xs=xs,
has_second_continuous_state=has_second_continuous_state,
params=params,
exog_savings_grid=exog_savings_grid,
exog_grids=exog_grids,
wealth_and_continuous_state_next_period=wealth_and_continuous_state_next_period,
income_shock_weights=income_shock_weights,
model_funcs=model_funcs,
Expand Down
8 changes: 4 additions & 4 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
from jax import vmap

from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values
Expand All @@ -13,7 +12,7 @@ def solve_single_period(
xs,
has_second_continuous_state,
params,
exog_savings_grid,
exog_grids,
income_shock_weights,
wealth_and_continuous_state_next_period,
model_funcs,
Expand All @@ -37,6 +36,7 @@ def solve_single_period(
compute_marginal_utility=model_funcs["compute_marginal_utility"],
compute_utility=model_funcs["compute_utility"],
state_choice_vec=state_choice_mat_child,
exog_grids=exog_grids,
wealth_and_continuous_state_next=wealth_and_continuous_state_next_period[
child_state_idxs
],
Expand All @@ -62,7 +62,7 @@ def solve_single_period(
params,
taste_shock_scale,
income_shock_weights,
exog_savings_grid,
exog_grids[0],
model_funcs,
)

Expand Down Expand Up @@ -109,7 +109,7 @@ def solve_for_interpolated_values(
policy_candidate,
expected_values,
) = calculate_candidate_solutions_from_euler_equation(
exogenous_savings_grid=exog_savings_grid,
exog_savings_grid=exog_savings_grid,
marg_util=marg_util,
emax=emax,
state_choice_vec=state_choice_mat,
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def toy_model_exog_ltc(
out = {}
model = setup_model(
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=(exog_savings_grid,),
state_space_functions=create_state_space_function_dict(),
utility_functions=create_utility_function_dict(),
utility_functions_final_period=create_final_period_utility_function_dict(),
Expand All @@ -191,7 +191,7 @@ def toy_model_exog_ltc(
) = solve_dcegm(
params,
options,
exog_savings_grid=exog_savings_grid,
exog_grids=(exog_savings_grid,),
state_space_functions=create_state_space_function_dict(),
utility_functions=create_utility_function_dict(),
utility_functions_final_period=create_final_period_utility_function_dict(),
Expand All @@ -218,7 +218,7 @@ def toy_model_exog_ltc_and_job_offer(
out = {}
model = setup_model(
options=options,
exog_savings_grid=exog_savings_grid,
exog_grids=(exog_savings_grid,),
state_space_functions=create_state_space_function_dict(),
utility_functions=create_utility_function_dict(),
utility_functions_final_period=create_final_period_utility_function_dict(),
Expand All @@ -232,7 +232,7 @@ def toy_model_exog_ltc_and_job_offer(
) = solve_dcegm(
params,
options,
exog_savings_grid=exog_savings_grid,
exog_grids=(exog_savings_grid,),
state_space_functions=create_state_space_function_dict(),
utility_functions=create_utility_function_dict(),
utility_functions_final_period=create_final_period_utility_function_dict(),
Expand Down
Loading

0 comments on commit 582a010

Please sign in to comment.