Skip to content

Commit

Permalink
Merge branch 'main' into 442-vectorize-differencedprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Sep 17, 2024
2 parents 3701dd4 + ab8238a commit 0454b4e
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 127 deletions.
5 changes: 2 additions & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple(
)
```

The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:
The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:

```{python}
# | label: new-model-def
Expand Down Expand Up @@ -224,10 +224,9 @@ class InfFeedback(RandomVariable):
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)
inf_feedback_strength = au.pad_x_to_match_y(
inf_feedback_strength, _ = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)
# Sampling inf feedback and adjusting the shape
Expand Down
60 changes: 16 additions & 44 deletions pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,25 @@
from jax.typing import ArrayLike


def pad_to_match(
def pad_edges_to_match(
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
axis: int = 0,
pad_direction: str = "end",
fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]:
"""
Pad the shorter array at the start or end to match the length of the longer array.
Pad the shorter array at the start or end using the
edge values to match the length of the longer array.
Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
axis : int, optional
Axis along which to add padding, by default 0
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
fix_y : bool, optional
Expand All @@ -38,64 +39,35 @@ def pad_to_match(
"""
x = jnp.atleast_1d(x)
y = jnp.atleast_1d(y)
x_len = x.size
y_len = y.size
x_len = x.shape[axis]
y_len = y.shape[axis]
pad_size = abs(x_len - y_len)
pad_width = [(0, 0)] * x.ndim

pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if pad_width is None:
if pad_direction not in ["start", "end"]:
raise ValueError(
"pad_direction must be either 'start' or 'end'."
f" Got {pad_direction}."
)

pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if x_len > y_len:
if fix_y:
raise ValueError(
"Cannot fix y when x is longer than y."
f" x_len: {x_len}, y_len: {y_len}."
)
y = jnp.pad(y, pad_width, constant_values=fill_value)
y = jnp.pad(y, pad_width, mode="edge")

elif y_len > x_len:
x = jnp.pad(x, pad_width, constant_values=fill_value)
x = jnp.pad(x, pad_width, mode="edge")

return x, y


def pad_x_to_match_y(
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
pad_direction: str = "end",
) -> ArrayLike:
"""
Pad the `x` array at the start or end to match the length of the `y` array.
Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
Returns
-------
Array
Padded array.
"""
return pad_to_match(
x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True
)[0]


class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess()`.
Expand Down
17 changes: 12 additions & 5 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ def _new_scanner(
history_subset: ArrayLike, multiplier: float
) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08
new_val = transform(
multiplier * jnp.dot(array_to_convolve, history_subset)
multiplier
* jnp.einsum("i...,i...->...", array_to_convolve, history_subset)
)
latest = jnp.concatenate(
[history_subset[1:], new_val[jnp.newaxis]], axis=0
)
latest = jnp.hstack([history_subset[1:], new_val])
return latest, new_val

return _new_scanner
Expand Down Expand Up @@ -158,9 +161,13 @@ def _new_scanner(
multipliers: tuple[float, float],
) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08
m1, m2 = multipliers
m_net1 = t1(m1 * jnp.dot(arr1, history_subset))
new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset))
latest = jnp.hstack([history_subset[1:], new_val])
m_net1 = t1(m1 * jnp.einsum("i...,i...->...", arr1, history_subset))
new_val = t2(
m2 * m_net1 * jnp.einsum("i...,i...->...", arr2, history_subset)
)
latest = jnp.concatenate(
[history_subset[1:], new_val[jnp.newaxis]], axis=0
)
return latest, (new_val, m_net1)

return _new_scanner
Expand Down
11 changes: 9 additions & 2 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,21 @@ def sample(
InfectionsSample
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
raise ValueError(
"Initial infections vector must be at least as long as "
"the generation interval. "
f"Initial infections vector length: {I0.size}, "
f"Initial infections vector length: {I0.shape[0]}, "
f"generation interval length: {gen_int.size}."
)

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)

gen_int_rev = jnp.flip(gen_int)
recent_I0 = I0[-gen_int_rev.size :]

Expand Down
35 changes: 23 additions & 12 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,18 @@ def sample(
InfectionsWithFeedback
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
raise ValueError(
"Initial infections must be at least as long as the "
f"generation interval. Got {I0.size} initial infections "
f"and {gen_int.size} generation interval."
f"generation interval. Got initial infections length {I0.shape[0]}"
f"and generation interval length {gen_int.size}."
)

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)

gen_int_rev = jnp.flip(gen_int)
Expand All @@ -160,19 +167,23 @@ def sample(
**kwargs,
)
)

if inf_feedback_strength.ndim == Rt.ndim - 1:
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.size == 1:
inf_feedback_strength = au.pad_x_to_match_y(
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)
elif inf_feedback_strength.size != Rt.size:
axis=0,
)[0]
if inf_feedback_strength.shape != Rt.shape:
raise ValueError(
"Infection feedback strength must be of size 1 "
"or the same size as the reproduction number array. "
f"Got {inf_feedback_strength.size} "
f"and {Rt.size} respectively."
"Infection feedback strength must be of length 1 "
"or the same length as the reproduction number array. "
f"Got {inf_feedback_strength.shape} "
f"and {Rt.shape} respectively."
)

# Sampling inf feedback pmf
Expand Down
75 changes: 47 additions & 28 deletions test/test_arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,72 @@
import pyrenew.arrayutils as au


def test_arrayutils_pad_to_match():
def test_pad_edges_to_match():
"""
Verifies extension when required and error when `fix_y` is True.
Test function to verify padding along the edges for 1D and 2D arrays
"""

# test when y gets padded
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_to_match(x, y)

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert y_pad[-1] == y[-1]
assert jnp.array_equal(x_pad, x)

# test when x gets padded
x = jnp.array([1, 2])
y = jnp.array([1, 2, 3])

x_pad, y_pad = au.pad_to_match(x, y)

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert x_pad[-1] == x[-1]
assert jnp.array_equal(y_pad, y)

# test when no padding required
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])
y = jnp.array([4, 5, 6])
x_pad, y_pad = au.pad_edges_to_match(x, y)

# Verify that the function raises an error when `fix_y` is True
with pytest.raises(ValueError):
x_pad, y_pad = au.pad_to_match(x, y, fix_y=True)
assert jnp.array_equal(x_pad, x)
assert jnp.array_equal(y_pad, y)

# Verify function works with both padding directions
x_pad, y_pad = au.pad_to_match(x, y, pad_direction="start")
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="start")

assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert y_pad[0] == y[0]
assert jnp.array_equal(x_pad, x)

# Verify that the function raises an error when `fix_y` is True
with pytest.raises(
ValueError, match="Cannot fix y when x is longer than y"
):
x_pad, y_pad = au.pad_edges_to_match(x, y, fix_y=True)

# Verify function raises an error when pad_direction is not "start" or "end"
with pytest.raises(ValueError):
x_pad, y_pad = au.pad_to_match(x, y, pad_direction="middle")


def test_arrayutils_pad_x_to_match_y():
"""
Verifies extension when required
"""

x = jnp.array([1, 2])
y = jnp.array([1, 2, 3])

x_pad = au.pad_x_to_match_y(x, y)

assert x_pad.size == 3
x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="middle")

# test padding for 2D arrays
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6]])

# Padding along axis 0
axis = 0
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")

assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y_pad[-1], y[-1])
assert jnp.array_equal(x_pad, x)

# padding along axis 1
axis = 1
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")
assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y[:, -1], y_pad[:, -1])
assert jnp.array_equal(x_pad, x)
Loading

0 comments on commit 0454b4e

Please sign in to comment.