Skip to content

Commit

Permalink
Merge pull request #1387 from Mv77/plumbing/more_param_types
Browse files Browse the repository at this point in the history
Add more parameter types to `Parameters`
  • Loading branch information
mnwhite authored Jun 28, 2024
2 parents 3faa570 + f50b08d commit 7f5091f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
4 changes: 3 additions & 1 deletion Documentation/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Release Date: TBD
#### Minor Changes

- Fixes bug in `AgentPopulation` that caused discretization of distributions to not work. [1275](https://github.com/econ-ark/HARK/pull/1275)
- Adds support for distributions, booleans, and callables as parameters in the `Parameters` class. [1387](https://github.com/econ-ark/HARK/pull/1387)

### 0.15.1

Expand Down Expand Up @@ -59,7 +60,8 @@ This release drops support for Python 3.8 and 3.9, consistent with SPEC 0, and a

#### Minor Changes

- Add option to pass pre-built grid to `LinearFast`. [#1388](https://github.com/econ-ark/HARK/pull/1388)

- Add option to pass pre-built grid to `LinearFast`. [1388](https://github.com/econ-ark/HARK/pull/1388)
- Moves calculation of stable points out of ConsIndShock solver, into method called by post_solve [#1349](https://github.com/econ-ark/HARK/pull/1349)
- Adds cubic spline interpolation and value function construction to "warm glow bequest" models.
- Fixes cubic spline interpolation for ConsMedShockModel.
Expand Down
19 changes: 10 additions & 9 deletions HARK/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
from copy import copy, deepcopy
from dataclasses import dataclass, field
from time import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from warnings import warn

import numpy as np
import pandas as pd
from xarray import DataArray

from HARK.distribution import (
Distribution,
IndexDistribution,
Expand All @@ -30,6 +28,7 @@
)
from HARK.parallel import multi_thread_commands, multi_thread_commands_fake
from HARK.utilities import NullFunc, get_arg_names
from xarray import DataArray

logging.basicConfig(format="%(message)s")
_log = logging.getLogger("HARK")
Expand Down Expand Up @@ -104,11 +103,11 @@ def __infer_dims__(
"""
Infers the age-varying dimensions of a parameter.
If the parameter is a scalar, numpy array, or None, it is assumed to be
invariant over time. If the parameter is a list or tuple, it is assumed
to be varying over time. If the parameter is a list or tuple of length
greater than 1, the length of the list or tuple must match the
`_term_age` attribute of the Parameters object.
If the parameter is a scalar, numpy array, boolean, distribution, callable or None,
it is assumed to be invariant over time. If the parameter is a list or
tuple, it is assumed to be varying over time. If the parameter is a list
or tuple of length greater than 1, the length of the list or tuple must match
the `_term_age` attribute of the Parameters object.
Parameters
----------
Expand All @@ -118,7 +117,9 @@ def __infer_dims__(
value of parameter
"""
if isinstance(value, (int, float, np.ndarray, type(None))):
if isinstance(
value, (int, float, np.ndarray, type(None), Distribution, bool, Callable)
):
self.__add_to_invariant__(key)
return value
if isinstance(value, (list, tuple)):
Expand Down
14 changes: 11 additions & 3 deletions HARK/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,20 @@ def test_create_agents(self):

class test_parameters(unittest.TestCase):
def setUp(self):
self.params = Parameters(T_cycle=3, a=1, b=[2, 3, 4], c=np.array([5, 6, 7]))
self.params = Parameters(
T_cycle=3,
a=1,
b=[2, 3, 4],
c=np.array([5, 6, 7]),
d=[lambda x: x, lambda x: x**2, lambda x: x**3],
e=Uniform(),
f=[True, False, True],
)

def test_init(self):
self.assertEqual(self.params._length, 3)
self.assertEqual(self.params._invariant_params, {"a", "c"})
self.assertEqual(self.params._varying_params, {"b"})
self.assertEqual(self.params._invariant_params, {"a", "c", "e"})
self.assertEqual(self.params._varying_params, {"b", "d", "f"})

def test_getitem(self):
self.assertEqual(self.params["a"], 1)
Expand Down

0 comments on commit 7f5091f

Please sign in to comment.