diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 2cef0d2a..3834d858 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple( ) ``` -The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: +The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: ```{python} # | label: new-model-def @@ -224,10 +224,9 @@ class InfFeedback(RandomVariable): inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength) - inf_feedback_strength = au.pad_x_to_match_y( + inf_feedback_strength, _ = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength[0], ) # Sampling inf feedback and adjusting the shape diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 9b8c4b21..f8ae8590 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -8,15 +8,16 @@ from jax.typing import ArrayLike -def pad_to_match( +def pad_edges_to_match( x: ArrayLike, y: ArrayLike, - fill_value: float = 0.0, + axis: int = 0, pad_direction: str = "end", fix_y: bool = False, ) -> tuple[ArrayLike, ArrayLike]: """ - Pad the shorter array at the start or end to match the length of the longer array. + Pad the shorter array at the start or end using the + edge values to match the length of the longer array. Parameters ---------- @@ -24,8 +25,8 @@ def pad_to_match( First array. y : ArrayLike Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. + axis : int, optional + Axis along which to add padding, by default 0 pad_direction : str, optional Direction to pad the shorter array, either "start" or "end", by default "end". fix_y : bool, optional @@ -38,64 +39,35 @@ def pad_to_match( """ x = jnp.atleast_1d(x) y = jnp.atleast_1d(y) - x_len = x.size - y_len = y.size + x_len = x.shape[axis] + y_len = y.shape[axis] pad_size = abs(x_len - y_len) + pad_width = [(0, 0)] * x.ndim - pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get( - pad_direction, None - ) - - if pad_width is None: + if pad_direction not in ["start", "end"]: raise ValueError( "pad_direction must be either 'start' or 'end'." f" Got {pad_direction}." ) + pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get( + pad_direction, None + ) + if x_len > y_len: if fix_y: raise ValueError( "Cannot fix y when x is longer than y." f" x_len: {x_len}, y_len: {y_len}." ) - y = jnp.pad(y, pad_width, constant_values=fill_value) + y = jnp.pad(y, pad_width, mode="edge") elif y_len > x_len: - x = jnp.pad(x, pad_width, constant_values=fill_value) + x = jnp.pad(x, pad_width, mode="edge") return x, y -def pad_x_to_match_y( - x: ArrayLike, - y: ArrayLike, - fill_value: float = 0.0, - pad_direction: str = "end", -) -> ArrayLike: - """ - Pad the `x` array at the start or end to match the length of the `y` array. - - Parameters - ---------- - x : ArrayLike - First array. - y : ArrayLike - Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. - pad_direction : str, optional - Direction to pad the shorter array, either "start" or "end", by default "end". - - Returns - ------- - Array - Padded array. - """ - return pad_to_match( - x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True - )[0] - - class PeriodicProcessSample(NamedTuple): """ A container for holding the output from `process.PeriodicProcess()`. diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 55b2227f..8b99853c 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -78,9 +78,12 @@ def _new_scanner( history_subset: ArrayLike, multiplier: float ) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08 new_val = transform( - multiplier * jnp.dot(array_to_convolve, history_subset) + multiplier + * jnp.einsum("i...,i...->...", array_to_convolve, history_subset) + ) + latest = jnp.concatenate( + [history_subset[1:], new_val[jnp.newaxis]], axis=0 ) - latest = jnp.hstack([history_subset[1:], new_val]) return latest, new_val return _new_scanner @@ -158,9 +161,13 @@ def _new_scanner( multipliers: tuple[float, float], ) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08 m1, m2 = multipliers - m_net1 = t1(m1 * jnp.dot(arr1, history_subset)) - new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset)) - latest = jnp.hstack([history_subset[1:], new_val]) + m_net1 = t1(m1 * jnp.einsum("i...,i...->...", arr1, history_subset)) + new_val = t2( + m2 * m_net1 * jnp.einsum("i...,i...->...", arr2, history_subset) + ) + latest = jnp.concatenate( + [history_subset[1:], new_val[jnp.newaxis]], axis=0 + ) return latest, (new_val, m_net1) return _new_scanner diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 887d0e5a..7a8b02a8 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -80,14 +80,21 @@ def sample( InfectionsSample Named tuple with "infections". """ - if I0.size < gen_int.size: + if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections vector must be at least as long as " "the generation interval. " - f"Initial infections vector length: {I0.size}, " + f"Initial infections vector length: {I0.shape[0]}, " f"generation interval length: {gen_int.size}." ) + if I0.shape[1:] != Rt.shape[1:]: + raise ValueError( + "Initial infections and Rt must have the same batch shapes. " + f"Got initial infections of batch shape {I0.shape[1:]} " + f"and Rt of batch shape {Rt.shape[1:]}." + ) + gen_int_rev = jnp.flip(gen_int) recent_I0 = I0[-gen_int_rev.size :] diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index e4b03594..b344a1b5 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -143,11 +143,18 @@ def sample( InfectionsWithFeedback Named tuple with "infections". """ - if I0.size < gen_int.size: + if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections must be at least as long as the " - f"generation interval. Got {I0.size} initial infections " - f"and {gen_int.size} generation interval." + f"generation interval. Got initial infections length {I0.shape[0]}" + f"and generation interval length {gen_int.size}." + ) + + if I0.shape[1:] != Rt.shape[1:]: + raise ValueError( + "Initial infections and Rt must have the same batch shapes. " + f"Got initial infections of batch shape {I0.shape[1:]} " + f"and Rt of batch shape {Rt.shape[1:]}." ) gen_int_rev = jnp.flip(gen_int) @@ -160,19 +167,23 @@ def sample( **kwargs, ) ) + + if inf_feedback_strength.ndim == Rt.ndim - 1: + inf_feedback_strength = inf_feedback_strength[jnp.newaxis] + # Making sure inf_feedback_strength spans the Rt length - if inf_feedback_strength.size == 1: - inf_feedback_strength = au.pad_x_to_match_y( + if inf_feedback_strength.shape[0] == 1: + inf_feedback_strength = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength[0], - ) - elif inf_feedback_strength.size != Rt.size: + axis=0, + )[0] + if inf_feedback_strength.shape != Rt.shape: raise ValueError( - "Infection feedback strength must be of size 1 " - "or the same size as the reproduction number array. " - f"Got {inf_feedback_strength.size} " - f"and {Rt.size} respectively." + "Infection feedback strength must be of length 1 " + "or the same length as the reproduction number array. " + f"Got {inf_feedback_strength.shape} " + f"and {Rt.shape} respectively." ) # Sampling inf feedback pmf diff --git a/test/test_arrayutils.py b/test/test_arrayutils.py index 9e168de1..0048fade 100644 --- a/test/test_arrayutils.py +++ b/test/test_arrayutils.py @@ -8,53 +8,72 @@ import pyrenew.arrayutils as au -def test_arrayutils_pad_to_match(): +def test_pad_edges_to_match(): """ - Verifies extension when required and error when `fix_y` is True. + Test function to verify padding along the edges for 1D and 2D arrays """ + # test when y gets padded x = jnp.array([1, 2, 3]) y = jnp.array([1, 2]) - x_pad, y_pad = au.pad_to_match(x, y) - + x_pad, y_pad = au.pad_edges_to_match(x, y) assert x_pad.size == y_pad.size - assert x_pad.size == 3 + assert y_pad[-1] == y[-1] + assert jnp.array_equal(x_pad, x) + # test when x gets padded x = jnp.array([1, 2]) y = jnp.array([1, 2, 3]) - x_pad, y_pad = au.pad_to_match(x, y) - + x_pad, y_pad = au.pad_edges_to_match(x, y) assert x_pad.size == y_pad.size - assert x_pad.size == 3 + assert x_pad[-1] == x[-1] + assert jnp.array_equal(y_pad, y) + # test when no padding required x = jnp.array([1, 2, 3]) - y = jnp.array([1, 2]) + y = jnp.array([4, 5, 6]) + x_pad, y_pad = au.pad_edges_to_match(x, y) - # Verify that the function raises an error when `fix_y` is True - with pytest.raises(ValueError): - x_pad, y_pad = au.pad_to_match(x, y, fix_y=True) + assert jnp.array_equal(x_pad, x) + assert jnp.array_equal(y_pad, y) # Verify function works with both padding directions - x_pad, y_pad = au.pad_to_match(x, y, pad_direction="start") + x = jnp.array([1, 2, 3]) + y = jnp.array([1, 2]) + + x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="start") assert x_pad.size == y_pad.size - assert x_pad.size == 3 + assert y_pad[0] == y[0] + assert jnp.array_equal(x_pad, x) + + # Verify that the function raises an error when `fix_y` is True + with pytest.raises( + ValueError, match="Cannot fix y when x is longer than y" + ): + x_pad, y_pad = au.pad_edges_to_match(x, y, fix_y=True) # Verify function raises an error when pad_direction is not "start" or "end" with pytest.raises(ValueError): - x_pad, y_pad = au.pad_to_match(x, y, pad_direction="middle") - - -def test_arrayutils_pad_x_to_match_y(): - """ - Verifies extension when required - """ - - x = jnp.array([1, 2]) - y = jnp.array([1, 2, 3]) - - x_pad = au.pad_x_to_match_y(x, y) - - assert x_pad.size == 3 + x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="middle") + + # test padding for 2D arrays + x = jnp.array([[1, 2], [3, 4]]) + y = jnp.array([[5, 6]]) + + # Padding along axis 0 + axis = 0 + x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end") + + assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) + assert jnp.array_equal(y_pad[-1], y[-1]) + assert jnp.array_equal(x_pad, x) + + # padding along axis 1 + axis = 1 + x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end") + assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) + assert jnp.array_equal(y[:, -1], y_pad[:, -1]) + assert jnp.array_equal(x_pad, x) diff --git a/test/test_convolve_scanners.py b/test/test_convolve_scanners.py index 12430004..63bdd6a1 100644 --- a/test/test_convolve_scanners.py +++ b/test/test_convolve_scanners.py @@ -6,21 +6,34 @@ import jax import jax.numpy as jnp import numpy as np +import pytest from numpy.testing import assert_array_equal import pyrenew.convolve as pc +import pyrenew.transformation as t -def test_double_scanner_reduces_to_single(): +@pytest.mark.parametrize( + ["inits", "to_scan_a", "multipliers"], + [ + [ + jnp.array([0.352, 5.2, -3]), + jnp.array([0.5, 0.3, 0.2]), + jnp.array(np.random.normal(0, 0.5, size=500)), + ], + [ + jnp.array(np.array([0.352, 5.2, -3] * 3).reshape(3, 3)), + jnp.array([0.5, 0.3, 0.2]), + jnp.array(np.random.normal(0, 0.5, size=(500, 3))), + ], + ], +) +def test_double_scanner_reduces_to_single(inits, to_scan_a, multipliers): """ Test that new_double_scanner() yields a function that is equivalent to a single scanner if the first scan is chosen appropriately """ - inits = jnp.array([0.352, 5.2, -3]) - to_scan_a = jnp.array([0.5, 0.3, 0.2]) - - multipliers = jnp.array(np.random.normal(0, 0.5, size=500)) def transform_a(x: any): """ @@ -42,7 +55,8 @@ def transform_a(x: any): scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a) double_scanner_a = pc.new_double_convolve_scanner( - (jnp.array([523, 2, -0.5233]), to_scan_a), (lambda x: 1, transform_a) + (jnp.array([523, 2, -0.5233]), to_scan_a), + (jnp.ones_like, transform_a), ) _, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers) @@ -53,3 +67,148 @@ def transform_a(x: any): assert_array_equal(result_a_double[1], jnp.ones_like(multipliers)) assert_array_equal(result_a_double[0], result_a) + + +@pytest.mark.parametrize( + ["arr", "history", "multipliers", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([3.0, 4.0]), + jnp.array([1, 2, 3]), + t.IdentityTransform(), + ], + [ + jnp.ones(3), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones((3, 3)), + t.ExpTransform(), + ], + ], +) +def test_convolve_scanner_using_scan(arr, history, multipliers, transform): + """ + Tests the output of new convolve scanner function + used with `jax.lax.scan` against values calculated + using a for loop + """ + scanner = pc.new_convolve_scanner(arr, transform) + + _, result = jax.lax.scan(f=scanner, init=history, xs=multipliers) + + result_not_scanned = [] + for multiplier in multipliers: + history, new_val = scanner(history, multiplier) + result_not_scanned.append(new_val) + + assert jnp.array_equal(result, result_not_scanned) + + +@pytest.mark.parametrize( + ["arr1", "arr2", "history", "m1", "m2", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([2.0, 1.0]), + jnp.array([0.1, 0.4]), + jnp.array([1, 2, 3]), + jnp.ones(3), + (t.IdentityTransform(), t.IdentityTransform()), + ], + [ + jnp.array([1.0, 2.0, 0.3]), + jnp.array([2.0, 1.0, 0.5]), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones((3, 3)), + jnp.ones((3, 3)), + (t.ExpTransform(), t.IdentityTransform()), + ], + ], +) +def test_double_convolve_scanner_using_scan( + arr1, arr2, history, m1, m2, transform +): + """ + Tests the output of new convolve double scanner function + used with `jax.lax.scan` against values calculated + using a for loop + """ + arr1 = jnp.array([1.0, 2.0]) + arr2 = jnp.array([2.0, 1.0]) + transform = (t.IdentityTransform(), t.IdentityTransform()) + history = jnp.array([0.1, 0.4]) + m1, m2 = (jnp.array([1, 2, 3]), jnp.ones(3)) + + scanner = pc.new_double_convolve_scanner((arr1, arr2), transform) + + _, result = jax.lax.scan(f=scanner, init=history, xs=(m1, m2)) + + res1, res2 = [], [] + for m1, m2 in zip(m1, m2): + history, new_val = scanner(history, (m1, m2)) + res1.append(new_val[0]) + res2.append(new_val[1]) + + assert jnp.array_equal(result, (res1, res2)) + + +@pytest.mark.parametrize( + ["arr", "history", "multiplier", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([3.0, 4.0]), + jnp.array(2), + t.IdentityTransform(), + ], + [ + jnp.ones(3), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones(3), + t.ExpTransform(), + ], + ], +) +def test_convolve_scanner(arr, history, multiplier, transform): + """ + Tests new convolve scanner function + """ + scanner = pc.new_convolve_scanner(arr, transform) + latest, new_val = scanner(history, multiplier) + assert jnp.array_equal( + new_val, transform(multiplier * jnp.dot(arr, history)) + ) + + +@pytest.mark.parametrize( + ["arr1", "arr2", "history", "m1", "m2", "transforms"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([2.0, 1.0]), + jnp.array([0.1, 0.4]), + jnp.array(1), + jnp.array(3), + (t.IdentityTransform(), t.IdentityTransform()), + ], + [ + jnp.array([1.0, 2.0, 0.3]), + jnp.array([2.0, 1.0, 0.5]), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones(3), + 0.1 * jnp.ones(3), + (t.ExpTransform(), t.IdentityTransform()), + ], + ], +) +def test_double_convolve_scanner(arr1, arr2, history, m1, m2, transforms): + """ + Tests new double convolve scanner function + """ + double_scanner = pc.new_double_convolve_scanner((arr1, arr2), transforms) + latest, (new_val, m_net) = double_scanner(history, (m1, m2)) + + assert jnp.array_equal(m_net, transforms[0](m1 * jnp.dot(arr1, history))) + assert jnp.array_equal( + new_val, transforms[1](m2 * m_net * jnp.dot(arr2, history)) + ) diff --git a/test/test_infection_functions.py b/test/test_infection_functions.py index f8e1a0b6..94a3ad5b 100644 --- a/test/test_infection_functions.py +++ b/test/test_infection_functions.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import pytest from numpy.testing import assert_array_equal from pyrenew.latent import infection_functions as inf @@ -54,6 +55,45 @@ def test_compute_infections_from_rt_with_feedback(): ) assert_array_equal(Rt_adj, Rt_raw) - pass - pass + return None + + +@pytest.mark.parametrize( + ["I0", "gen_int", "inf_pmf", "Rt_raw"], + [ + [ + jnp.array([[5.0, 0.2]]), + jnp.array([1.0]), + jnp.array([1.0]), + jnp.ones((5, 2)), + ], + [ + 3.5235 * jnp.ones((35, 3)), + jnp.ones(35) / 35, + jnp.ones(35), + jnp.zeros((253, 3)), + ], + ], +) +def test_compute_infections_from_rt_with_feedback_2d( + I0, gen_int, inf_pmf, Rt_raw +): + """ + Test implementation of infection feedback + when I0 and Rt are 2d arrays. + """ + ( + infs_feedback, + Rt_adj, + ) = inf.compute_infections_from_rt_with_feedback( + I0, Rt_raw, jnp.zeros_like(Rt_raw), gen_int, inf_pmf + ) + + assert_array_equal( + inf.compute_infections_from_rt(I0, Rt_raw, gen_int), + infs_feedback, + ) + + assert_array_equal(Rt_adj, Rt_raw) + return None diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index d241783c..9be03312 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import numpy as np import numpyro +import pytest from jax.typing import ArrayLike from numpy.testing import assert_array_almost_equal, assert_array_equal @@ -38,40 +39,64 @@ def _infection_w_feedback_alt( ------- tuple """ - - Rt = np.array(Rt) # coerce from jax to use numpy-like operations T = len(Rt) + Rt = np.array(Rt).reshape( + T, -1 + ) # coerce from jax to use numpy-like operations len_gen = len(gen_int) - I_vec = np.concatenate([I0, np.zeros(T)]) - Rt_adj = np.zeros(T) + infs = np.pad(I0.reshape(T, -1), ((0, Rt.shape[0]), (0, 0))) + Rt_adj = np.zeros(Rt.shape) + inf_feedback_strength = np.array(inf_feedback_strength).reshape(T, -1) - for t in range(T): - Rt_adj[t] = Rt[t] * np.exp( - inf_feedback_strength[t] - * np.dot(I_vec[t : t + len_gen], np.flip(inf_feedback_pmf)) - ) + for n in range(Rt.shape[1]): + for t in range(Rt.shape[0]): + Rt_adj[t, n] = Rt[t, n] * np.exp( + inf_feedback_strength[t, n] + * np.dot(infs[t : t + len_gen, n], np.flip(inf_feedback_pmf)) + ) - I_vec[t + len_gen] = Rt_adj[t] * np.dot( - I_vec[t : t + len_gen], np.flip(gen_int) - ) + infs[t + len_gen, n] = Rt_adj[t, n] * np.dot( + infs[t : t + len_gen, n], np.flip(gen_int) + ) - return {"post_initialization_infections": I_vec[I0.size :], "rt": Rt_adj} + return { + "post_initialization_infections": np.squeeze(infs[I0.shape[0] :]), + "rt": np.squeeze(Rt_adj), + } -def test_infectionsrtfeedback(): +@pytest.mark.parametrize( + ["Rt", "I0", "inf_feed_strength"], + [ + [ + jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]), + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + DeterministicVariable( + name="inf_feed_strength", value=jnp.array(0) + ), + ], + [ + jnp.array( + np.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25] * 3) + ).reshape((7, 3)), + jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, 3)), + DeterministicVariable( + name="inf_feed_strength", value=jnp.zeros(3) + ), + ], + ], +) +def test_infectionsrtfeedback(Rt, I0, inf_feed_strength): """ Test the InfectionsWithFeedback matching the Infections class. """ - - Rt = jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]) - I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) # By doing the infection feedback strength 0, Rt = Rt_adjusted # So infection should be equal in both - inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.zeros_like(Rt) - ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) # Test the InfectionsWithFeedback class @@ -104,17 +129,31 @@ def test_infectionsrtfeedback(): return None -def test_infectionsrtfeedback_feedback(): +@pytest.mark.parametrize( + ["Rt", "I0"], + [ + [ + jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]), + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + ], + [ + jnp.array( + np.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25] * 3) + ).reshape((7, 3)), + jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, 3)), + ], + ], +) +def test_infectionsrtfeedback_feedback(Rt, I0): """ Test the InfectionsWithFeedback with feedback """ - - Rt = jnp.array([0.5, 0.6, 1.5, 2.523, 0.7, 0.8]) - I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.repeat(0.5, len(Rt)) + name="inf_feed_strength", value=0.5 * jnp.ones_like(Rt) ) inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) @@ -158,3 +197,70 @@ def test_infectionsrtfeedback_feedback(): assert_array_almost_equal(samp1.rt, res["rt"]) return None + + +def test_infections_with_feedback_invalid_inputs(): + """ + Test the InfectionsWithFeedback class cannot + be sampled when Rt and I0 have invalid input shapes + """ + I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) + I0_2d = jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, -1)) + Rt = jnp.ones(10) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + inf_feed_strength = DeterministicVariable( + name="inf_feed_strength", value=0.5 + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + infections = latent.Infections() + + with numpyro.handlers.seed(rng_seed=0): + with pytest.raises( + ValueError, + match="Initial infections must be at least as long as the generation interval.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections vector must be at least as long as the generation interval.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same batch shapes.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same batch shapes.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) diff --git a/test/test_infectionwithfeedback_plate_compatibility.py b/test/test_infectionwithfeedback_plate_compatibility.py new file mode 100644 index 00000000..535902ab --- /dev/null +++ b/test/test_infectionwithfeedback_plate_compatibility.py @@ -0,0 +1,44 @@ +""" +Test the InfectionsWithFeedback class works well within numpyro plate +""" + +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as dist + +import pyrenew.latent as latent +from pyrenew.deterministic import DeterministicPMF +from pyrenew.randomvariable import DistributionalVariable + + +def test_infections_with_feedback_plate_compatibility(): + """ + Test the InfectionsWithFeedback matching the Infections class. + """ + I0 = jnp.array( + np.array([0.0, 0.0, 0.0, 0.5, 0.6, 0.7, 0.8] * 5).reshape(-1, 5) + ) + Rt = jnp.ones((10, 5)) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1]) + + inf_feed_strength = DistributionalVariable( + "inf_feed_strength", dist.Beta(1, 1) + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + with numpyro.handlers.seed(rng_seed=0): + with numpyro.plate("test_plate", 5): + samp = InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + assert samp.rt.shape == Rt.shape