Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor interpolation wrappers #116

Merged
merged 6 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
hooks:
- id: setup-cfg-fmt
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
language_version: python3.10
Expand All @@ -84,7 +84,7 @@ repos:
- '88'
- --blank
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.5
rev: 1.8.7
hooks:
- id: nbqa-black
- id: nbqa-ruff
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers =
Topic :: Scientific/Engineering
Topic :: Software Development :: Build Tools
project_urls =
Github = https://github.com/segsell/dcegm
Github = https://github.com/OpenSourceEconomics/dcegm

[options]
packages = find:
Expand Down
103 changes: 80 additions & 23 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from jax import numpy as jnp
from jax import vmap

from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_value_and_marg_utility_on_next_period_wealth(
def interpolate_value_and_marg_util(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -15,25 +15,81 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate value and policy in the child states and compute the marginal value.
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for all child states and compute marginal utility.

Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility. The input ```params``` is already partialled in.
compute_value (callable): Function for calculating the value from consumption
level, discrete choice and expected value. The inputs ```discount_rate```
and ```compute_utility``` are already partialled in.
wealth_next_period (jnp.ndarray): 2d array of shape
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
choice (int): The agent's discrete choice.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
choice_policies_child_state_choice (jnp.ndarray): 1d array containing the
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
choice_values_child_state_choice (jnp.ndarray): 1d array containing the
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
params (dict): Dictionary containing the model parameters.

Returns:
tuple:

- value_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated value function.
- marg_util_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated marginal utilities for each wealth level and
income shock.

"""

interp_for_single_state_choice = vmap(
interpolate_value_and_marg_util_for_single_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
wealth_beginning_of_period,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)


def interpolate_value_and_marg_util_for_single_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
wealth_beginning_of_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for given child state and compute marginal utility.

Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
params (dict): Dictionary containing the model parameters.
Expand All @@ -49,10 +105,8 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(

"""

# Generate interpolation function for single wealth point where the endogenous grid,
# policy and value are fixed.
def interp_on_single_wealth(wealth):
policy_interp, value_interp = interp_value_and_policy_on_wealth(
def interp_on_single_wealth_point(wealth):
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
Expand All @@ -61,14 +115,17 @@ def interp_on_single_wealth(wealth):
state_choice_vec=state_choice_vec,
params=params,
)
marg_utility_interp = compute_marginal_utility(
marg_util_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)
return marg_utility_interp, value_interp

# Generate vectorized function for savings and income shock dimension
vector_interp_func = vmap(vmap(interp_on_single_wealth))
return value_interp, marg_util_interp

marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period)
# Vectorize over savings and income shock dimension
interp_for_savings_point_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point)
)
value_interp, marg_util_interp = interp_for_savings_point_and_income_shock_draw(
wealth_beginning_of_period
)

return marg_utils, value_interp
return value_interp, marg_util_interp
6 changes: 3 additions & 3 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import jax.numpy as jnp

