Skip to content

Commit

Permalink
Fix first simulation test
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Jun 6, 2024
1 parent e1420c8 commit 4d248d7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def simulate_all_periods(
value_solved,
model,
):
# Set initial states to internal dtype
state_space_dict = model["model_structure"]["state_space_dict"]
states_initial = {
key: value.astype(state_space_dict[key].dtype)
Expand Down
20 changes: 12 additions & 8 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def model_setup(toy_model_exog_ltc):
}


@pytest.mark.skip()
# @pytest.mark.skip()
def test_simulate_lax_scan(model_setup):
params = model_setup["params"]
options = model_setup["options"]
Expand All @@ -86,12 +86,18 @@ def test_simulate_lax_scan(model_setup):
policy = model_setup["policy"]
endog_grid = model_setup["endog_grid"]

initial_states_and_resources = (
model_setup["initial_states"],
model_setup["initial_resources"],
)
states_initial = model_setup["initial_states"]
resources_initial = model_setup["initial_resources"]

sim_specific_keys = model_setup["sim_specific_keys"]

state_space_dict = model_structure["state_space_dict"]
states_initial = {
key: value.astype(state_space_dict[key].dtype)
for key, value in states_initial.items()
}
initial_states_and_resources = states_initial, resources_initial

simulate_body = partial(
simulate_single_period,
params=params,
Expand All @@ -100,7 +106,7 @@ def test_simulate_lax_scan(model_setup):
value_solved=value,
policy_solved=policy,
map_state_choice_to_index=jnp.array(map_state_choice_to_index),
choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=jnp.int64),
choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=np.uint8),
compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"],
compute_utility=model_funcs["compute_utility"],
compute_beginning_of_period_resources=model_funcs[
Expand All @@ -110,7 +116,6 @@ def test_simulate_lax_scan(model_setup):
get_next_period_state=get_next_period_state,
)

# lax.scan
(
lax_states_and_resources_beginning_of_final_period,
lax_sim_dict_zero,
Expand Down Expand Up @@ -167,7 +172,6 @@ def test_simulate(model_setup):

n_agents = 100_000

# We need 64 because we do not alter the model array dtypes.
initial_states = {
"period": np.zeros(n_agents, dtype=np.int64),
"lagged_choice": np.zeros(
Expand Down

0 comments on commit 4d248d7

Please sign in to comment.