Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New main introduces bug #132

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/dcegm/law_of_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def calculate_resources_for_all_agents(
return resources_beginning_of_next_period


def calculate_second_continuous_state_for_all_agents(
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
params,
compute_continuous_state,
):
continuous_state_beginning_of_next_period = vmap(
calc_continuous_state_for_each_grid_point,
in_axes=(0, 0, None, None),
)(
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
params,
compute_continuous_state,
)
return continuous_state_beginning_of_next_period


def calculate_resources_given_second_continuous_state_for_all_agents(
states_beginning_of_period,
continuous_state_beginning_of_period,
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/pre_processing/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def determine_optimal_batch_size(
# the maximum index of the batch, i.e. if all state choice relevant to
# solve the current state choices of the batch are in previous batches
min_state_choice_idx = np.min(unique_child_state_choice_idxs)
if batch.max() > min_state_choice_idx:
if batch.max() >= min_state_choice_idx:
batch_not_found = True
need_to_reduce_batchsize = True
break
Expand Down
27 changes: 5 additions & 22 deletions src/dcegm/pre_processing/model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,10 @@ def process_model_functions(
)

# Now state space functions
(
get_state_specific_choice_set,
get_next_period_state,
update_continuous_state,
update_continuous_state_for_next_period,
) = process_state_space_functions(
state_space_functions, options, continuous_state_name
get_state_specific_choice_set, get_next_period_state, update_continuous_state = (
process_state_space_functions(
state_space_functions, options, continuous_state_name
)
)

# Budget equation
Expand All @@ -133,7 +130,6 @@ def process_model_functions(
"compute_marginal_utility_final": compute_marginal_utility_final,
"compute_beginning_of_period_resources": compute_beginning_of_period_resources,
"update_continuous_state": update_continuous_state,
"update_continuous_state_for_next_period": update_continuous_state_for_next_period,
"compute_exog_transition_vec": compute_exog_transition_vec,
"processed_exog_funcs": processed_exog_funcs_dict,
"get_state_specific_choice_set": get_state_specific_choice_set,
Expand Down Expand Up @@ -207,23 +203,10 @@ def get_next_period_state(**kwargs):
options=options["model_params"],
continuous_state_name=continuous_state_name,
)
update_continuous_state_for_next_period = (
determine_function_arguments_and_partial_options(
func=state_space_functions["update_continuous_state_for_next_period"],
options=options["model_params"],
continuous_state_name=continuous_state_name,
)
)
else:
update_continuous_state = None
update_continuous_state_for_next_period = None

return (
get_state_specific_choice_set,
get_next_period_state,
update_continuous_state,
update_continuous_state_for_next_period,
)
return get_state_specific_choice_set, get_next_period_state, update_continuous_state


def create_upper_envelope_function(options, continuous_state=None):
Expand Down
6 changes: 3 additions & 3 deletions src/dcegm/pre_processing/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def determine_function_arguments_and_partial_options_beginning_of_period(
func, options, continuous_state=None
func, options, continuous_state_name=None
):
signature = set(inspect.signature(func).parameters)

Expand All @@ -41,8 +41,8 @@
@functools.wraps(func)
def processed_func(**kwargs):

if continuous_state:
kwargs[continuous_state] = kwargs["continuous_state"]
if continuous_state_name:
kwargs[continuous_state_name] = kwargs["continuous_state"]

Check warning on line 45 in src/dcegm/pre_processing/shared.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/shared.py#L44-L45

Added lines #L44 - L45 were not covered by tests

func_kwargs = {key: kwargs[key] for key in signature}

Expand Down
153 changes: 11 additions & 142 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dcegm.law_of_motion import (
calculate_resources_for_all_agents,
calculate_resources_given_second_continuous_state_for_all_agents,
calculate_second_continuous_state_for_all_agents,
)


Expand Down Expand Up @@ -120,70 +121,6 @@ def interpolate_policy_and_value_for_all_agents(
return policy_agent, value_agent


# def interp1d_policy_and_value_for_all_agents(
# states_beginning_of_period,
# resources_beginning_of_period,
# value_solved,
# policy_solved,
# endog_grid_solved,
# map_state_choice_to_index,
# choice_range,
# params,
# state_space_names,
# compute_utility,
# second_continuous_state,
# ):
# """This function interpolates the policy and value function for all agents.

# It uses the states at the beginning of period to select the solved policy and value
# and then interpolates the wealth at the beginning of period on them.

# """
# breakpoint()
# discrete_state_choice_indexes = get_state_choice_index_per_discrete_state(
# map_state_choice_to_index=map_state_choice_to_index,
# states=states_beginning_of_period,
# state_space_names=state_space_names,
# )

# value_grid_agent = jnp.take(
# value_solved,
# discrete_state_choice_indexes,
# axis=0,
# mode="fill",
# fill_value=jnp.nan,
# )
# policy_grid_agent = jnp.take(policy_solved, discrete_state_choice_indexes, axis=0)
# endog_grid_agent = jnp.take(
# endog_grid_solved, discrete_state_choice_indexes, axis=0
# )

# # =================================================================================

# vectorized_interp = vmap(
# vmap(
# interpolate_policy_and_value_function,
# in_axes=(None, None, 0, 0, 0, 0, None, None), # wealth grid
# ),
# in_axes=(0, 0, 0, 0, 0, None, None, None), # discrete state-choices
# )

# # =================================================================================

# policy_agent, value_per_agent_interp = vectorized_interp(
# resources_beginning_of_period,
# states_beginning_of_period,
# endog_grid_agent,
# value_grid_agent,
# policy_grid_agent,
# choice_range,
# params,
# compute_utility,
# )

# return policy_agent, value_per_agent_interp


def transition_to_next_period(
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
Expand All @@ -197,6 +134,7 @@ def transition_to_next_period(
sim_specific_keys,
):
n_agents = savings_current_period.shape[0]

exog_states_next_period = vmap(
realize_exog_process, in_axes=(0, 0, 0, None, None, None)
)(
Expand Down Expand Up @@ -230,16 +168,16 @@ def transition_to_next_period(
)

if continuous_state_beginning_of_period is not None:
continuous_state_next_period = vmap(
update_continuous_state_for_one_agent,
in_axes=(None, 0, 0, 0, None), # choice
)(
compute_next_period_states["update_continuous_state_for_next_period"],
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
choice,
params,

continuous_state_next_period = calculate_second_continuous_state_for_all_agents(
discrete_states_beginning_of_period=discrete_states_next_period,
continuous_state_beginning_of_period=continuous_state_beginning_of_period,
params=params,
compute_continuous_state=compute_next_period_states[
"update_continuous_state"
],
)

resources_beginning_of_next_period = calculate_resources_given_second_continuous_state_for_all_agents(
states_beginning_of_period=discrete_states_next_period,
continuous_state_beginning_of_period=continuous_state_next_period,
Expand Down Expand Up @@ -267,75 +205,6 @@ def transition_to_next_period(
)


def _transition_to_next_period(
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
savings_current_period,
choice,
params,
compute_exog_transition_vec,
exog_state_mapping,
compute_beginning_of_period_resources,
compute_next_period_states,
sim_specific_keys,
):
n_agents = savings_current_period.shape[0]
exog_states_next_period = vmap(
realize_exog_process, in_axes=(0, 0, 0, None, None, None)
)(
discrete_states_beginning_of_period,
choice,
sim_specific_keys[2:, :],
params,
compute_exog_transition_vec,
exog_state_mapping,
)

discrete_states_next_period = vmap(
update_discrete_states_for_one_agent, in_axes=(None, 0, 0, None) # choice
)(
compute_next_period_states["get_next_period_state"],
discrete_states_beginning_of_period,
choice,
params,
)
continuous_state_next_period = vmap(
update_continuous_state_for_one_agent,
in_axes=(None, 0, 0, 0, None), # choice
)(
compute_next_period_states["update_continuous_state"],
discrete_states_beginning_of_period,
continuous_state_beginning_of_period,
choice,
params,
)

# Generate states next period and apply budged constraint for wealth at the
# beginning of next period.
# Initialize states by copying
states_next_period = discrete_states_beginning_of_period.copy()
states_to_update = {**discrete_states_next_period, **exog_states_next_period}
states_next_period.update(states_to_update)

# Draw income shocks.
income_shocks_next_period = draw_normal_shocks(
key=sim_specific_keys[1, :], num_agents=n_agents, mean=0, std=params["sigma"]
)
resources_beginning_of_next_period = calculate_resources_for_all_agents(
states_beginning_of_period=states_next_period,
savings_end_of_previous_period=savings_current_period,
income_shocks_of_period=income_shocks_next_period,
params=params,
compute_beginning_of_period_resources=compute_beginning_of_period_resources,
)

return (
resources_beginning_of_next_period,
states_next_period,
income_shocks_next_period,
)


def compute_final_utility_for_each_choice(
state_vec, choice, resources, params, compute_utility_final_period
):
Expand Down
5 changes: 1 addition & 4 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ def simulate_all_periods(

compute_next_period_states = {
"get_next_period_state": model_funcs["get_next_period_state"],
# "update_continuous_state": model_funcs["update_continuous_state"],
"update_continuous_state_for_next_period": model_funcs[
"update_continuous_state_for_next_period"
],
"update_continuous_state": model_funcs["update_continuous_state"],
}

simulate_body = partial(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def budget_constraint_cont_exp(
savings_end_of_previous_period,
income_shock_previous_period,
params,
options,
):
experience_years = experience * period
max_init_experience_period = period + options["max_init_experience"]
experience_years = experience * max_init_experience_period

return budget_constraint_exp(
lagged_choice=lagged_choice,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ def create_state_space_function_dict():
return {
"get_state_specific_choice_set": get_state_specific_feasible_choice_set,
"update_continuous_state": get_next_period_experience,
"update_continuous_state_for_next_period": get_next_period_experience_simulation,
}


def get_next_period_experience(period, lagged_choice, experience):
return (1 / period) * ((period - 1) * experience + (lagged_choice == 0))


def get_next_period_experience_simulation(period, choice, experience):
return (1 / (period + 1)) * (period * experience + (choice == 0))
def get_next_period_experience(period, lagged_choice, experience, options):
max_experience_period = period + options["max_init_experience"]
return (1 / max_experience_period) * (
(max_experience_period - 1) * experience + (lagged_choice == 0)
)
8 changes: 5 additions & 3 deletions src/toy_models/cons_ret_model_with_exp/state_space_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def sparsity_condition(
lagged_choice,
options,
):
if period < experience:
max_exp_period = period + options["max_init_experience"]
max_total_experience = options["n_periods"] + options["max_init_experience"]
if max_exp_period < experience:
return False
elif experience >= options["n_periods"]:
elif max_total_experience <= experience:
return False
elif (experience == period) & (lagged_choice == 1):
elif (experience == max_exp_period) & (lagged_choice == 1):
return False
elif (lagged_choice == 0) & (experience == 0):
return False
Expand Down
Loading
Loading