from dcegm.interpolation import (
from dcegm.interpolation.interp1d import (
interp_policy_on_wealth,
interp_value_and_policy_on_wealth,
interp_value_on_wealth,
interpolate_policy_and_value_on_wealth,
)


Expand Down Expand Up @@ -34,7 +34,7 @@ def policy_and_value_for_state_choice_vec(
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
policy, value = interp_value_and_policy_on_wealth(
policy, value = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_index_high_and_low(x, x_new):
return ind_high, ind_high - 1


def interp_value_and_policy_on_wealth(
def interpolate_policy_and_value_on_wealth(
wealth: float | jnp.ndarray,
endog_grid: jnp.ndarray,
policy: jnp.ndarray,
Expand All @@ -45,7 +45,7 @@ def interp_value_and_policy_on_wealth(
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate value and policy given a single wealth value.
"""Interpolate policy and value function given a single wealth grid point.

Args:
wealth (float): Wealth value to interpolate.
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
calculate_choice_probs_and_unsqueezed_logsum,
)
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation import interp_value_on_wealth
from dcegm.interpolation.interp1d import interp_value_on_wealth
from dcegm.solve import get_solve_func_for_model


Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dcegm.budget import calculate_resources_for_all_agents
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_policy_and_value_for_all_agents(
Expand Down Expand Up @@ -172,7 +172,7 @@ def interpolate_policy_and_value_function(
):
state_choice_vec = {**state, "choice": choice}

policy_interp, value_interp = interp_value_and_policy_on_wealth(
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=resources_beginning_of_period,
endog_grid=endog_grid_agent,
policy=policy_agent,
Expand Down
65 changes: 44 additions & 21 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from jax import vmap

from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values
from dcegm.egm.interpolate_marginal_utility import (
interpolate_value_and_marg_utility_on_next_period_wealth,
)
from dcegm.egm.interpolate_marginal_utility import interpolate_value_and_marg_util
from dcegm.egm.solve_euler_equation import (
calculate_candidate_solutions_from_euler_equation,
)
Expand All @@ -19,8 +17,7 @@ def solve_single_period(
model_funcs,
taste_shock_scale,
):
"""This function solves a single period of the model using the discrete continuous
endogenous grid method (DCEGM)."""
"""Solve a single period of the model using the DCEGM method."""
(value_solved, policy_solved, endog_grid_solved) = carry

(
Expand All @@ -34,10 +31,7 @@ def solve_single_period(
) = xs

# EGM step 1)
marginal_utility_interpolated, value_interpolated = vmap(
interpolate_value_and_marg_utility_on_next_period_wealth,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)(
value_interpolated, marginal_utility_interpolated = interpolate_value_and_marg_util(
model_funcs["compute_marginal_utility"],
model_funcs["compute_utility"],
state_choice_mat_child,
Expand Down Expand Up @@ -89,8 +83,8 @@ def solve_for_interpolated_values(
model_funcs,
):
# EGM step 2)
# Aggregate the marginal utilities and expected values over all choices and
# income shock draws
# Aggregate the marginal utilities and expected values over all state-choice
# combinations and income shock draws
marg_util, emax = aggregate_marg_utils_and_exp_values(
value_state_choice_specific=value_interpolated,
marg_util_state_choice_specific=marginal_utility_interpolated,
Expand Down Expand Up @@ -119,26 +113,55 @@ def solve_for_interpolated_values(
params=params,
)

# Run upper envelope to remove suboptimal candidates
# Run upper envelope over all state-choice combinations to remove suboptimal
# candidates
(
endog_grid_state_choice,
policy_state_choice,
value_state_choice,
) = vmap(
model_funcs["compute_upper_envelope"],
) = run_upper_envelope(
endog_grid_candidate=endog_grid_candidate,
policy_candidate=policy_candidate,
value_candidate=value_candidate,
expected_values=expected_values,
state_choice_mat=state_choice_mat,
compute_utility=model_funcs["compute_utility"],
params=params,
compute_upper_envelope_for_state_choice=model_funcs["compute_upper_envelope"],
)

return (
endog_grid_state_choice,
policy_state_choice,
value_state_choice,
)


def run_upper_envelope(
endog_grid_candidate,
policy_candidate,
value_candidate,
expected_values,
state_choice_mat,
compute_utility,
params,
compute_upper_envelope_for_state_choice,
):
"""Run upper envelope to remove suboptimal candidates.

Vectorized over all state-choice combinations.

"""

return vmap(
compute_upper_envelope_for_state_choice,
in_axes=(0, 0, 0, 0, 0, None, None), # vmap over state-choice combs
)(
endog_grid_candidate,
policy_candidate,
value_candidate,
expected_values[:, 0],
state_choice_mat,
model_funcs["compute_utility"],
compute_utility,
params,
)

return (
endog_grid_state_choice,
policy_state_choice,
value_state_choice,
)
5 changes: 4 additions & 1 deletion tests/utils/interpolations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
from jax import numpy as jnp

from dcegm.interpolation import get_index_high_and_low, linear_interpolation_formula
from dcegm.interpolation.interp1d import (
get_index_high_and_low,
linear_interpolation_formula,
)


def linear_interpolation_with_extrapolation(x, y, x_new):
Expand Down
Loading