diff --git a/Documentation/CHANGELOG.md b/Documentation/CHANGELOG.md index e2593eca6..978d117bf 100644 --- a/Documentation/CHANGELOG.md +++ b/Documentation/CHANGELOG.md @@ -21,6 +21,7 @@ Release Date: TBD #### Minor Changes - Fixes bug in `AgentPopulation` that caused discretization of distributions to not work. [1275](https://github.com/econ-ark/HARK/pull/1275) +- Adds support for distributions, booleans, and callables as parameters in the `Parameters` class. [1387](https://github.com/econ-ark/HARK/pull/1387) ### 0.15.1 @@ -59,7 +60,8 @@ This release drops support for Python 3.8 and 3.9, consistent with SPEC 0, and a #### Minor Changes -- Add option to pass pre-built grid to `LinearFast`. [#1388](https://github.com/econ-ark/HARK/pull/1388) + +- Add option to pass pre-built grid to `LinearFast`. [1388](https://github.com/econ-ark/HARK/pull/1388) - Moves calculation of stable points out of ConsIndShock solver, into method called by post_solve [#1349](https://github.com/econ-ark/HARK/pull/1349) - Adds cubic spline interpolation and value function construction to "warm glow bequest" models. - Fixes cubic spline interpolation for ConsMedShockModel. diff --git a/HARK/core.py b/HARK/core.py index e1da9e3eb..9550cfdc1 100644 --- a/HARK/core.py +++ b/HARK/core.py @@ -15,13 +15,11 @@ from copy import copy, deepcopy from dataclasses import dataclass, field from time import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from warnings import warn import numpy as np import pandas as pd -from xarray import DataArray - from HARK.distribution import ( Distribution, IndexDistribution, @@ -30,6 +28,7 @@ ) from HARK.parallel import multi_thread_commands, multi_thread_commands_fake from HARK.utilities import NullFunc, get_arg_names +from xarray import DataArray logging.basicConfig(format="%(message)s") _log = logging.getLogger("HARK") @@ -104,11 +103,11 @@ def __infer_dims__( """ Infers the age-varying dimensions of a parameter. - If the parameter is a scalar, numpy array, or None, it is assumed to be - invariant over time. If the parameter is a list or tuple, it is assumed - to be varying over time. If the parameter is a list or tuple of length - greater than 1, the length of the list or tuple must match the - `_term_age` attribute of the Parameters object. + If the parameter is a scalar, numpy array, boolean, distribution, callable or None, + it is assumed to be invariant over time. If the parameter is a list or + tuple, it is assumed to be varying over time. If the parameter is a list + or tuple of length greater than 1, the length of the list or tuple must match + the `_term_age` attribute of the Parameters object. Parameters ---------- @@ -118,7 +117,9 @@ def __infer_dims__( value of parameter """ - if isinstance(value, (int, float, np.ndarray, type(None))): + if isinstance( + value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) + ): self.__add_to_invariant__(key) return value if isinstance(value, (list, tuple)): diff --git a/HARK/tests/test_core.py b/HARK/tests/test_core.py index cff351200..102d23deb 100644 --- a/HARK/tests/test_core.py +++ b/HARK/tests/test_core.py @@ -184,12 +184,20 @@ def test_create_agents(self): class test_parameters(unittest.TestCase): def setUp(self): - self.params = Parameters(T_cycle=3, a=1, b=[2, 3, 4], c=np.array([5, 6, 7])) + self.params = Parameters( + T_cycle=3, + a=1, + b=[2, 3, 4], + c=np.array([5, 6, 7]), + d=[lambda x: x, lambda x: x**2, lambda x: x**3], + e=Uniform(), + f=[True, False, True], + ) def test_init(self): self.assertEqual(self.params._length, 3) - self.assertEqual(self.params._invariant_params, {"a", "c"}) - self.assertEqual(self.params._varying_params, {"b"}) + self.assertEqual(self.params._invariant_params, {"a", "c", "e"}) + self.assertEqual(self.params._varying_params, {"b", "d", "f"}) def test_getitem(self): self.assertEqual(self.params["a"], 1)