From 4d248d705d148139de33ca3fa5e0d4f7543b1af3 Mon Sep 17 00:00:00 2001 From: Sebastian Gsell Date: Thu, 6 Jun 2024 11:35:38 +0200 Subject: [PATCH] Fix first simulation test --- src/dcegm/simulation/simulate.py | 1 + tests/test_simulate.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index e7968ddd..a1ad0d63 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -24,6 +24,7 @@ def simulate_all_periods( value_solved, model, ): + # Set initial states to internal dtype state_space_dict = model["model_structure"]["state_space_dict"] states_initial = { key: value.astype(state_space_dict[key].dtype) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index be5d6bb3..4d1a2604 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -68,7 +68,7 @@ def model_setup(toy_model_exog_ltc): } -@pytest.mark.skip() +# @pytest.mark.skip() def test_simulate_lax_scan(model_setup): params = model_setup["params"] options = model_setup["options"] @@ -86,12 +86,18 @@ def test_simulate_lax_scan(model_setup): policy = model_setup["policy"] endog_grid = model_setup["endog_grid"] - initial_states_and_resources = ( - model_setup["initial_states"], - model_setup["initial_resources"], - ) + states_initial = model_setup["initial_states"] + resources_initial = model_setup["initial_resources"] + sim_specific_keys = model_setup["sim_specific_keys"] + state_space_dict = model_structure["state_space_dict"] + states_initial = { + key: value.astype(state_space_dict[key].dtype) + for key, value in states_initial.items() + } + initial_states_and_resources = states_initial, resources_initial + simulate_body = partial( simulate_single_period, params=params, @@ -100,7 +106,7 @@ def test_simulate_lax_scan(model_setup): value_solved=value, policy_solved=policy, map_state_choice_to_index=jnp.array(map_state_choice_to_index), - choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=jnp.int64), + choice_range=jnp.arange(map_state_choice_to_index.shape[-1], dtype=np.uint8), compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"], compute_utility=model_funcs["compute_utility"], compute_beginning_of_period_resources=model_funcs[ @@ -110,7 +116,6 @@ def test_simulate_lax_scan(model_setup): get_next_period_state=get_next_period_state, ) - # lax.scan ( lax_states_and_resources_beginning_of_final_period, lax_sim_dict_zero, @@ -167,7 +172,6 @@ def test_simulate(model_setup): n_agents = 100_000 - # We need 64 because we do not alter the model array dtypes. initial_states = { "period": np.zeros(n_agents, dtype=np.int64), "lagged_choice": np.zeros(