Skip to content

Commit

Permalink
Renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Aug 21, 2024
1 parent 582a010 commit d727b7a
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 32 deletions.
8 changes: 4 additions & 4 deletions src/dcegm/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _transform_lagged_choice_to_working_hours(lagged_choice):
def calculate_resources_for_second_continuous_state(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
savings_end_of_previous_period,
savings_grid,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
Expand All @@ -97,7 +97,7 @@ def calculate_resources_for_second_continuous_state(
)(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
savings_end_of_previous_period,
savings_grid,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
Expand All @@ -110,7 +110,7 @@ def calculate_resources_for_second_continuous_state(

def calculate_resources(
discrete_states_beginning_of_period,
savings_end_of_previous_period,
savings_grid,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
Expand All @@ -126,7 +126,7 @@ def calculate_resources(
in_axes=(0, None, None, None, None), # discrete states
)(
discrete_states_beginning_of_period,
savings_end_of_previous_period,
savings_grid,
income_shocks_current_period,
params,
compute_beginning_of_period_resources,
Expand Down
12 changes: 6 additions & 6 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def interpolate_value_and_marg_util(
interp_for_single_state_choice = vmap(

Check warning on line 70 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L70

Added line #L70 was not covered by tests
vmap(
interp2d_value_and_marg_util_for_state_choice,
in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # state-choice
in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state
),
in_axes=(None, None, None, 0, 1, 1, 0, 0, 0, None), # continuous state
in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, None), # discrete state-choice
)

return interp_for_single_state_choice(

Check warning on line 78 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L78

Added line #L78 was not covered by tests
Expand All @@ -91,7 +91,7 @@ def interpolate_value_and_marg_util(
else:
interp_for_single_state_choice = vmap(
interp1d_value_and_marg_util_for_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
in_axes=(None, None, 0, 0, 0, 0, 0, None), # discrete state-choice
)

return interp_for_single_state_choice(
Expand Down Expand Up @@ -181,12 +181,12 @@ def interp2d_value_and_marg_util_for_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
regular_grid: jnp.ndarray,
wealth_beginning_of_next_period: jnp.ndarray,
# regular_grid_child_state_choice: jnp.ndarray,
continuous_state_beginning_of_next_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
has_second_continuous_state: bool,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for given child state and compute marginal utility.
Expand Down Expand Up @@ -231,7 +231,7 @@ def interp_on_single_wealth_point(regular_point, wealth_point):

policy_interp, value_interp = (

Check warning on line 232 in src/dcegm/egm/interpolate_marginal_utility.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/egm/interpolate_marginal_utility.py#L232

Added line #L232 was not covered by tests
interp2d_policy_and_value_on_wealth_and_regular_grid(
# regular_grid=regular_grid_child_state_choice,
regular_grid=regular_grid,
wealth_grid=endog_grid_child_state_choice,
policy_grid=policy_child_state_choice,
value_grid=value_child_state_choice,
Expand Down
3 changes: 1 addition & 2 deletions src/dcegm/interpolation/interp1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def interp_value_and_check_creditconstraint(
x_new=new_wealth,
)

# Now recalculate the value when consumed all wealth
# Now recalculate the closed-form value of consuming all wealth
utility = compute_utility(
consumption=new_wealth,
params=params,
Expand All @@ -222,7 +222,6 @@ def interp_value_and_check_creditconstraint(
# Check if we are in the credit constrained region
credit_constraint = new_wealth <= endog_grid_min

# If so we return the value if all is consumed.
value_interp = (
credit_constraint * value_interp_closed_form
+ (1 - credit_constraint) * value_interp_on_grid
Expand Down
13 changes: 10 additions & 3 deletions src/dcegm/interpolation/interp2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Jax implementation of 2D interpolation."""

from typing import Callable
from typing import Callable, Dict

import jax.lax
import jax.numpy as jnp
Expand All @@ -23,6 +23,7 @@ def interp2d_policy_and_value_on_wealth_and_regular_grid(
regular_point_to_interp: jnp.ndarray | float,
wealth_point_to_interp: jnp.ndarray | float,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
params: dict,
):
"""Linear 2D interpolation on two grids where wealth has irregular spacing.
Expand Down Expand Up @@ -85,6 +86,7 @@ def interp2d_policy_and_value_on_wealth_and_regular_grid(
compute_utility=compute_utility,
wealth_min_unconstrained=wealth_grid[:, 1],
value_at_zero_wealth=value_grid[:, 0],
state_choice_vec=state_choice_vec,
params=params,
)

Expand Down Expand Up @@ -147,6 +149,7 @@ def interp2d_value_and_check_creditconstraint(
compute_utility,
wealth_min_unconstrained,
value_at_zero_wealth,
state_choice_vec,
params,
):
"""Interpolate the value function on a 2D grid and check for credit constraints.
Expand Down Expand Up @@ -189,7 +192,9 @@ def interp2d_value_and_check_creditconstraint(

# Now recalculate the closed-form value of consuming all wealth
value_calc_left = (
compute_utility(wealth_point_to_interp, params)
compute_utility(
consumption=wealth_point_to_interp, params=params, **state_choice_vec
)
+ params["beta"] * value_at_zero_wealth[regular_idx_left]
)

Expand All @@ -198,7 +203,9 @@ def interp2d_value_and_check_creditconstraint(
wealth_point_to_interp <= wealth_min_unconstrained[regular_idx_right]
)
value_calc_right = (
compute_utility(wealth_point_to_interp, params)
compute_utility(
consumption=wealth_point_to_interp, params=params, **state_choice_vec
)
+ params["beta"] * value_at_zero_wealth[regular_idx_right]
)

Expand Down
9 changes: 4 additions & 5 deletions src/dcegm/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ def backward_induction(
taste_shock_scale = params["lambda"]

if has_second_continuous_state:
continuous_grid = jnp.linspace(0, 1, 10)
# continuous_grid = jnp.linspace(0, 1, 10)
continuous_state_next_period = calculate_continuous_state(

Check warning on line 232 in src/dcegm/solve.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/solve.py#L232

Added line #L232 was not covered by tests
discrete_states_beginning_of_period=state_space_dict,
continuous_grid=continuous_grid,
continuous_grid=exog_grids[1],
params=params,
compute_continuous_state=model_funcs[
"compute_beginning_of_period_continuous_state"
Expand All @@ -243,7 +243,7 @@ def backward_induction(
calculate_resources_for_second_continuous_state(
discrete_states_beginning_of_next_period=state_space_dict,
continuous_state_beginning_of_next_period=continuous_state_next_period,
savings_end_of_previous_period=exog_grids,
savings_grid=exog_grids[0],
income_shocks_current_period=income_shock_draws_unscaled
* params["sigma"],
params=params,
Expand All @@ -261,15 +261,14 @@ def backward_induction(
else:
wealth_and_continuous_state_next_period = calculate_resources(
discrete_states_beginning_of_period=state_space_dict,
savings_end_of_previous_period=exog_grids[0],
savings_grid=exog_grids[0],
income_shocks_current_period=income_shock_draws_unscaled * params["sigma"],
params=params,
compute_beginning_of_period_resources=model_funcs[
"compute_beginning_of_period_resources"
],
)

# breakpoint()
# Create solution containers. The 20 percent extra in wealth grid needs to go
# into tuning parameters
(
Expand Down
26 changes: 17 additions & 9 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from numpy.testing import assert_array_almost_equal as aaae
from scipy.interpolate import griddata, interp1d

from dcegm.egm.interpolate_marginal_utility import (
interp2d_value_and_marg_util_for_state_choice,
interpolate_value_and_marg_util,
)
from dcegm.interpolation.interp2d import (
interp2d_policy_and_value_on_wealth_and_regular_grid,
)
Expand All @@ -28,14 +32,12 @@
custom_interp2d_quad,
custom_interp2d_quad_value_function,
)
from toy_models.consumption_retirement_model.utility_functions import (
marginal_utility_crra,
utility_crra,
)

PARAMS = {"beta": 0.95, "theta": 0.5}


def flow_util(cons, params):
"""Flow utility in the Epstein Zin case."""
theta = params["theta"]
return (cons ** (1 - theta)) / (1 - theta)
PARAMS = {"beta": 0.95, "rho": 0.5, "delta": -1}


@pytest.fixture()
Expand Down Expand Up @@ -120,7 +122,7 @@ def functional_form(x, y):
regular_grid,
values=value,
points=test_points,
float_util=flow_util,
flow_util=utility_crra,
params=PARAMS,
)

Expand All @@ -139,7 +141,8 @@ def functional_form(x, y):
value_grid=value_jax,
wealth_point_to_interp=x_in,
regular_point_to_interp=y_in,
compute_utility=flow_util,
compute_utility=utility_crra,
state_choice_vec={"choice": 0},
params=PARAMS,
)
)
Expand All @@ -154,3 +157,8 @@ def functional_form(x, y):

aaae(policy_interp_jax, policy_interp_scipy, decimal=7)
aaae(value_interp_jax, value_interp_custom, decimal=7)


def test_interp2d_value_and_marg_util():

pass
6 changes: 3 additions & 3 deletions tests/utils/interp2d_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def custom_interp2d_quad(x_grids, y_grid, values, points):


def custom_interp2d_quad_value_function(
x_grids, y_grid, values, points, *, float_util, params
x_grids, y_grid, values, points, *, flow_util, params
):
"""This function is able to interpolate linearly on a two dimensional grid where one
dimension has irregular spacing between the grid points.
Expand Down Expand Up @@ -141,15 +141,15 @@ def custom_interp2d_quad_value_function(
# z[i, 0] = transform(values[y_indx][0], theta)

# else:
z[i, 0] = float_util(x, params) + params["beta"] * values[y_indx][0]
z[i, 0] = flow_util(x, params) + params["beta"] * values[y_indx][0]

if x < x_grids[y_indx + 1][1]:
x_cords[i, 1] = x
# if x == 0:
# z[i, 1] = transform(values[y_indx + 1][0], params)

# else:
z[i, 1] = float_util(x, params) + params["beta"] * values[y_indx + 1][0]
z[i, 1] = flow_util(x, params) + params["beta"] * values[y_indx + 1][0]

alpha, beta = (x_cords @ A), (y_cords @ A)
m, l = calculate_map_params(points[:, 0], points[:, 1], alpha, beta)
Expand Down

0 comments on commit d727b7a

Please sign in to comment.