diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index e79d2523..1cefe7ba 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -229,7 +229,7 @@ def solve_final_period_discrete( ) # Choose which draw we take for policy and value function as those are not # saved with respect to the draws - middle_of_draws = int(value.shape[2] + 1 / 2) + middle_of_draws = int((value.shape[2] - 1) / 2) # Select solutions to store value_final = value[:, :, middle_of_draws] diff --git a/tests/test_discrete_versus_continuous_experience.py b/tests/test_discrete_versus_continuous_experience.py index 7113c25e..9e6537fd 100644 --- a/tests/test_discrete_versus_continuous_experience.py +++ b/tests/test_discrete_versus_continuous_experience.py @@ -17,7 +17,7 @@ N_DISCRETE_CHOICES = 2 MAX_WEALTH = 50 WEALTH_GRID_POINTS = 100 -EXPERIENCE_GRID_POINTS = 6 +EXPERIENCE_GRID_POINTS = 5 PARAMS = { @@ -129,6 +129,7 @@ def test_setup(): @pytest.mark.parametrize( "period, experience, lagged_choice, choice", [ + (1, 0, 1, 0), (1, 0, 1, 0), (1, 1, 0, 0), (2, 1, 0, 1), @@ -220,5 +221,5 @@ def test_replication_discrete_versus_continuous_experience( params=PARAMS, ) - aaae(value_cont_interp, value_disc_interp, decimal=1e-6) - aaae(policy_cont_interp, policy_disc_interp, decimal=1e-6) + aaae(value_cont_interp, value_disc_interp, decimal=4) + aaae(policy_cont_interp, policy_disc_interp, decimal=4)