Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alanlujan91 committed Aug 15, 2023
1 parent 03c49fd commit 38ccc75
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 53 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@ exclude: Documentation/example_notebooks/

repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.14.5
rev: v1.15.0
hooks:
- id: jupytext
args: [--sync, --set-formats, "ipynb", --pipe, black, --execute]
additional_dependencies: [jupytext, black, nbconvert]
files: ^examples/.*\.ipynb$

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
exclude: ^examples/

- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
rev: v3.10.1
hooks:
- id: pyupgrade
args: ["--py38-plus"]
exclude: ^examples/

- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
rev: 1.15.0
hooks:
- id: blacken-docs
exclude: ^examples/
Expand All @@ -37,7 +37,7 @@ repos:
exclude: ^examples/

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.9-for-vscode
rev: v3.0.1
hooks:
- id: prettier
exclude: ^examples/
Expand Down
96 changes: 51 additions & 45 deletions HARK/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
problem by finding a general equilibrium dynamic rule.
"""
import sys
from collections import namedtuple
from collections import defaultdict, namedtuple
from copy import copy, deepcopy
from dataclasses import dataclass, field
from time import time
Expand Down Expand Up @@ -37,15 +37,15 @@ class Parameters:
Attributes
----------
_term_age : int
_length : int
The terminal age of the agents in the model.
_age_inv : list
_invariant_params : list
A list of the names of the parameters that are invariant over time.
_age_var : list
_varying_params : list
A list of the names of the parameters that vary over time.
"""

def __init__(self, **parameters):
def __init__(self, **parameters: Any):
"""
Initializes a Parameters object and parses the age-varying
dynamics of the parameters.
Expand All @@ -58,15 +58,17 @@ def __init__(self, **parameters):
To parse a dictionary of parameters, use the ** operator.
"""
params = parameters.copy()
self._term_age = params.pop("T_cycle", None)
self._age_inv = set()
self._age_var = set()
self._parameters = {}
self._length = params.pop("T_cycle", None)
self._invariant_params = set()
self._varying_params = set()
self._parameters: Dict[str, Union[int, float, np.ndarray, list, tuple]] = {}

for key, value in params.items():
self._parameters[key] = self.__infer_dims__(key, value)

def __infer_dims__(self, key, value):
def __infer_dims__(
self, key: str, value: Union[int, float, np.ndarray, list, tuple, None]
) -> Union[int, float, np.ndarray, list, tuple]:
"""
Infers the age-varying dimensions of a parameter.
Expand All @@ -85,55 +87,61 @@ def __infer_dims__(self, key, value):
"""
if isinstance(value, (int, float, np.ndarray, type(None))):
self.__add_to_time_inv(key)
self.__add_to_invariant__(key)
return value
if isinstance(value, (list, tuple)):
if len(value) == 1:
self.__add_to_time_inv(key)
self.__add_to_invariant__(key)
return value[0]
if self._term_age is None or self._term_age == 1:
self._term_age = len(value)
if len(value) == self._term_age:
self.__add_to_time_vary(key)
if self._length is None or self._length == 1:
self._length = len(value)

Check warning on line 97 in HARK/core.py

View check run for this annotation

Codecov / codecov/patch

HARK/core.py#L97

Added line #L97 was not covered by tests
if len(value) == self._length:
self.__add_to_varying__(key)
return value
raise ValueError(f"Parameter {key} must be of length 1 or {self._term_age}")
raise ValueError(f"Parameter {key} has type {type(value)}")
raise ValueError(

Check warning on line 101 in HARK/core.py

View check run for this annotation

Codecov / codecov/patch

HARK/core.py#L101

Added line #L101 was not covered by tests
f"Parameter {key} must be of length 1 or {self._length}, not {len(value)}"
)
raise ValueError(f"Parameter {key} has unsupported type {type(value)}")

def __add_to_time_inv(self, key):
def __add_to_invariant__(self, key: str):
"""
Adds parameter name to invariant set and removes from varying set.
"""
self._age_var.discard(key)
self._age_inv.add(key)
self._varying_params.discard(key)
self._invariant_params.add(key)

def __add_to_time_vary(self, key):
def __add_to_varying__(self, key: str):
"""
Adds parameter name to varying set and removes from invariant set.
"""
self._age_inv.discard(key)
self._age_var.add(key)
self._invariant_params.discard(key)
self._varying_params.add(key)

def __getitem__(self, age_or_key):
def __getitem__(self, item_or_key: Union[int, str]):
"""
If age_or_key is an integer, returns a Parameters object with the parameters
If item_or_key is an integer, returns a Parameters object with the parameters
that apply to that age. This includes all invariant parameters and the
`age_or_key`th element of all age-varying parameters. If age_or_key is a string,
`item_or_key`th element of all age-varying parameters. If item_or_key is a string,
it returns the value of the parameter with that name.
"""
if isinstance(age_or_key, int):
if age_or_key >= self._term_age:
raise ValueError("Age is greater than or equal to terminal age.")
if isinstance(item_or_key, int):
if item_or_key >= self._length:
raise ValueError(
f"Age {item_or_key} is greater than or equal to terminal age {self._length}."
)

params = {key: self._parameters[key] for key in self._age_inv}
params = {key: self._parameters[key] for key in self._invariant_params}
params.update(
{key: self._parameters[key][age_or_key] for key in self._age_var}
{
key: self._parameters[key][item_or_key]
for key in self._varying_params
}
)
return Parameters(**params)
elif isinstance(item_or_key, str):
return self._parameters[item_or_key]

