diff --git a/HARK/distribution.py b/HARK/distribution.py index 22314c9cb..7c3d39321 100644 --- a/HARK/distribution.py +++ b/HARK/distribution.py @@ -1181,7 +1181,7 @@ def func_wrapper(x: np.ndarray, *args: Any) -> np.ndarray: if len(kwargs): f_query = func(self.dataset, **kwargs) - ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.pmv) + ldd = DiscreteDistributionLabeled.from_dataset(f_query, self.probability) return ldd diff --git a/HARK/tests/test_distribution.py b/HARK/tests/test_distribution.py index ced58fbf5..ec685eaa3 100644 --- a/HARK/tests/test_distribution.py +++ b/HARK/tests/test_distribution.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import xarray as xr from HARK.distribution import ( Bernoulli, @@ -603,3 +604,41 @@ def test_combine_labeled_dist(self): np.concatenate([de.expected(), abc.expected()]), ) ) + + +class labeled_transition_tests(unittest.TestCase): + def setUp(self) -> None: + return super().setUp() + + def test_expectation_transformation(self): + # Create a basic labeled distribution + base_dist = DiscreteDistributionLabeled( + pmv=np.array([0.5, 0.5]), + atoms=np.array([[1.0, 2.0], [3.0, 4.0]]), + var_names=["a", "b"], + ) + + # Define a transition function + def transition(shocks, state): + state_new = {} + state_new["m"] = state["m"] * shocks["a"] + state_new["n"] = state["n"] * shocks["b"] + return state_new + + m = xr.DataArray(np.linspace(0, 10, 11), name="m", dims=("grid",)) + n = xr.DataArray(np.linspace(0, -10, 11), name="n", dims=("grid",)) + state_grid = xr.Dataset({"m": m, "n": n}) + + # Evaluate labeled transformation + + # Direct expectation + exp1 = base_dist.expected(transition, state=state_grid) + # Expectation after transformation + new_state_dstn = base_dist.dist_of_func(transition, state=state_grid) + # TODO: needs a cluncky identity function with an extra argument because + # DDL.expected() behavior is very different with and without kwargs. + # Fix! + exp2 = new_state_dstn.expected(lambda x, unused: x, unused=0) + + assert np.all(exp1["m"] == exp2["m"]).item() + assert np.all(exp1["n"] == exp2["n"]).item()