Skip to content

Commit

Permalink
Use isort instead of reorder-python-imports
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell committed Jun 4, 2024
1 parent 2e4c9b7 commit 08d4120
Show file tree
Hide file tree
Showing 34 changed files with 111 additions and 107 deletions.
15 changes: 11 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,19 @@ repos:
- id: python-no-log-warn
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.13.0
# - repo: https://github.com/asottile/reorder-python-imports
# rev: v3.13.0
# hooks:
# - id: reorder-python-imports
# args:
# - --py37-plus
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: reorder-python-imports
- id: isort
name: isort
args:
- --py37-plus
- --profile=black
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.5.0
hooks:
Expand Down
7 changes: 3 additions & 4 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Callable, Dict, Tuple

from dcegm.interpolation import interp_value_and_policy_on_wealth
from jax import numpy as jnp
from jax import vmap

from dcegm.interpolation import interp_value_and_policy_on_wealth


def interpolate_value_and_marg_utility_on_next_period_wealth(
compute_marginal_utility: Callable,
Expand Down
4 changes: 1 addition & 3 deletions src/dcegm/egm/solve_euler_equation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Auxiliary functions for the EGM algorithm."""

from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Callable, Dict, Tuple

import numpy as np
from jax import numpy as jnp
Expand Down
7 changes: 3 additions & 4 deletions src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Wrapper to solve the final period of the model."""

from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Callable, Dict, Tuple

import jax.numpy as jnp
from dcegm.solve_single_period import solve_for_interpolated_values
from jax import vmap

from dcegm.solve_single_period import solve_for_interpolated_values


def solve_last_two_periods(
resources_beginning_of_period: jnp.ndarray,
Expand Down
9 changes: 6 additions & 3 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import jax.numpy as jnp
from dcegm.interpolation import interp_policy_on_wealth
from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation import interp_value_on_wealth

from dcegm.interpolation import (
interp_policy_on_wealth,
interp_value_and_policy_on_wealth,
interp_value_on_wealth,
)


def policy_and_value_for_state_choice_vec(
Expand Down
4 changes: 1 addition & 3 deletions src/dcegm/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Callable, Dict, Tuple

import jax.numpy as jnp
from jax import numpy as jnp
Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""

from typing import Any
from typing import Dict
from typing import Any, Dict

import jax
import jax.numpy as jnp
import numpy as np

from dcegm.egm.aggregate_marginal_utility import (
calculate_choice_probs_and_unsqueezed_logsum,
)
Expand Down
3 changes: 1 addition & 2 deletions src/dcegm/numerical_integration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Tuple

import numpy as np
from scipy.special import roots_hermite
from scipy.special import roots_sh_legendre
from scipy.special import roots_hermite, roots_sh_legendre
from scipy.stats import norm


Expand Down
7 changes: 5 additions & 2 deletions src/dcegm/pre_processing/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import numpy as np
import pandas as pd
from dcegm.pre_processing.state_space import process_endog_state_specifications
from dcegm.pre_processing.state_space import process_exog_model_specifications

from dcegm.pre_processing.state_space import (
process_endog_state_specifications,
process_exog_model_specifications,
)


def inspect_state_space(
Expand Down
3 changes: 2 additions & 1 deletion src/dcegm/pre_processing/exog_processes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import partial
from typing import Callable

from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options
from jax import numpy as jnp

from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options


def create_exog_transition_function(options):
"""Create the exogenous process transition function.
Expand Down
6 changes: 3 additions & 3 deletions src/dcegm/pre_processing/model_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable
from typing import Dict
from typing import Callable, Dict

import jax.numpy as jnp
from upper_envelope.fues_jax.fues_jax import fast_upper_envelope_wrapper

from dcegm.pre_processing.exog_processes import create_exog_transition_function
from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options
from upper_envelope.fues_jax.fues_jax import fast_upper_envelope_wrapper


def process_model_functions(
Expand Down
3 changes: 1 addition & 2 deletions src/dcegm/pre_processing/params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict
from typing import Union
from typing import Dict, Union

import pandas as pd

Expand Down
10 changes: 6 additions & 4 deletions src/dcegm/pre_processing/setup_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pickle
from typing import Callable
from typing import Dict
from typing import Callable, Dict

import numpy as np

from dcegm.pre_processing.batches import create_batches_and_information
from dcegm.pre_processing.exog_processes import create_exog_state_mapping
from dcegm.pre_processing.model_functions import process_model_functions
from dcegm.pre_processing.state_space import check_options
from dcegm.pre_processing.state_space import create_state_space_and_choice_objects
from dcegm.pre_processing.state_space import (
check_options,
create_state_space_and_choice_objects,
)


def setup_model(
Expand Down
1 change: 1 addition & 0 deletions src/dcegm/pre_processing/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
import numpy as np

from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options


Expand Down
5 changes: 3 additions & 2 deletions src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import jax
import numpy as np
import pandas as pd
from jax import numpy as jnp
from jax import vmap

from dcegm.budget import calculate_resources_for_all_agents
from dcegm.interface import get_state_choice_index_per_state
from dcegm.interpolation import interp_value_and_policy_on_wealth
from jax import numpy as jnp
from jax import vmap


def interpolate_policy_and_value_for_all_agents(
Expand Down
15 changes: 9 additions & 6 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import jax
import jax.numpy as jnp
import numpy as np
from dcegm.interface import get_state_choice_index_per_state
from dcegm.simulation.sim_utils import compute_final_utility_for_each_choice
from dcegm.simulation.sim_utils import draw_taste_shocks
from dcegm.simulation.sim_utils import interpolate_policy_and_value_for_all_agents
from dcegm.simulation.sim_utils import transition_to_next_period
from dcegm.simulation.sim_utils import vectorized_utility
from jax import vmap

from dcegm.interface import get_state_choice_index_per_state
from dcegm.simulation.sim_utils import (
compute_final_utility_for_each_choice,
draw_taste_shocks,
interpolate_policy_and_value_for_all_agents,
transition_to_next_period,
vectorized_utility,
)


def simulate_all_periods(
states_initial,
Expand Down
8 changes: 3 additions & 5 deletions src/dcegm/solve.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""Interface for the DC-EGM algorithm."""

from functools import partial
from typing import Any
from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Any, Callable, Dict, Tuple

import jax.lax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit

from dcegm.budget import calculate_resources
from dcegm.final_periods import solve_last_two_periods
from dcegm.numerical_integration import quadrature_legendre
from dcegm.pre_processing.params import process_params
from dcegm.pre_processing.setup_model import setup_model
from dcegm.solve_single_period import solve_single_period
from jax import jit


def solve_dcegm(
Expand Down
3 changes: 2 additions & 1 deletion src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from jax import vmap

from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values
from dcegm.egm.interpolate_marginal_utility import (
interpolate_value_and_marg_utility_on_next_period_wealth,
)
from dcegm.egm.solve_euler_equation import (
calculate_candidate_solutions_from_euler_equation,
)
from jax import vmap


def solve_single_period(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any
from typing import Dict
from typing import Any, Dict

import jax
import jax.numpy as jnp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any
from typing import Dict
from typing import Any, Dict

import jax.numpy as jnp

Expand Down
15 changes: 7 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@
import pandas as pd
import pytest
import yaml

from dcegm.pre_processing.setup_model import setup_model
from dcegm.solve import solve_dcegm
from tests.two_period_models.model import (
budget_dcegm_exog_ltc,
budget_dcegm_exog_ltc_and_job_offer,
prob_exog_job_offer,
prob_exog_ltc,
)
from toy_models.consumption_retirement_model.state_space_objects import (
create_state_space_function_dict,
)
from toy_models.consumption_retirement_model.utility_functions import (
create_final_period_utility_function_dict,
)
from toy_models.consumption_retirement_model.utility_functions import (
create_utility_function_dict,
)

from tests.two_period_models.model import budget_dcegm_exog_ltc
from tests.two_period_models.model import budget_dcegm_exog_ltc_and_job_offer
from tests.two_period_models.model import prob_exog_job_offer
from tests.two_period_models.model import prob_exog_ltc


# Obtain the test directory of the package
TEST_DIR = Path(__file__).parent

Expand Down
5 changes: 3 additions & 2 deletions tests/test_budget_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import numpy as np
import pytest
from dcegm.pre_processing.params import process_params
from numpy.testing import assert_array_almost_equal as aaae
from scipy.special import roots_sh_legendre
from scipy.stats import norm

from dcegm.pre_processing.params import process_params
from toy_models.consumption_retirement_model.budget_functions import (
_calc_stochastic_income,
budget_constraint,
)
from toy_models.consumption_retirement_model.budget_functions import budget_constraint

model = ["deaton", "retirement_taste_shocks", "retirement_no_taste_shocks"]
labor_choice = [0, 1]
Expand Down
7 changes: 2 additions & 5 deletions tests/test_changing_choice_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
import jax.numpy as jnp
import numpy as np
import pytest

from dcegm.pre_processing.setup_model import setup_model
from dcegm.solve import get_solve_function
from toy_models.consumption_retirement_model.utility_functions import (
inverse_marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
marginal_utility_final_consume_all,
utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import utility_crra

# Obtain the test directory of the package
TEST_DIR = Path(__file__).parent
Expand Down
8 changes: 3 additions & 5 deletions tests/test_exog_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@

import jax.numpy as jnp
import numpy as np
from numpy.testing import assert_almost_equal as aaae

from dcegm.pre_processing.exog_processes import create_exog_state_mapping
from dcegm.pre_processing.model_functions import process_model_functions
from dcegm.pre_processing.state_space import create_state_space_and_choice_objects
from numpy.testing import assert_almost_equal as aaae
from tests.two_period_models.model import prob_exog_health
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.state_space_objects import (
create_state_space_function_dict,
)
from toy_models.consumption_retirement_model.utility_functions import (
create_final_period_utility_function_dict,
)
from toy_models.consumption_retirement_model.utility_functions import (
create_utility_function_dict,
)

from tests.two_period_models.model import prob_exog_health


def trans_prob_care_demand(health_state, params):
prob_care_demand = (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from numpy.testing import assert_allclose
from scipy.interpolate import interp1d

from tests.utils.interpolations import linear_interpolation_with_extrapolation
from tests.utils.interpolations import (
linear_interpolation_with_extrapolation,
linear_interpolation_with_inserting_missing_values,
)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_numerical_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dcegm.numerical_integration import quadrature_hermite
from numpy.testing import assert_allclose

from dcegm.numerical_integration import quadrature_hermite


def test_normal_distribution():
draws, weights = quadrature_hermite(20, 1)
Expand Down
Loading

0 comments on commit 08d4120

Please sign in to comment.