elif isinstance(age_or_key, str):
return self._parameters[age_or_key]

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any):
"""
Sets the value of a parameter.
Expand All @@ -147,25 +155,25 @@ def __setitem__(self, key, value):
"""
if not isinstance(key, str):
raise ValueError("Parameters must be set with a string key")

Check warning on line 157 in HARK/core.py

View check run for this annotation

Codecov / codecov/patch

HARK/core.py#L157

Added line #L157 was not covered by tests
self._parameters[key] = value
self._parameters[key] = self.__infer_dims__(key, value)

def keys(self):
"""
Returns a list of the names of the parameters.
"""
return self._age_inv + self._age_var
return self._invariant_params | self._varying_params

def values(self):
"""
Returns a list of the values of the parameters.
"""
return [self._parameters[key] for key in self.keys()]
return list(self._parameters.values())

def items(self):
"""
Returns a list of tuples of the form (name, value) for each parameter.
"""
return [(key, self._parameters[key]) for key in self.keys()]
return list(self._parameters.items())

def __iter__(self):
"""
Expand Down Expand Up @@ -202,11 +210,9 @@ def update(self, other_params):
Parameters object or dictionary of parameters to update with.
"""
if isinstance(other_params, Parameters):
for key, value in other_params:
self._parameters[key] = value
self._parameters.update(other_params.to_dict())
elif isinstance(other_params, dict):
for key, value in other_params.items():
self._parameters[key] = value
self._parameters.update(other_params)
else:
raise ValueError("Parameters must be a dict or a Parameters object")

Check warning on line 217 in HARK/core.py

View check run for this annotation

Codecov / codecov/patch

HARK/core.py#L217

Added line #L217 was not covered by tests

Expand All @@ -220,7 +226,7 @@ def __repr__(self):
"""
Returns a detailed string representation of the Parameters object.
"""
return f"Parameters( _age_inv = {self._age_inv}, _age_var = {self._age_var}, | {self.to_dict()})"
return f"Parameters( _age_inv = {self._invariant_params}, _age_var = {self._varying_params}, | {self.to_dict()})"

Check warning on line 229 in HARK/core.py

View check run for this annotation

Codecov / codecov/patch

HARK/core.py#L229

Added line #L229 was not covered by tests


class Model:
Expand Down
77 changes: 74 additions & 3 deletions HARK/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import numpy as np
import pytest

from HARK.ConsumptionSaving.ConsIndShockModel import (
IndShockConsumerType,
Expand Down Expand Up @@ -175,9 +176,9 @@ def setUp(self):
self.params = Parameters(T_cycle=3, a=1, b=[2, 3, 4], c=np.array([5, 6, 7]))

def test_init(self):
self.assertEqual(self.params._term_age, 3)
self.assertEqual(self.params._age_inv, {"a", "c"})
self.assertEqual(self.params._age_var, {"b"})
self.assertEqual(self.params._length, 3)
self.assertEqual(self.params._invariant_params, {"a", "c"})
self.assertEqual(self.params._varying_params, {"b"})

def test_getitem(self):
self.assertEqual(self.params["a"], 1)
Expand All @@ -192,3 +193,73 @@ def test_update(self):
self.params.update({"a": 9, "b": [10, 11, 12]})
self.assertEqual(self.params["a"], 9)
self.assertEqual(self.params[0]["b"], 10)

def test_initialization(self):
params = Parameters(a=1, b=[1, 2], T_cycle=2)
assert params._length == 2
assert params._invariant_params == {"a"}
assert params._varying_params == {"b"}

def test_infer_dims_scalar(self):
params = Parameters(a=1)
assert params["a"] == 1

def test_infer_dims_array(self):
params = Parameters(b=np.array([1, 2]))
assert all(params["b"] == np.array([1, 2]))

def test_infer_dims_list_varying(self):
params = Parameters(b=[1, 2], T_cycle=2)
assert params["b"] == [1, 2]

def test_infer_dims_list_invariant(self):
params = Parameters(b=[1])
assert params["b"] == 1

def test_setitem(self):
params = Parameters(a=1)
params["b"] = 2
assert params["b"] == 2

def test_keys_values_items(self):
params = Parameters(a=1, b=2)
assert set(params.keys()) == {"a", "b"}
assert set(params.values()) == {1, 2}
assert set(params.items()) == {("a", 1), ("b", 2)}

def test_to_dict(self):
params = Parameters(a=1, b=2)
assert params.to_dict() == {"a": 1, "b": 2}

def test_to_namedtuple(self):
params = Parameters(a=1, b=2)
named_tuple = params.to_namedtuple()
assert named_tuple.a == 1
assert named_tuple.b == 2

def test_update_params(self):
params1 = Parameters(a=1, b=2)
params2 = Parameters(a=3, c=4)
params1.update(params2)
assert params1["a"] == 3
assert params1["c"] == 4

def test_unsupported_type_error(self):
with pytest.raises(ValueError):
Parameters(b={1, 2})

def test_get_item_dimension_error(self):
params = Parameters(b=[1, 2], T_cycle=2)
with pytest.raises(ValueError):
params[2]

def test_getitem_with_key(self):
params = Parameters(a=1, b=[2, 3], T_cycle=2)
assert params["a"] == 1
assert params["b"] == [2, 3]

def test_getitem_with_item(self):
params = Parameters(a=1, b=[2, 3], T_cycle=2)
age_params = params[1]
assert age_params["a"] == 1
assert age_params["b"] == 3

0 comments on commit 38ccc75

Please sign in to comment.