Skip to content

Commit

Permalink
Remove SampledValue (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer authored Sep 11, 2024
1 parent 145ced6 commit bfaa0c0
Show file tree
Hide file tree
Showing 55 changed files with 373 additions and 1,168 deletions.
11 changes: 5 additions & 6 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -152,9 +151,9 @@ class MyRt(RandomVariable):
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rt_proc = MyRt()
Expand Down Expand Up @@ -220,11 +219,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt.value)
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_infections.value)
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
Expand All @@ -242,7 +241,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections.value,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Expand Down
16 changes: 8 additions & 8 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ I0 = InfectionInitializationProcess(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
# Generation interval and Rt
Expand All @@ -110,11 +109,11 @@ class MyRt(metaclass.RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
# Standard deviation of the random walk
sd_rt, *_ = self.sd_rv()
sd_rt = self.sd_rv()
# Random walk step
step_rv = randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
name="rw_step_rv", distribution=dist.Normal(0, sd_rt)
)
rt_init_rv = randomvariable.DistributionalVariable(
Expand All @@ -133,9 +132,9 @@ class MyRt(metaclass.RandomVariable):
base_rv=base_rv,
transforms=transformation.ExpTransform(),
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rtproc = MyRt(
Expand Down Expand Up @@ -168,7 +167,7 @@ obs = observation.NegativeBinomialObservation(
)
```

4. And finally, we built the model:
4. And finally, we build the model:

```{python}
# | label: init-model
Expand Down Expand Up @@ -282,10 +281,11 @@ As a result, we can see the posterior distribution of our novel day-of-the-week
# | label: fig-output-day-of-week
# | fig-cap: Day of the week effect
out = hosp_model_dow.plot_posterior(
var="dayofweek_effect", ylab="Day of the Week Effect", samples=500
var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500
)
sp = hosp_model_dow.spread_draws(["dayofweek_effect"])
sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"])
# dayofweek_effect is not recorded
```

The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect:
Expand Down
33 changes: 15 additions & 18 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ I0 = InfectionInitializationProcess(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
latent_infections = InfectionsWithFeedback(
Expand Down Expand Up @@ -85,9 +84,9 @@ class MyRt(RandomVariable):
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
```

With all the components defined, we can build the model:
Expand Down Expand Up @@ -118,7 +117,7 @@ with numpyro.handlers.seed(rng_seed=223):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections.value)
ax.plot(model0_samp.latent_infections)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
Expand Down Expand Up @@ -164,7 +163,7 @@ from collections import namedtuple
# Creating a tuple to store the output
InfFeedbackSample = namedtuple(
typename="InfFeedbackSample",
field_names=["infections", "rt"],
field_names=["post_initialization_infections", "rt"],
defaults=(None, None),
)
```
Expand All @@ -175,7 +174,7 @@ The next step is to create the actual class. The bulk of its implementation lies
# | label: new-model-def
# | code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.metaclass import RandomVariable
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
Expand Down Expand Up @@ -219,11 +218,11 @@ class InfFeedback(RandomVariable):
I0_vec = I0[-gen_int_rev.size :]
# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength(
inf_feedback_strength = self.infection_feedback_strength(
**kwargs,
)
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value)
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength,
Expand All @@ -232,8 +231,8 @@ class InfFeedback(RandomVariable):
)
# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value)
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)
# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
Expand All @@ -250,20 +249,18 @@ class InfFeedback(RandomVariable):
# Preparing theoutput
return InfFeedbackSample(
infections=SampledValue(all_infections),
rt=SampledValue(Rt_adj),
post_initialization_infections=all_infections,
rt=Rt_adj,
)
```

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.
2. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.
3. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
# | label: simulation2
Expand Down Expand Up @@ -293,8 +290,8 @@ Comparing `model0` with `model1`, these two should match:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections.value)
ax[1].plot(model1_samp.latent_infections.value)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
Expand Down
12 changes: 5 additions & 7 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ I0 = InfectionInitializationProcess(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
# Generation interval and Rt
Expand Down Expand Up @@ -207,9 +206,9 @@ class MyRt(metaclass.RandomVariable):
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rtproc = MyRt()
Expand Down Expand Up @@ -256,7 +255,6 @@ import numpy as np
timeframe = 120
with numpyro.handlers.seed(rng_seed=223):
simulated_data = hosp_model.sample(n_datapoints=timeframe)
```
Expand All @@ -269,11 +267,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(simulated_data.Rt.value)
axs[0].plot(simulated_data.Rt)
axs[0].set_ylabel("Simulated Rt")
# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o")
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].set_ylabel("Simulated Admissions")
fig.suptitle("Basic renewal model")
Expand Down Expand Up @@ -483,7 +481,7 @@ def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.values.T
return eti_bdry.T
fig, axes = plt.subplots(figsize=(6, 5))
Expand Down
7 changes: 2 additions & 5 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the Rt values
import matplotlib.pyplot as plt
plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post")
plt.step(np.arange(len(sim_data)), sim_data, where="post")
plt.xlabel("Time")
plt.ylabel("Rt")
plt.title("Simulated Rt values")
Expand Down Expand Up @@ -79,7 +79,6 @@ dayofweek = process.DayOfWeekEffect(
quantity_to_broadcast=randomvariable.DistributionalVariable(
name="simp", distribution=mysimplex
),
t_start=0,
)
```

Expand All @@ -92,9 +91,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the effect values
import matplotlib.pyplot as plt
plt.step(
np.arange(len(sim_data.value.value)), sim_data.value.value, where="post"
)
plt.step(np.arange(len(sim_data)), sim_data, where="post")
plt.xlabel("Time")
plt.ylabel("Effect size")
plt.title("Simulated Day of Week Effect values")
Expand Down
51 changes: 0 additions & 51 deletions docs/source/tutorials/time.qmd

This file was deleted.

5 changes: 0 additions & 5 deletions docs/source/tutorials/time.rst

This file was deleted.

2 changes: 1 addition & 1 deletion pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class PeriodicProcessSample(NamedTuple):
value: ArrayLike | None = None

def __repr__(self):
return f"PeriodicProcessSample(value={self.value})"
return f"PeriodicProcessSample(value={self})"


def tile_until_n(
Expand Down
9 changes: 1 addition & 8 deletions pyrenew/deterministic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@

from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.deterministic.deterministicpmf import DeterministicPMF
from pyrenew.deterministic.nullrv import (
NullObservation,
NullProcess,
NullVariable,
)
from pyrenew.deterministic.process import DeterministicProcess
from pyrenew.deterministic.nullrv import NullObservation, NullVariable

__all__ = [
"DeterministicVariable",
"DeterministicPMF",
"DeterministicProcess",
"NullVariable",
"NullProcess",
"NullObservation",
]
Loading

0 comments on commit bfaa0c0

Please sign in to comment.