Skip to content

Commit

Permalink
add more test for convolve scanner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Sep 12, 2024
1 parent 97d19d0 commit d933df4
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def sample(
f"generation interval length: {gen_int.size}."
)

if I0.shape != Rt.shape:
raise ValueError(
"Initial infections and Rt must have the same shape. "
f"Got initial infections of shape {I0.shape} "
f"and Rt of shape {Rt.shape}."
)

gen_int_rev = jnp.flip(gen_int)
recent_I0 = I0[-gen_int_rev.size :]

Expand Down
11 changes: 9 additions & 2 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,15 @@ def sample(
if I0.shape[0] < gen_int.size:
raise ValueError(
"Initial infections must be at least as long as the "
f"generation interval. Got {I0.shape[0]} initial infections "
f"and {gen_int.size} generation interval."
f"generation interval. Got initial infections length {I0.shape[0]}"
f"and generation interval length {gen_int.size}."
)

if I0.shape != Rt.shape:
raise ValueError(
"Initial infections and Rt must have the same shape. "
f"Got initial infections of shape {I0.shape} "
f"and Rt of shape {Rt.shape}."
)

gen_int_rev = jnp.flip(gen_int)
Expand Down
146 changes: 146 additions & 0 deletions test/test_convolve_scanners.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numpy.testing import assert_array_equal

import pyrenew.convolve as pc
import pyrenew.transformation as t


@pytest.mark.parametrize(
Expand Down Expand Up @@ -82,3 +83,148 @@ def transform_ones_like(x: any):

assert_array_equal(result_a_double[1], jnp.ones_like(multipliers))
assert_array_equal(result_a_double[0], result_a)


@pytest.mark.parametrize(
["arr", "history", "multipliers", "transform"],
[
[
jnp.array([1.0, 2.0]),
jnp.array([3.0, 4.0]),
jnp.array([1, 2, 3]),
t.IdentityTransform(),
],
[
jnp.ones(3),
jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3),
jnp.ones((3, 3)),
t.ExpTransform(),
],
],
)
def test_convolve_scanner_using_scan(arr, history, multipliers, transform):
"""
Tests the output of new convolve scanner function
used with `jax.lax.scan` against values calculated
using a for loop
"""
scanner = pc.new_convolve_scanner(arr, transform)

_, result = jax.lax.scan(f=scanner, init=history, xs=multipliers)

result_not_scanned = []
for multiplier in multipliers:
history, new_val = scanner(history, multiplier)
result_not_scanned.append(new_val)

assert jnp.array_equal(result, result_not_scanned)


@pytest.mark.parametrize(
["arr1", "arr2", "history", "m1", "m2", "transform"],
[
[
jnp.array([1.0, 2.0]),
jnp.array([2.0, 1.0]),
jnp.array([0.1, 0.4]),
jnp.array([1, 2, 3]),
jnp.ones(3),
(t.IdentityTransform(), t.IdentityTransform()),
],
[
jnp.array([1.0, 2.0, 0.3]),
jnp.array([2.0, 1.0, 0.5]),
jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3),
jnp.ones((3, 3)),
jnp.ones((3, 3)),
(t.ExpTransform(), t.IdentityTransform()),
],
],
)
def test_double_convolve_scanner_using_scan(
arr1, arr2, history, m1, m2, transform
):
"""
Tests the output of new convolve double scanner function
used with `jax.lax.scan` against values calculated
using a for loop
"""
arr1 = jnp.array([1.0, 2.0])
arr2 = jnp.array([2.0, 1.0])
transform = (t.IdentityTransform(), t.IdentityTransform())
history = jnp.array([0.1, 0.4])
m1, m2 = (jnp.array([1, 2, 3]), jnp.ones(3))

scanner = pc.new_double_convolve_scanner((arr1, arr2), transform)

_, result = jax.lax.scan(f=scanner, init=history, xs=(m1, m2))

res1, res2 = [], []
for m1, m2 in zip(m1, m2):
history, new_val = scanner(history, (m1, m2))
res1.append(new_val[0])
res2.append(new_val[1])

assert jnp.array_equal(result, (res1, res2))


@pytest.mark.parametrize(
["arr", "history", "multiplier", "transform"],
[
[
jnp.array([1.0, 2.0]),
jnp.array([3.0, 4.0]),
jnp.array(2),
t.IdentityTransform(),
],
[
jnp.ones(3),
jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3),
jnp.ones(3),
t.ExpTransform(),
],
],
)
def test_convolve_scanner(arr, history, multiplier, transform):
"""
Tests new convolve scanner function
"""
scanner = pc.new_convolve_scanner(arr, transform)
latest, new_val = scanner(history, multiplier)
assert jnp.array_equal(
new_val, transform(multiplier * jnp.dot(arr, history))
)


@pytest.mark.parametrize(
["arr1", "arr2", "history", "m1", "m2", "transforms"],
[
[
jnp.array([1.0, 2.0]),
jnp.array([2.0, 1.0]),
jnp.array([0.1, 0.4]),
jnp.array(1),
jnp.array(3),
(t.IdentityTransform(), t.IdentityTransform()),
],
[
jnp.array([1.0, 2.0, 0.3]),
jnp.array([2.0, 1.0, 0.5]),
jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3),
jnp.ones(3),
0.1 * jnp.ones(3),
(t.ExpTransform(), t.IdentityTransform()),
],
],
)
def test_double_convolve_scanner(arr1, arr2, history, m1, m2, transforms):
"""
Tests new double convolve scanner function
"""
double_scanner = pc.new_double_convolve_scanner((arr1, arr2), transforms)
latest, (new_val, m_net) = double_scanner(history, (m1, m2))

assert jnp.array_equal(m_net, transforms[0](m1 * jnp.dot(arr1, history)))
assert jnp.array_equal(
new_val, transforms[1](m2 * m_net * jnp.dot(arr2, history))
)

0 comments on commit d933df4

Please sign in to comment.