Skip to content

Commit

Permalink
Add income shock dimension to test
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Aug 22, 2024
1 parent 6cc27fe commit b4a7f70
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 108 deletions.
52 changes: 10 additions & 42 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,6 @@ def interpolate_value_and_marg_util(
wealth_next, continuous_state_next = wealth_and_continuous_state_next
regular_grid = exog_grids[1]

Check warning on line 60 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L59-L60

Added lines #L59 - L60 were not covered by tests

# interp_for_single_state_choice = vmap(
# interp2d_value_and_marg_util_for_state_choice,
# in_axes=(None, None, 0, 0, 0, 0, 0, 0, None),
# )
# interp_for_single_state_choice = vmap(
# interp_for_single_state_choice, in_axes=(None, None, 0, 1, 1, 0, 0, 0, None)
# )

interp_for_single_state_choice = vmap(

Check warning on line 62 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L62

Added line #L62 was not covered by tests
vmap(
interp2d_value_and_marg_util_for_state_choice,
Expand Down Expand Up @@ -167,9 +159,9 @@ def interp_on_single_wealth_point(wealth_point):

return value_interp, marg_util_interp

interp_over_single_wealth_and_income_shock_draw = vmap( # outer: income shocks
vmap(interp_on_single_wealth_point) # inner: wealth
)
interp_over_single_wealth_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point) # income shocks
) # wealth grid

value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw(
wealth_beginning_of_next_period
Expand Down Expand Up @@ -223,7 +215,6 @@ def interp2d_value_and_marg_util_for_state_choice(
containing the interpolated value function.
"""
# breakpoint()

def interp_on_single_wealth_point(wealth_point, regular_point):

Expand Down Expand Up @@ -252,7 +243,13 @@ def interp_on_single_wealth_point(wealth_point, regular_point):

# Outer vmap applies first
interp_over_single_wealth_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point, in_axes=(0, None)), # wealth grid
vmap(
vmap(
interp_on_single_wealth_point,
in_axes=(0, None), # income shocks
),
in_axes=(0, None), # wealth grid
),
in_axes=(0, 0), # continuous state grid
)

Expand All @@ -264,32 +261,3 @@ def interp_on_single_wealth_point(wealth_point, regular_point):
)

return value_interp, marg_util_interp


def interp_over_continuous_grid(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
regular_grid: jnp.ndarray,
wealth_beginning_of_next_period: jnp.ndarray,
continuous_state_beginning_of_next_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
):
return vmap(
interp2d_value_and_marg_util_for_state_choice,
in_axes=(None, None, None, None, 0, 0, 0, 0, 0, None),
)(
compute_marginal_utility,
compute_utility,
state_choice_vec,
regular_grid,
wealth_beginning_of_next_period,
continuous_state_beginning_of_next_period,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)
2 changes: 0 additions & 2 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def solve_for_interpolated_values(
# Run upper envelope over all state-choice combinations to remove suboptimal
# candidates
# extra dimension for second continuous state

# breakpoint()
(
endog_grid_state_choice,
policy_state_choice,
Expand Down
70 changes: 6 additions & 64 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

from dcegm.egm.interpolate_marginal_utility import (
interp2d_value_and_marg_util_for_state_choice,
interp_over_continuous_grid,
interpolate_value_and_marg_util,
)
from dcegm.interpolation.interp2d import (
interp2d_policy_and_value_on_wealth_and_regular_grid,
Expand Down Expand Up @@ -174,36 +172,13 @@ def functional_form(x, y):
# =====================================================================================


# def budget(lagged_choice, savings_end_of_previous_period, options, params):
# retired = lagged_choice == 0
# working = lagged_choice == 1

# retirement_income = params["pension"]
# labor_income = params["labor_income"]
# income = working * labor_income + retired * retirement_income

# return jnp.maximum(
# income + (1 + params["interest_rate"]) * savings_end_of_previous_period,
# params["consumption_floor"],
# )
import pytest


# @pytest.mark.skip()
def test_interp2d_value_and_marg_util():

marginal_utility_crra_partial = partial(marginal_utility_crra, options={})

savings_grid = np.linspace(0, 10_000, 100)
experience_grid = np.linspace(0, 1, 6)

exog_grids = (savings_grid, experience_grid) = np.meshgrid(
np.linspace(0, 1, 10), np.linspace(0, 1, 10)
)

# for _ in range(20):
np.random.seed(1234)

a, b = np.random.uniform(1, 10), np.random.uniform(1, 10)

def functional_form(x, y):
Expand All @@ -222,49 +197,16 @@ def functional_form(x, y):
wealth_next = np.random.uniform(30, 40, 100)
experience_next = np.random.choice(experience_grid, 6)

wealth_next_state_choice = np.tile(wealth_next, (2, 6, 1))
# wealth_next_state_choice = np.tile(
# np.expand_dims(np.tile(wealth_next, (2, 6, 1)), axis=-1), (1, 1, 1, 5)
# )
wealth_next_state_choice = np.tile(
np.expand_dims(np.tile(wealth_next, (2, 6, 1)), axis=-1), (1, 1, 1, 5)
)
experience_next_state_choice = np.tile(experience_next, (2, 1))
# experience_next_state_choice = np.tile(
# np.expand_dims(experience_next_state_choice, axis=-1), (1, 1, 100)
# )

policy_state_choice = np.tile(policy, (2, 1, 1))
value_state_choice = np.tile(value, (2, 1, 1))
wealth_grid_state_choice = np.tile(wealth_grid, (2, 1, 1))
# experience_grid_state_choice = np.tile(experience_grid, (2, 1))

# wealth_and_continuous_state_next_period = (wealth_next, experience_next)
# value_interpolated, marginal_utility_interpolated = interpolate_value_and_marg_util(
# compute_marginal_utility=marginal_utility_crra,
# compute_utility=utility_crra,
# state_choice_vec={"choice": 0},
# exog_grids=exog_grids,
# wealth_and_continuous_state_next=wealth_and_continuous_state_next_period,
# endog_grid_child_state_choice=wealth_grid,
# policy_child_state_choice=policy,
# value_child_state_choice=value,
# has_second_continuous_state=True,
# params=PARAMS,
# )

interp_for_single_state_choice = vmap(
# vmap(
# in_axes=(
# None,
# None,
# None,
# 0,
# 1,
# 0,
# 1,
# 1,
# 1,
# None,
# ), # continuous state
# ),
# interp_over_continuous_grid,
interp2d_value_and_marg_util_for_state_choice,
in_axes=(None, None, 0, None, 0, 0, 0, 0, 0, None), # discrete state-choice
)
Expand Down Expand Up @@ -292,5 +234,5 @@ def functional_form(x, y):
PARAMS,
)

np.testing.assert_equal(marg_util.shape, (2, 6, 100))
np.testing.assert_equal(val.shape, (2, 6, 100))
np.testing.assert_equal(marg_util.shape, (2, 6, 100, 5))
np.testing.assert_equal(val.shape, (2, 6, 100, 5))

0 comments on commit b4a7f70

Please sign in to comment.