Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more parameter types to Parameters #1387

Merged
merged 10 commits into from
Jun 28, 2024
4 changes: 3 additions & 1 deletion Documentation/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Release Date: TBD
#### Minor Changes

- Fixes bug in `AgentPopulation` that caused discretization of distributions to not work. [1275](https://github.com/econ-ark/HARK/pull/1275)
- Adds support for distributions, booleans, and callables as parameters in the `Parameters` class. [1387](https://github.com/econ-ark/HARK/pull/1387)

### 0.15.1

Expand Down Expand Up @@ -59,7 +60,8 @@ This release drops support for Python 3.8 and 3.9, consistent with SPEC 0, and a

#### Minor Changes

- Add option to pass pre-built grid to `LinearFast`. [#1388](https://github.com/econ-ark/HARK/pull/1388)

- Add option to pass pre-built grid to `LinearFast`. [1388](https://github.com/econ-ark/HARK/pull/1388)
- Moves calculation of stable points out of ConsIndShock solver, into method called by post_solve [#1349](https://github.com/econ-ark/HARK/pull/1349)
- Adds cubic spline interpolation and value function construction to "warm glow bequest" models.
- Fixes cubic spline interpolation for ConsMedShockModel.
Expand Down
19 changes: 10 additions & 9 deletions HARK/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
from copy import copy, deepcopy
from dataclasses import dataclass, field
from time import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from warnings import warn

import numpy as np
import pandas as pd
from xarray import DataArray

from HARK.distribution import (
Distribution,
IndexDistribution,
Expand All @@ -30,6 +28,7 @@
)
from HARK.parallel import multi_thread_commands, multi_thread_commands_fake
from HARK.utilities import NullFunc, get_arg_names
from xarray import DataArray

logging.basicConfig(format="%(message)s")
_log = logging.getLogger("HARK")
Expand Down Expand Up @@ -104,11 +103,11 @@ def __infer_dims__(
"""
Infers the age-varying dimensions of a parameter.

If the parameter is a scalar, numpy array, or None, it is assumed to be
invariant over time. If the parameter is a list or tuple, it is assumed
to be varying over time. If the parameter is a list or tuple of length
greater than 1, the length of the list or tuple must match the
`_term_age` attribute of the Parameters object.
If the parameter is a scalar, numpy array, boolean, distribution, callable or None,
it is assumed to be invariant over time. If the parameter is a list or
tuple, it is assumed to be varying over time. If the parameter is a list
or tuple of length greater than 1, the length of the list or tuple must match
the `_term_age` attribute of the Parameters object.

Parameters
----------
Expand All @@ -118,7 +117,9 @@ def __infer_dims__(
value of parameter

"""
if isinstance(value, (int, float, np.ndarray, type(None))):
if isinstance(
value, (int, float, np.ndarray, type(None), Distribution, bool, Callable)
):
self.__add_to_invariant__(key)
return value
if isinstance(value, (list, tuple)):
Expand Down
14 changes: 11 additions & 3 deletions HARK/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,20 @@ def test_create_agents(self):

class test_parameters(unittest.TestCase):
def setUp(self):
self.params = Parameters(T_cycle=3, a=1, b=[2, 3, 4], c=np.array([5, 6, 7]))
self.params = Parameters(
T_cycle=3,
a=1,
b=[2, 3, 4],
c=np.array([5, 6, 7]),
d=[lambda x: x, lambda x: x**2, lambda x: x**3],
e=Uniform(),
f=[True, False, True],
)

def test_init(self):
self.assertEqual(self.params._length, 3)
self.assertEqual(self.params._invariant_params, {"a", "c"})
self.assertEqual(self.params._varying_params, {"b"})
self.assertEqual(self.params._invariant_params, {"a", "c", "e"})
self.assertEqual(self.params._varying_params, {"b", "d", "f"})

def test_getitem(self):
self.assertEqual(self.params["a"], 1)
Expand Down
Loading