diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edb768ec5..1acc0a54e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: Documentation/example_notebooks/ repos: - repo: https://github.com/mwouts/jupytext - rev: v1.14.5 + rev: v1.15.0 hooks: - id: jupytext args: [--sync, --set-formats, "ipynb", --pipe, black, --execute] @@ -10,20 +10,20 @@ repos: files: ^examples/.*\.ipynb$ - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black exclude: ^examples/ - repo: https://github.com/asottile/pyupgrade - rev: v3.4.0 + rev: v3.10.1 hooks: - id: pyupgrade args: ["--py38-plus"] exclude: ^examples/ - repo: https://github.com/asottile/blacken-docs - rev: 1.13.0 + rev: 1.15.0 hooks: - id: blacken-docs exclude: ^examples/ @@ -37,7 +37,7 @@ repos: exclude: ^examples/ - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.9-for-vscode + rev: v3.0.1 hooks: - id: prettier exclude: ^examples/ diff --git a/HARK/core.py b/HARK/core.py index 2c13d22fc..6ae9623da 100644 --- a/HARK/core.py +++ b/HARK/core.py @@ -7,7 +7,7 @@ problem by finding a general equilibrium dynamic rule. """ import sys -from collections import namedtuple +from collections import defaultdict, namedtuple from copy import copy, deepcopy from dataclasses import dataclass, field from time import time @@ -37,15 +37,15 @@ class Parameters: Attributes ---------- - _term_age : int + _length : int The terminal age of the agents in the model. - _age_inv : list + _invariant_params : list A list of the names of the parameters that are invariant over time. - _age_var : list + _varying_params : list A list of the names of the parameters that vary over time. """ - def __init__(self, **parameters): + def __init__(self, **parameters: Any): """ Initializes a Parameters object and parses the age-varying dynamics of the parameters. @@ -58,15 +58,17 @@ def __init__(self, **parameters): To parse a dictionary of parameters, use the ** operator. """ params = parameters.copy() - self._term_age = params.pop("T_cycle", None) - self._age_inv = set() - self._age_var = set() - self._parameters = {} + self._length = params.pop("T_cycle", None) + self._invariant_params = set() + self._varying_params = set() + self._parameters: Dict[str, Union[int, float, np.ndarray, list, tuple]] = {} for key, value in params.items(): self._parameters[key] = self.__infer_dims__(key, value) - def __infer_dims__(self, key, value): + def __infer_dims__( + self, key: str, value: Union[int, float, np.ndarray, list, tuple, None] + ) -> Union[int, float, np.ndarray, list, tuple]: """ Infers the age-varying dimensions of a parameter. @@ -85,55 +87,61 @@ def __infer_dims__(self, key, value): """ if isinstance(value, (int, float, np.ndarray, type(None))): - self.__add_to_time_inv(key) + self.__add_to_invariant__(key) return value if isinstance(value, (list, tuple)): if len(value) == 1: - self.__add_to_time_inv(key) + self.__add_to_invariant__(key) return value[0] - if self._term_age is None or self._term_age == 1: - self._term_age = len(value) - if len(value) == self._term_age: - self.__add_to_time_vary(key) + if self._length is None or self._length == 1: + self._length = len(value) + if len(value) == self._length: + self.__add_to_varying__(key) return value - raise ValueError(f"Parameter {key} must be of length 1 or {self._term_age}") - raise ValueError(f"Parameter {key} has type {type(value)}") + raise ValueError( + f"Parameter {key} must be of length 1 or {self._length}, not {len(value)}" + ) + raise ValueError(f"Parameter {key} has unsupported type {type(value)}") - def __add_to_time_inv(self, key): + def __add_to_invariant__(self, key: str): """ Adds parameter name to invariant set and removes from varying set. """ - self._age_var.discard(key) - self._age_inv.add(key) + self._varying_params.discard(key) + self._invariant_params.add(key) - def __add_to_time_vary(self, key): + def __add_to_varying__(self, key: str): """ Adds parameter name to varying set and removes from invariant set. """ - self._age_inv.discard(key) - self._age_var.add(key) + self._invariant_params.discard(key) + self._varying_params.add(key) - def __getitem__(self, age_or_key): + def __getitem__(self, item_or_key: Union[int, str]): """ - If age_or_key is an integer, returns a Parameters object with the parameters + If item_or_key is an integer, returns a Parameters object with the parameters that apply to that age. This includes all invariant parameters and the - `age_or_key`th element of all age-varying parameters. If age_or_key is a string, + `item_or_key`th element of all age-varying parameters. If item_or_key is a string, it returns the value of the parameter with that name. """ - if isinstance(age_or_key, int): - if age_or_key >= self._term_age: - raise ValueError("Age is greater than or equal to terminal age.") + if isinstance(item_or_key, int): + if item_or_key >= self._length: + raise ValueError( + f"Age {item_or_key} is greater than or equal to terminal age {self._length}." + ) - params = {key: self._parameters[key] for key in self._age_inv} + params = {key: self._parameters[key] for key in self._invariant_params} params.update( - {key: self._parameters[key][age_or_key] for key in self._age_var} + { + key: self._parameters[key][item_or_key] + for key in self._varying_params + } ) return Parameters(**params) + elif isinstance(item_or_key, str): + return self._parameters[item_or_key] - elif isinstance(age_or_key, str): - return self._parameters[age_or_key] - - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any): """ Sets the value of a parameter. @@ -147,25 +155,25 @@ def __setitem__(self, key, value): """ if not isinstance(key, str): raise ValueError("Parameters must be set with a string key") - self._parameters[key] = value + self._parameters[key] = self.__infer_dims__(key, value) def keys(self): """ Returns a list of the names of the parameters. """ - return self._age_inv + self._age_var + return self._invariant_params | self._varying_params def values(self): """ Returns a list of the values of the parameters. """ - return [self._parameters[key] for key in self.keys()] + return list(self._parameters.values()) def items(self): """ Returns a list of tuples of the form (name, value) for each parameter. """ - return [(key, self._parameters[key]) for key in self.keys()] + return list(self._parameters.items()) def __iter__(self): """ @@ -202,11 +210,9 @@ def update(self, other_params): Parameters object or dictionary of parameters to update with. """ if isinstance(other_params, Parameters): - for key, value in other_params: - self._parameters[key] = value + self._parameters.update(other_params.to_dict()) elif isinstance(other_params, dict): - for key, value in other_params.items(): - self._parameters[key] = value + self._parameters.update(other_params) else: raise ValueError("Parameters must be a dict or a Parameters object") @@ -220,7 +226,7 @@ def __repr__(self): """ Returns a detailed string representation of the Parameters object. """ - return f"Parameters( _age_inv = {self._age_inv}, _age_var = {self._age_var}, | {self.to_dict()})" + return f"Parameters( _age_inv = {self._invariant_params}, _age_var = {self._varying_params}, | {self.to_dict()})" class Model: diff --git a/HARK/tests/test_core.py b/HARK/tests/test_core.py index 87d91919d..0cda93381 100644 --- a/HARK/tests/test_core.py +++ b/HARK/tests/test_core.py @@ -4,6 +4,7 @@ import unittest import numpy as np +import pytest from HARK.ConsumptionSaving.ConsIndShockModel import ( IndShockConsumerType, @@ -175,9 +176,9 @@ def setUp(self): self.params = Parameters(T_cycle=3, a=1, b=[2, 3, 4], c=np.array([5, 6, 7])) def test_init(self): - self.assertEqual(self.params._term_age, 3) - self.assertEqual(self.params._age_inv, {"a", "c"}) - self.assertEqual(self.params._age_var, {"b"}) + self.assertEqual(self.params._length, 3) + self.assertEqual(self.params._invariant_params, {"a", "c"}) + self.assertEqual(self.params._varying_params, {"b"}) def test_getitem(self): self.assertEqual(self.params["a"], 1) @@ -192,3 +193,73 @@ def test_update(self): self.params.update({"a": 9, "b": [10, 11, 12]}) self.assertEqual(self.params["a"], 9) self.assertEqual(self.params[0]["b"], 10) + + def test_initialization(self): + params = Parameters(a=1, b=[1, 2], T_cycle=2) + assert params._length == 2 + assert params._invariant_params == {"a"} + assert params._varying_params == {"b"} + + def test_infer_dims_scalar(self): + params = Parameters(a=1) + assert params["a"] == 1 + + def test_infer_dims_array(self): + params = Parameters(b=np.array([1, 2])) + assert all(params["b"] == np.array([1, 2])) + + def test_infer_dims_list_varying(self): + params = Parameters(b=[1, 2], T_cycle=2) + assert params["b"] == [1, 2] + + def test_infer_dims_list_invariant(self): + params = Parameters(b=[1]) + assert params["b"] == 1 + + def test_setitem(self): + params = Parameters(a=1) + params["b"] = 2 + assert params["b"] == 2 + + def test_keys_values_items(self): + params = Parameters(a=1, b=2) + assert set(params.keys()) == {"a", "b"} + assert set(params.values()) == {1, 2} + assert set(params.items()) == {("a", 1), ("b", 2)} + + def test_to_dict(self): + params = Parameters(a=1, b=2) + assert params.to_dict() == {"a": 1, "b": 2} + + def test_to_namedtuple(self): + params = Parameters(a=1, b=2) + named_tuple = params.to_namedtuple() + assert named_tuple.a == 1 + assert named_tuple.b == 2 + + def test_update_params(self): + params1 = Parameters(a=1, b=2) + params2 = Parameters(a=3, c=4) + params1.update(params2) + assert params1["a"] == 3 + assert params1["c"] == 4 + + def test_unsupported_type_error(self): + with pytest.raises(ValueError): + Parameters(b={1, 2}) + + def test_get_item_dimension_error(self): + params = Parameters(b=[1, 2], T_cycle=2) + with pytest.raises(ValueError): + params[2] + + def test_getitem_with_key(self): + params = Parameters(a=1, b=[2, 3], T_cycle=2) + assert params["a"] == 1 + assert params["b"] == [2, 3] + + def test_getitem_with_item(self): + params = Parameters(a=1, b=[2, 3], T_cycle=2) + age_params = params[1] + assert age_params["a"] == 1 + assert age_params["b"] == 3