Skip to content

Commit

Permalink
Fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Oct 4, 2024
2 parents 5d94c41 + 2d17681 commit 93f2dfb
Show file tree
Hide file tree
Showing 17 changed files with 566 additions and 456 deletions.
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ ignore:
- tests/*
- tests/**/*
- .tox/**/*
- src/toy_models/*
- src/dcegm/likelihood.py
- src/dcegm/interface.py
2 changes: 1 addition & 1 deletion src/dcegm/egm/solve_euler_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def compute_optimal_policy_and_value_wrapper(
params: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Write second continuous grid point into state_choice_vec."""
state_choice_vec["second_continuous"] = second_continuous_grid
state_choice_vec["continuous_state"] = second_continuous_grid

return compute_optimal_policy_and_value(
marg_util_next,
Expand Down
99 changes: 69 additions & 30 deletions src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import jax.numpy as jnp
from jax import vmap
from numpy.testing import assert_array_almost_equal as aaae

from dcegm.law_of_motion import (
calc_resources_for_each_continuous_state_and_savings_grid_point,
)
from dcegm.solve_single_period import solve_for_interpolated_values


Expand Down Expand Up @@ -60,15 +62,14 @@ def solve_last_two_periods(
cont_grids_next_period=cont_grids_next_period,
exog_grids=exog_grids,
params=params,
compute_utility=model_funcs["compute_utility_final"],
compute_marginal_utility=model_funcs["compute_marginal_utility_final"],
model_funcs=model_funcs,
value_solved=value_solved,
policy_solved=policy_solved,
endog_grid_solved=endog_grid_solved,
has_second_continuous_state=has_second_continuous_state,
)

endog_grid, policy, value, marg_util, emax = solve_for_interpolated_values(
endog_grid, policy, value = solve_for_interpolated_values(
value_interpolated=value_interp_final_period,
marginal_utility_interpolated=marginal_utility_final_last_period,
state_choice_mat=batch_info["state_choice_mat_second_last_period"],
Expand Down Expand Up @@ -103,8 +104,7 @@ def solve_final_period(
cont_grids_next_period: Dict[str, jnp.ndarray],
exog_grids: Dict[str, jnp.ndarray],
params: Dict[str, float],
compute_utility: Callable,
compute_marginal_utility: Callable,
model_funcs: Dict[str, Callable],
value_solved,
policy_solved,
endog_grid_solved,
Expand Down Expand Up @@ -146,8 +146,7 @@ def solve_final_period(
cont_grids_next_period=cont_grids_next_period,
exog_grids=exog_grids,
params=params,
compute_utility=compute_utility,
compute_marginal_utility=compute_marginal_utility,
model_funcs=model_funcs,
value_solved=value_solved,
policy_solved=policy_solved,
endog_grid_solved=endog_grid_solved,
Expand All @@ -166,8 +165,8 @@ def solve_final_period(
cont_grids_next_period=cont_grids_next_period,
exog_grids=exog_grids,
params=params,
compute_utility=compute_utility,
compute_marginal_utility=compute_marginal_utility,
compute_utility=model_funcs["compute_utility_final"],
compute_marginal_utility=model_funcs["compute_marginal_utility_final"],
value_solved=value_solved,
policy_solved=policy_solved,
endog_grid_solved=endog_grid_solved,
Expand All @@ -182,9 +181,9 @@ def solve_final_period(
)


###############################################################
# =====================================================================================
# Solve final period discrete states only
###############################################################
# =====================================================================================


def solve_final_period_discrete(
Expand Down Expand Up @@ -230,7 +229,7 @@ def solve_final_period_discrete(
)
# Choose which draw we take for policy and value function as those are not
# saved with respect to the draws
middle_of_draws = int(value.shape[2] + 1 / 2)
middle_of_draws = int((value.shape[2] - 1) / 2)
# Select solutions to store
value_final = value[:, :, middle_of_draws]

Expand Down Expand Up @@ -268,9 +267,9 @@ def solve_final_period_discrete(
)


###############################################################
# =====================================================================================
# Solver final period with second continuous state
###############################################################
# =====================================================================================


def solve_final_period_second_continuous(
Expand All @@ -280,8 +279,7 @@ def solve_final_period_second_continuous(
cont_grids_next_period: Dict[str, jnp.ndarray],
exog_grids: Dict[str, jnp.ndarray],
params: Dict[str, float],
compute_utility: Callable,
compute_marginal_utility: Callable,
model_funcs: Dict[str, Callable],
value_solved,
policy_solved,
endog_grid_solved,
Expand Down Expand Up @@ -320,47 +318,54 @@ def solve_final_period_second_continuous(
resources_child_states_final_period,
continuous_state_final,
params,
compute_utility,
compute_marginal_utility,
model_funcs["compute_utility_final"],
model_funcs["compute_marginal_utility_final"],
)

# For the value to save in the second continuous case, we calculate the value
# at the exogenous wealth and second continuous points
value_regular = vmap(
value_regular, wealth_at_regular = vmap(
vmap(
vmap(
calc_value_for_each_gridpoint_second_continuous,
in_axes=(None, 0, None, None, None), # wealth
calc_value_and_budget_for_each_gridpoint,
in_axes=(None, 0, None, None, None, None), # wealth
),
in_axes=(None, None, 0, None, None), # second continuous_state
in_axes=(None, None, 0, None, None, None), # second continuous_state
),
in_axes=(0, None, None, None, None), # discrete state choices
in_axes=(0, None, None, None, None, None), # discrete state choices
)(
state_choice_mat_final_period,
exog_grids["wealth"],
exog_grids["second_continuous"],
params,
compute_utility,
model_funcs["compute_utility_final"],
model_funcs["compute_beginning_of_period_resources"],
)

sort_idx = jnp.argsort(wealth_at_regular, axis=2)
wealth_sorted = jnp.take_along_axis(wealth_at_regular, sort_idx, axis=2)
values_sorted = jnp.take_along_axis(value_regular, sort_idx, axis=2)

# Store results and add zero entry for the first column
zeros_to_append = jnp.zeros(value_regular.shape[:-1])
zeros_to_append = jnp.zeros(values_sorted.shape[:-1])

# Stack along the second-to-last axis (axis 1)
values_with_zeros = jnp.concatenate(
(zeros_to_append[..., None], value_regular), axis=2
(zeros_to_append[..., None], values_sorted), axis=2
)
wealth_with_zeros = jnp.concatenate(
(zeros_to_append[..., None], wealth_sorted), axis=2
)
exog_wealth_with_zero = jnp.append(0, exog_grids["wealth"])

value_solved = value_solved.at[
idx_state_choices_final_period, :, : n_wealth + 1
].set(values_with_zeros)
policy_solved = policy_solved.at[
idx_state_choices_final_period, :, : n_wealth + 1
].set(exog_wealth_with_zero)
].set(wealth_with_zeros)
endog_grid_solved = endog_grid_solved.at[
idx_state_choices_final_period, :, : n_wealth + 1
].set(exog_wealth_with_zero)
].set(wealth_with_zeros)

return (
value_solved,
Expand Down Expand Up @@ -410,12 +415,46 @@ def calc_value_and_marg_util_for_each_gridpoint_second_continuous(
marg_util = compute_marginal_utility(
**state_choice_vec,
resources=wealth_final_period,
continuous_state=second_continuous_state,
params=params,
)

return value, marg_util


def calc_value_and_budget_for_each_gridpoint(
state_choice_vec,
savings_grid_point,
second_continuous_state,
params,
compute_utility,
compute_beginning_of_period_resources,
):
state_vec = state_choice_vec.copy()
state_vec.pop("choice")

wealth_final_period = (
calc_resources_for_each_continuous_state_and_savings_grid_point(
state_vec=state_vec,
continuous_state_beginning_of_period=second_continuous_state,
exog_savings_grid_point=savings_grid_point,
income_shock_draw=jnp.array(0.0),
params=params,
compute_beginning_of_period_resources=compute_beginning_of_period_resources,
)
)

value = calc_value_for_each_gridpoint_second_continuous(
state_choice_vec=state_choice_vec,
wealth_final_period=wealth_final_period,
second_continuous_state=second_continuous_state,
params=params,
compute_utility=compute_utility,
)

return value, wealth_final_period


def calc_value_for_each_gridpoint_second_continuous(
state_choice_vec,
wealth_final_period,
Expand Down
10 changes: 8 additions & 2 deletions src/dcegm/interpolation/interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ def interp2d_value_and_check_creditconstraint(
# Now recalculate the closed-form value of consuming all wealth
value_calc_left = (
compute_utility(
consumption=wealth_point_to_interp, params=params, **state_choice_vec
consumption=wealth_point_to_interp,
params=params,
continuous_state=regular_point_to_interp,
**state_choice_vec
)
+ params["beta"] * value_at_zero_wealth[regular_idx_left]
)
Expand All @@ -206,7 +209,10 @@ def interp2d_value_and_check_creditconstraint(
)
value_calc_right = (
compute_utility(
consumption=wealth_point_to_interp, params=params, **state_choice_vec
consumption=wealth_point_to_interp,
continuous_state=regular_point_to_interp,
params=params,
**state_choice_vec
)
+ params["beta"] * value_at_zero_wealth[regular_idx_right]
)
Expand Down
6 changes: 2 additions & 4 deletions src/dcegm/law_of_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def calc_cont_grids_next_period(
discrete_states_beginning_of_period=state_space_dict,
continuous_grid=exog_grids["second_continuous"],
params=params,
compute_continuous_state=model_funcs[
"compute_beginning_of_period_continuous_state"
],
compute_continuous_state=model_funcs["update_continuous_state"],
)

# Extra dimension for continuous state
Expand Down Expand Up @@ -111,7 +109,7 @@ def calc_resources_for_each_continuous_state_and_savings_grid_point(
):
out = compute_beginning_of_period_resources(
**state_vec,
continuous_state_beginning_of_period=continuous_state_beginning_of_period,
continuous_state=continuous_state_beginning_of_period,
savings_end_of_previous_period=exog_savings_grid_point,
income_shock_previous_period=income_shock_draw,
params=params,
Expand Down
12 changes: 8 additions & 4 deletions src/dcegm/pre_processing/exog_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options


def create_exog_transition_function(options):
def create_exog_transition_function(options, continuous_state_name):
"""Create the exogenous process transition function.
The output function takes a state vector(also choice?), params and options as input.
Expand All @@ -18,15 +18,17 @@ def create_exog_transition_function(options):
compute_exog_transition_vec = return_dummy_exog_transition
processed_exog_funcs_dict = {}
else:
exog_funcs, processed_exog_funcs_dict = process_exog_funcs(options)
exog_funcs, processed_exog_funcs_dict = process_exog_funcs(
options, continuous_state_name
)

compute_exog_transition_vec = partial(
get_exog_transition_vec, exog_funcs=exog_funcs
)
return compute_exog_transition_vec, processed_exog_funcs_dict


def process_exog_funcs(options):
def process_exog_funcs(options, continuous_state_name):
"""Process exogenous functions.
Args:
Expand All @@ -45,7 +47,9 @@ def process_exog_funcs(options):
for exog_name, exog_dict in exog_processes.items():
if isinstance(exog_dict["transition"], Callable):
processed_exog_func = determine_function_arguments_and_partial_options(
func=exog_dict["transition"], options=options["model_params"]
func=exog_dict["transition"],
options=options["model_params"],
continuous_state_name=continuous_state_name,
)
exog_funcs += [processed_exog_func]
processed_exog_funcs[exog_name] = processed_exog_func
Expand Down
Loading

0 comments on commit 93f2dfb

Please sign in to comment.