Skip to content

Commit

Permalink
Merge branch 'refactor' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell authored Aug 8, 2024
2 parents 9c3f27e + 719b281 commit cfe62ca
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 42 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers =
Topic :: Scientific/Engineering
Topic :: Software Development :: Build Tools
project_urls =
Github = https://github.com/segsell/dcegm
Github = https://github.com/OpenSourceEconomics/dcegm

[options]
packages = find:
Expand Down
103 changes: 80 additions & 23 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from jax import numpy as jnp
from jax import vmap

from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_value_and_marg_utility_on_next_period_wealth(
def interpolate_value_and_marg_util(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -15,25 +15,81 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate value and policy in the child states and compute the marginal value.
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for all child states and compute marginal utility.
Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility. The input ```params``` is already partialled in.
compute_value (callable): Function for calculating the value from consumption
level, discrete choice and expected value. The inputs ```discount_rate```
and ```compute_utility``` are already partialled in.
wealth_next_period (jnp.ndarray): 2d array of shape
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
choice (int): The agent's discrete choice.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
choice_policies_child_state_choice (jnp.ndarray): 1d array containing the
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
choice_values_child_state_choice (jnp.ndarray): 1d array containing the
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
params (dict): Dictionary containing the model parameters.
Returns:
tuple:
- value_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated value function.
- marg_util_interp (jnp.ndarray): 2d array of shape (n_wealth_grid, n_income_shocks)
containing the interpolated marginal utilities for each wealth level and
income shock.
"""

interp_for_single_state_choice = vmap(
interpolate_value_and_marg_util_for_single_state_choice,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)

return interp_for_single_state_choice(
compute_marginal_utility,
compute_utility,
state_choice_vec,
wealth_beginning_of_period,
endog_grid_child_state_choice,
policy_child_state_choice,
value_child_state_choice,
params,
)


def interpolate_value_and_marg_util_for_single_state_choice(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
wealth_beginning_of_period: jnp.ndarray,
endog_grid_child_state_choice: jnp.ndarray,
policy_child_state_choice: jnp.ndarray,
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Interpolate value and policy for given child state and compute marginal utility.
Args:
compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility of consumption.
compute_utility (callable): Function for calculating the utility of consumption.
state_choice_vec (dict): Dictionary containing the state and choice of the agent.
wealth_beginning_of_period (jnp.ndarray): 2d array of shape
(n_quad_stochastic, n_grid_wealth,) containing the agent's beginning of
period wealth.
endog_grid_child_state_choice (jnp.ndarray): 1d array containing the endogenous
wealth grid of the child state/choice pair. Shape (n_grid_wealth,).
policy_child_state_choice (jnp.ndarray): 1d array containing the
corresponding policy function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
value_child_state_choice (jnp.ndarray): 1d array containing the
corresponding value function values of the endogenous wealth grid of the
child state/choice pair. Shape (n_grid_wealth,).
params (dict): Dictionary containing the model parameters.
Expand All @@ -49,10 +105,8 @@ def interpolate_value_and_marg_utility_on_next_period_wealth(
"""

# Generate interpolation function for single wealth point where the endogenous grid,
# policy and value are fixed.
def interp_on_single_wealth(wealth):
policy_interp, value_interp = interp_value_and_policy_on_wealth(
def interp_on_single_wealth_point(wealth):
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
Expand All @@ -61,14 +115,17 @@ def interp_on_single_wealth(wealth):
state_choice_vec=state_choice_vec,
params=params,
)
marg_utility_interp = compute_marginal_utility(
marg_util_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)
return marg_utility_interp, value_interp

# Generate vectorized function for savings and income shock dimension
vector_interp_func = vmap(vmap(interp_on_single_wealth))
return value_interp, marg_util_interp

marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period)
# Vectorize over savings and income shock dimension
interp_for_savings_point_and_income_shock_draw = vmap(
vmap(interp_on_single_wealth_point)
)
value_interp, marg_util_interp = interp_for_savings_point_and_income_shock_draw(
wealth_beginning_of_period
)

return marg_utils, value_interp
return value_interp, marg_util_interp
6 changes: 3 additions & 3 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import jax.numpy as jnp

from dcegm.interpolation import (
from dcegm.interpolation.interp1d import (
interp_policy_on_wealth,
interp_value_and_policy_on_wealth,
interp_value_on_wealth,
interpolate_policy_and_value_on_wealth,
)


Expand Down Expand Up @@ -34,7 +34,7 @@ def policy_and_value_for_state_choice_vec(
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
policy, value = interp_value_and_policy_on_wealth(
policy, value = interpolate_policy_and_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_index_high_and_low(x, x_new):
return ind_high, ind_high - 1


def interp_value_and_policy_on_wealth(
def interpolate_policy_and_value_on_wealth(
wealth: float | jnp.ndarray,
endog_grid: jnp.ndarray,
policy: jnp.ndarray,
Expand All @@ -45,7 +45,7 @@ def interp_value_and_policy_on_wealth(
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate value and policy given a single wealth value.
"""Interpolate policy and value function given a single wealth grid point.
Args:
wealth (float): Wealth value to interpolate.
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
calculate_choice_probs_and_unsqueezed_logsum,
)
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation import interp_value_on_wealth
from dcegm.interpolation.interp1d import interp_value_on_wealth
from dcegm.solve import get_solve_func_for_model


Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dcegm.budget import calculate_resources_for_all_agents
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation.interp1d import interpolate_policy_and_value_on_wealth


def interpolate_policy_and_value_for_all_agents(
Expand Down Expand Up @@ -172,7 +172,7 @@ def interpolate_policy_and_value_function(
):
state_choice_vec = {**state, "choice": choice}

policy_interp, value_interp = interp_value_and_policy_on_wealth(
policy_interp, value_interp = interpolate_policy_and_value_on_wealth(
wealth=resources_beginning_of_period,
endog_grid=endog_grid_agent,
policy=policy_agent,
Expand Down
12 changes: 3 additions & 9 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from jax import vmap

from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values
from dcegm.egm.interpolate_marginal_utility import (
interpolate_value_and_marg_utility_on_next_period_wealth,
)
from dcegm.egm.interpolate_marginal_utility import interpolate_value_and_marg_util
from dcegm.egm.solve_euler_equation import (
calculate_candidate_solutions_from_euler_equation,
)
Expand All @@ -19,8 +17,7 @@ def solve_single_period(
model_funcs,
taste_shock_scale,
):
"""This function solves a single period of the model using the discrete continuous
endogenous grid method (DCEGM)."""
"""Solve a single period of the model using the DCEGM method."""
(value_solved, policy_solved, endog_grid_solved) = carry

(
Expand All @@ -34,10 +31,7 @@ def solve_single_period(
) = xs

# EGM step 1)
marginal_utility_interpolated, value_interpolated = vmap(
interpolate_value_and_marg_utility_on_next_period_wealth,
in_axes=(None, None, 0, 0, 0, 0, 0, None),
)(
value_interpolated, marginal_utility_interpolated = interpolate_value_and_marg_util(
model_funcs["compute_marginal_utility"],
model_funcs["compute_utility"],
state_choice_mat_child,
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/interpolations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
from jax import numpy as jnp

from dcegm.interpolation import get_index_high_and_low, linear_interpolation_formula
from dcegm.interpolation.interp1d import (
get_index_high_and_low,
linear_interpolation_formula,
)


def linear_interpolation_with_extrapolation(x, y, x_new):
Expand Down

0 comments on commit cfe62ca

Please sign in to comment.