diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 266283fb..abe7c3b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,12 +41,19 @@ repos: - id: python-no-log-warn - id: python-use-type-annotations - id: text-unicode-replacement-char - - repo: https://github.com/asottile/reorder-python-imports - rev: v3.13.0 + # - repo: https://github.com/asottile/reorder-python-imports + # rev: v3.13.0 + # hooks: + # - id: reorder-python-imports + # args: + # - --py37-plus + - repo: https://github.com/pycqa/isort + rev: 5.13.2 hooks: - - id: reorder-python-imports + - id: isort + name: isort args: - - --py37-plus + - --profile=black - repo: https://github.com/asottile/setup-cfg-fmt rev: v2.5.0 hooks: diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index 6f092283..caea9e7d 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -1,11 +1,10 @@ -from typing import Callable -from typing import Dict -from typing import Tuple +from typing import Callable, Dict, Tuple -from dcegm.interpolation import interp_value_and_policy_on_wealth from jax import numpy as jnp from jax import vmap +from dcegm.interpolation import interp_value_and_policy_on_wealth + def interpolate_value_and_marg_utility_on_next_period_wealth( compute_marginal_utility: Callable, diff --git a/src/dcegm/egm/solve_euler_equation.py b/src/dcegm/egm/solve_euler_equation.py index f64190ce..f30668f3 100644 --- a/src/dcegm/egm/solve_euler_equation.py +++ b/src/dcegm/egm/solve_euler_equation.py @@ -1,8 +1,6 @@ """Auxiliary functions for the EGM algorithm.""" -from typing import Callable -from typing import Dict -from typing import Tuple +from typing import Callable, Dict, Tuple import numpy as np from jax import numpy as jnp diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index 6be1b29d..4070b09d 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -1,13 +1,12 @@ """Wrapper to solve the final period of the model.""" -from typing import Callable -from typing import Dict -from typing import Tuple +from typing import Callable, Dict, Tuple import jax.numpy as jnp -from dcegm.solve_single_period import solve_for_interpolated_values from jax import vmap +from dcegm.solve_single_period import solve_for_interpolated_values + def solve_last_two_periods( resources_beginning_of_period: jnp.ndarray, diff --git a/src/dcegm/interface.py b/src/dcegm/interface.py index 2d41db61..857f9054 100644 --- a/src/dcegm/interface.py +++ b/src/dcegm/interface.py @@ -1,7 +1,10 @@ import jax.numpy as jnp -from dcegm.interpolation import interp_policy_on_wealth -from dcegm.interpolation import interp_value_and_policy_on_wealth -from dcegm.interpolation import interp_value_on_wealth + +from dcegm.interpolation import ( + interp_policy_on_wealth, + interp_value_and_policy_on_wealth, + interp_value_on_wealth, +) def policy_and_value_for_state_choice_vec( diff --git a/src/dcegm/interpolation.py b/src/dcegm/interpolation.py index 65e1ab2b..49816573 100644 --- a/src/dcegm/interpolation.py +++ b/src/dcegm/interpolation.py @@ -1,6 +1,4 @@ -from typing import Callable -from typing import Dict -from typing import Tuple +from typing import Callable, Dict, Tuple import jax.numpy as jnp from jax import numpy as jnp diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 9ae5d49b..ae1981a3 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -4,12 +4,12 @@ """ -from typing import Any -from typing import Dict +from typing import Any, Dict import jax import jax.numpy as jnp import numpy as np + from dcegm.egm.aggregate_marginal_utility import ( calculate_choice_probs_and_unsqueezed_logsum, ) diff --git a/src/dcegm/numerical_integration.py b/src/dcegm/numerical_integration.py index b275c84d..0c6cb4cd 100644 --- a/src/dcegm/numerical_integration.py +++ b/src/dcegm/numerical_integration.py @@ -1,8 +1,7 @@ from typing import Tuple import numpy as np -from scipy.special import roots_hermite -from scipy.special import roots_sh_legendre +from scipy.special import roots_hermite, roots_sh_legendre from scipy.stats import norm diff --git a/src/dcegm/pre_processing/debugging.py b/src/dcegm/pre_processing/debugging.py index 531355b7..4dabad29 100644 --- a/src/dcegm/pre_processing/debugging.py +++ b/src/dcegm/pre_processing/debugging.py @@ -2,8 +2,11 @@ import numpy as np import pandas as pd -from dcegm.pre_processing.state_space import process_endog_state_specifications -from dcegm.pre_processing.state_space import process_exog_model_specifications + +from dcegm.pre_processing.state_space import ( + process_endog_state_specifications, + process_exog_model_specifications, +) def inspect_state_space( diff --git a/src/dcegm/pre_processing/exog_processes.py b/src/dcegm/pre_processing/exog_processes.py index 61052f55..fe38effd 100644 --- a/src/dcegm/pre_processing/exog_processes.py +++ b/src/dcegm/pre_processing/exog_processes.py @@ -1,9 +1,10 @@ from functools import partial from typing import Callable -from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options from jax import numpy as jnp +from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options + def create_exog_transition_function(options): """Create the exogenous process transition function. diff --git a/src/dcegm/pre_processing/model_functions.py b/src/dcegm/pre_processing/model_functions.py index 7b52d3a5..5eb2aaa0 100644 --- a/src/dcegm/pre_processing/model_functions.py +++ b/src/dcegm/pre_processing/model_functions.py @@ -1,10 +1,10 @@ -from typing import Callable -from typing import Dict +from typing import Callable, Dict import jax.numpy as jnp +from upper_envelope.fues_jax.fues_jax import fast_upper_envelope_wrapper + from dcegm.pre_processing.exog_processes import create_exog_transition_function from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options -from upper_envelope.fues_jax.fues_jax import fast_upper_envelope_wrapper def process_model_functions( diff --git a/src/dcegm/pre_processing/params.py b/src/dcegm/pre_processing/params.py index 1dd8873f..8e5e0a91 100644 --- a/src/dcegm/pre_processing/params.py +++ b/src/dcegm/pre_processing/params.py @@ -1,5 +1,4 @@ -from typing import Dict -from typing import Union +from typing import Dict, Union import pandas as pd diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index cc26beaf..4a459cdb 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -1,13 +1,15 @@ import pickle -from typing import Callable -from typing import Dict +from typing import Callable, Dict import numpy as np + from dcegm.pre_processing.batches import create_batches_and_information from dcegm.pre_processing.exog_processes import create_exog_state_mapping from dcegm.pre_processing.model_functions import process_model_functions -from dcegm.pre_processing.state_space import check_options -from dcegm.pre_processing.state_space import create_state_space_and_choice_objects +from dcegm.pre_processing.state_space import ( + check_options, + create_state_space_and_choice_objects, +) def setup_model( diff --git a/src/dcegm/pre_processing/state_space.py b/src/dcegm/pre_processing/state_space.py index ffa8387d..a6c98cb1 100644 --- a/src/dcegm/pre_processing/state_space.py +++ b/src/dcegm/pre_processing/state_space.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np + from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index eb5fdccb..19fcb186 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -1,11 +1,12 @@ import jax import numpy as np import pandas as pd +from jax import numpy as jnp +from jax import vmap + from dcegm.budget import calculate_resources_for_all_agents from dcegm.interface import get_state_choice_index_per_state from dcegm.interpolation import interp_value_and_policy_on_wealth -from jax import numpy as jnp -from jax import vmap def interpolate_policy_and_value_for_all_agents( diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index dd859833..97719595 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -5,14 +5,17 @@ import jax import jax.numpy as jnp import numpy as np -from dcegm.interface import get_state_choice_index_per_state -from dcegm.simulation.sim_utils import compute_final_utility_for_each_choice -from dcegm.simulation.sim_utils import draw_taste_shocks -from dcegm.simulation.sim_utils import interpolate_policy_and_value_for_all_agents -from dcegm.simulation.sim_utils import transition_to_next_period -from dcegm.simulation.sim_utils import vectorized_utility from jax import vmap +from dcegm.interface import get_state_choice_index_per_state +from dcegm.simulation.sim_utils import ( + compute_final_utility_for_each_choice, + draw_taste_shocks, + interpolate_policy_and_value_for_all_agents, + transition_to_next_period, + vectorized_utility, +) + def simulate_all_periods( states_initial, diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index b7e42658..be69c457 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -1,22 +1,20 @@ """Interface for the DC-EGM algorithm.""" from functools import partial -from typing import Any -from typing import Callable -from typing import Dict -from typing import Tuple +from typing import Any, Callable, Dict, Tuple import jax.lax import jax.numpy as jnp import numpy as np import pandas as pd +from jax import jit + from dcegm.budget import calculate_resources from dcegm.final_periods import solve_last_two_periods from dcegm.numerical_integration import quadrature_legendre from dcegm.pre_processing.params import process_params from dcegm.pre_processing.setup_model import setup_model from dcegm.solve_single_period import solve_single_period -from jax import jit def solve_dcegm( diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index 38412146..fa72ba91 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -1,3 +1,5 @@ +from jax import vmap + from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values from dcegm.egm.interpolate_marginal_utility import ( interpolate_value_and_marg_utility_on_next_period_wealth, @@ -5,7 +7,6 @@ from dcegm.egm.solve_euler_equation import ( calculate_candidate_solutions_from_euler_equation, ) -from jax import vmap def solve_single_period( diff --git a/src/toy_models/consumption_retirement_model/budget_functions.py b/src/toy_models/consumption_retirement_model/budget_functions.py index 3631c5de..24760e14 100644 --- a/src/toy_models/consumption_retirement_model/budget_functions.py +++ b/src/toy_models/consumption_retirement_model/budget_functions.py @@ -1,5 +1,4 @@ -from typing import Any -from typing import Dict +from typing import Any, Dict import jax import jax.numpy as jnp diff --git a/src/toy_models/consumption_retirement_model/utility_functions.py b/src/toy_models/consumption_retirement_model/utility_functions.py index 193324e3..9b0a853a 100644 --- a/src/toy_models/consumption_retirement_model/utility_functions.py +++ b/src/toy_models/consumption_retirement_model/utility_functions.py @@ -1,5 +1,4 @@ -from typing import Any -from typing import Dict +from typing import Any, Dict import jax.numpy as jnp diff --git a/tests/conftest.py b/tests/conftest.py index 7f24fd64..6d30a8b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,24 +9,23 @@ import pandas as pd import pytest import yaml + from dcegm.pre_processing.setup_model import setup_model from dcegm.solve import solve_dcegm +from tests.two_period_models.model import ( + budget_dcegm_exog_ltc, + budget_dcegm_exog_ltc_and_job_offer, + prob_exog_job_offer, + prob_exog_ltc, +) from toy_models.consumption_retirement_model.state_space_objects import ( create_state_space_function_dict, ) from toy_models.consumption_retirement_model.utility_functions import ( create_final_period_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( create_utility_function_dict, ) -from tests.two_period_models.model import budget_dcegm_exog_ltc -from tests.two_period_models.model import budget_dcegm_exog_ltc_and_job_offer -from tests.two_period_models.model import prob_exog_job_offer -from tests.two_period_models.model import prob_exog_ltc - - # Obtain the test directory of the package TEST_DIR = Path(__file__).parent diff --git a/tests/test_budget_equation.py b/tests/test_budget_equation.py index ee717eb1..069740dd 100644 --- a/tests/test_budget_equation.py +++ b/tests/test_budget_equation.py @@ -2,14 +2,15 @@ import numpy as np import pytest -from dcegm.pre_processing.params import process_params from numpy.testing import assert_array_almost_equal as aaae from scipy.special import roots_sh_legendre from scipy.stats import norm + +from dcegm.pre_processing.params import process_params from toy_models.consumption_retirement_model.budget_functions import ( _calc_stochastic_income, + budget_constraint, ) -from toy_models.consumption_retirement_model.budget_functions import budget_constraint model = ["deaton", "retirement_taste_shocks", "retirement_no_taste_shocks"] labor_choice = [0, 1] diff --git a/tests/test_changing_choice_set.py b/tests/test_changing_choice_set.py index 9df20536..90e863c5 100644 --- a/tests/test_changing_choice_set.py +++ b/tests/test_changing_choice_set.py @@ -14,18 +14,15 @@ import jax.numpy as jnp import numpy as np import pytest + from dcegm.pre_processing.setup_model import setup_model from dcegm.solve import get_solve_function from toy_models.consumption_retirement_model.utility_functions import ( inverse_marginal_utility_crra, -) -from toy_models.consumption_retirement_model.utility_functions import ( marginal_utility_crra, -) -from toy_models.consumption_retirement_model.utility_functions import ( marginal_utility_final_consume_all, + utility_crra, ) -from toy_models.consumption_retirement_model.utility_functions import utility_crra # Obtain the test directory of the package TEST_DIR = Path(__file__).parent diff --git a/tests/test_exog_processes.py b/tests/test_exog_processes.py index f2c25987..16ca4fb0 100644 --- a/tests/test_exog_processes.py +++ b/tests/test_exog_processes.py @@ -2,23 +2,21 @@ import jax.numpy as jnp import numpy as np +from numpy.testing import assert_almost_equal as aaae + from dcegm.pre_processing.exog_processes import create_exog_state_mapping from dcegm.pre_processing.model_functions import process_model_functions from dcegm.pre_processing.state_space import create_state_space_and_choice_objects -from numpy.testing import assert_almost_equal as aaae +from tests.two_period_models.model import prob_exog_health from toy_models.consumption_retirement_model.budget_functions import budget_constraint from toy_models.consumption_retirement_model.state_space_objects import ( create_state_space_function_dict, ) from toy_models.consumption_retirement_model.utility_functions import ( create_final_period_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( create_utility_function_dict, ) -from tests.two_period_models.model import prob_exog_health - def trans_prob_care_demand(health_state, params): prob_care_demand = ( diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 56f3c235..598352ec 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -12,8 +12,8 @@ from numpy.testing import assert_allclose from scipy.interpolate import interp1d -from tests.utils.interpolations import linear_interpolation_with_extrapolation from tests.utils.interpolations import ( + linear_interpolation_with_extrapolation, linear_interpolation_with_inserting_missing_values, ) diff --git a/tests/test_numerical_integration.py b/tests/test_numerical_integration.py index 13295366..db3e777b 100644 --- a/tests/test_numerical_integration.py +++ b/tests/test_numerical_integration.py @@ -1,6 +1,7 @@ -from dcegm.numerical_integration import quadrature_hermite from numpy.testing import assert_allclose +from dcegm.numerical_integration import quadrature_hermite + def test_normal_distribution(): draws, weights = quadrature_hermite(20, 1) diff --git a/tests/test_pre_processing.py b/tests/test_pre_processing.py index ee0e58a2..3f215917 100644 --- a/tests/test_pre_processing.py +++ b/tests/test_pre_processing.py @@ -1,23 +1,22 @@ import jax.numpy as jnp import numpy as np import pytest +from jax import vmap + from dcegm.pre_processing.params import process_params -from dcegm.pre_processing.setup_model import load_and_setup_model -from dcegm.pre_processing.setup_model import setup_and_save_model -from dcegm.pre_processing.setup_model import setup_model +from dcegm.pre_processing.setup_model import ( + load_and_setup_model, + setup_and_save_model, + setup_model, +) from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options -from jax import vmap from toy_models.consumption_retirement_model.budget_functions import budget_constraint from toy_models.consumption_retirement_model.state_space_objects import ( create_state_space_function_dict, ) from toy_models.consumption_retirement_model.utility_functions import ( create_final_period_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( create_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( utiility_log_crra, ) diff --git a/tests/test_replication.py b/tests/test_replication.py index 75b05d71..161bba50 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -3,29 +3,25 @@ import jax.numpy as jnp import pytest +from numpy.testing import assert_array_almost_equal as aaae + from dcegm.pre_processing.setup_model import setup_model from dcegm.solve import solve_dcegm -from numpy.testing import assert_array_almost_equal as aaae +from tests.utils.interpolations import ( + interpolate_policy_and_value_on_wealth_grid, + linear_interpolation_with_extrapolation, +) from toy_models.consumption_retirement_model.budget_functions import budget_constraint from toy_models.consumption_retirement_model.state_space_objects import ( create_state_space_function_dict, ) from toy_models.consumption_retirement_model.utility_functions import ( create_final_period_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( create_utility_function_dict, -) -from toy_models.consumption_retirement_model.utility_functions import ( utiility_log_crra, -) -from toy_models.consumption_retirement_model.utility_functions import ( utiility_log_crra_final_consume_all, ) -from tests.utils.interpolations import interpolate_policy_and_value_on_wealth_grid -from tests.utils.interpolations import linear_interpolation_with_extrapolation - # Obtain the test directory of the package TEST_DIR = Path(__file__).parent diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 2a64e6de..55503ddf 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -6,12 +6,15 @@ import jax.numpy as jnp import numpy as np import pytest -from dcegm.simulation.sim_utils import create_simulation_df -from dcegm.simulation.simulate import simulate_all_periods -from dcegm.simulation.simulate import simulate_final_period -from dcegm.simulation.simulate import simulate_single_period from numpy.testing import assert_array_almost_equal as aaae +from dcegm.simulation.sim_utils import create_simulation_df +from dcegm.simulation.simulate import ( + simulate_all_periods, + simulate_final_period, + simulate_single_period, +) + def _create_test_objects_from_df(df, params): _cond = [df["choice"] == 0, df["choice"] == 1] diff --git a/tests/test_state_space.py b/tests/test_state_space.py index caf92ce7..87ed92f2 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import numpy as np import pytest + from dcegm.pre_processing.debugging import inspect_state_space from dcegm.pre_processing.state_space import create_state_space from toy_models.consumption_retirement_model.state_space_objects import ( diff --git a/tests/test_two_period_model.py b/tests/test_two_period_model.py index 037a7514..33573f1a 100644 --- a/tests/test_two_period_model.py +++ b/tests/test_two_period_model.py @@ -12,11 +12,12 @@ from scipy.special import roots_sh_legendre from scipy.stats import norm -from tests.two_period_models.euler_equation import euler_rhs_exog_ltc -from tests.two_period_models.euler_equation import euler_rhs_exog_ltc_and_job_offer +from tests.two_period_models.euler_equation import ( + euler_rhs_exog_ltc, + euler_rhs_exog_ltc_and_job_offer, +) from tests.two_period_models.model import marginal_utility - WEALTH_GRID_POINTS = 100 ALL_WEALTH_GRIDS = list(range(WEALTH_GRID_POINTS)) RANDOM_TEST_SET_LTC = np.random.choice(ALL_WEALTH_GRIDS, size=10, replace=False) diff --git a/tests/two_period_models/euler_equation.py b/tests/two_period_models/euler_equation.py index f83845a3..f01f6deb 100644 --- a/tests/two_period_models/euler_equation.py +++ b/tests/two_period_models/euler_equation.py @@ -1,7 +1,6 @@ import numpy as np -from tests.two_period_models.model import flow_utility -from tests.two_period_models.model import marginal_utility +from tests.two_period_models.model import flow_utility, marginal_utility def prob_long_term_care_patient(params, lagged_bad_health, bad_health): diff --git a/tests/two_period_models/model.py b/tests/two_period_models/model.py index 22e5fb38..c84f7a08 100644 --- a/tests/two_period_models/model.py +++ b/tests/two_period_models/model.py @@ -1,6 +1,5 @@ from jax import numpy as jnp - # ===================================================================================== # Utility functions # ===================================================================================== diff --git a/tests/utils/interpolations.py b/tests/utils/interpolations.py index 8c7cb705..f97f712f 100644 --- a/tests/utils/interpolations.py +++ b/tests/utils/interpolations.py @@ -1,8 +1,8 @@ import numpy as np -from dcegm.interpolation import get_index_high_and_low -from dcegm.interpolation import linear_interpolation_formula from jax import numpy as jnp +from dcegm.interpolation import get_index_high_and_low, linear_interpolation_formula + def linear_interpolation_with_extrapolation(x, y, x_new): """Linear interpolation with extrapolation.