From bfaa0c04bf4f08ce1cb3ccd6fa26895ec860355b Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Wed, 11 Sep 2024 12:06:13 -0500 Subject: [PATCH] Remove `SampledValue` (#441) --- docs/source/tutorials/basic_renewal_model.qmd | 11 +- docs/source/tutorials/day_of_the_week.qmd | 16 +- docs/source/tutorials/extending_pyrenew.qmd | 33 ++-- .../tutorials/hospital_admissions_model.qmd | 12 +- docs/source/tutorials/periodic_effects.qmd | 7 +- docs/source/tutorials/time.qmd | 51 ----- docs/source/tutorials/time.rst | 5 - pyrenew/arrayutils.py | 2 +- pyrenew/deterministic/__init__.py | 9 +- pyrenew/deterministic/deterministic.py | 28 +-- pyrenew/deterministic/deterministicpmf.py | 16 +- pyrenew/deterministic/nullrv.py | 60 +----- pyrenew/deterministic/process.py | 53 ------ pyrenew/latent/hospitaladmissions.py | 70 +++---- .../latent/infection_initialization_method.py | 2 +- .../infection_initialization_process.py | 38 +--- pyrenew/latent/infections.py | 10 +- pyrenew/latent/infectionswithfeedback.py | 32 ++-- pyrenew/metaclass.py | 177 ------------------ pyrenew/model/admissionsmodel.py | 48 ++--- pyrenew/model/rtinfectionsrenewalmodel.py | 69 +++---- pyrenew/observation/negativebinomial.py | 18 +- pyrenew/observation/poisson.py | 14 +- pyrenew/process/ar.py | 20 +- pyrenew/process/differencedprocess.py | 21 +-- pyrenew/process/iidrandomsequence.py | 27 +-- pyrenew/process/periodiceffect.py | 58 +----- pyrenew/process/rtperiodicdiffar.py | 63 ++----- .../randomvariable/distributionalvariable.py | 28 +-- pyrenew/randomvariable/transformedvariable.py | 20 +- pyrenew/regression.py | 19 +- test/test_ar_process.py | 6 +- test/test_assert_sample_and_rtype.py | 123 ------------ test/test_deterministic.py | 32 +--- test/test_differenced_process.py | 24 +-- test/test_distributional_rv.py | 6 +- test/test_forecast.py | 3 +- test/test_iid_random_sequence.py | 17 +- test/test_infection_initialization_method.py | 12 +- test/test_infection_initialization_process.py | 5 - test/test_infectionsrtfeedback.py | 18 +- test/test_latent_admissions.py | 21 +-- test/test_latent_infections.py | 14 +- test/test_model_basic_renewal.py | 28 ++- test/test_model_hosp_admissions.py | 71 ++----- test/test_observation_negativebinom.py | 21 +-- test/test_observation_poisson.py | 4 +- test/test_periodiceffect.py | 15 +- test/test_predictive.py | 1 - test/test_random_key.py | 3 +- test/test_random_walk.py | 17 +- test/test_regression.py | 10 +- test/test_rtperiodicdiff.py | 8 +- test/test_transformed_rv_class.py | 33 ++-- test/utils.py | 12 +- 55 files changed, 373 insertions(+), 1168 deletions(-) delete mode 100644 docs/source/tutorials/time.qmd delete mode 100644 docs/source/tutorials/time.rst delete mode 100644 pyrenew/deterministic/process.py delete mode 100644 test/test_assert_sample_and_rtype.py diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index bf262094..f9100da6 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -125,7 +125,6 @@ I0 = InfectionInitializationProcess( "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)), InitializeInfectionsZeroPad(pmf_array.size), - t_unit=1, ) @@ -152,9 +151,9 @@ class MyRt(RandomVariable): rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) - init_rt, *_ = rt_init_rv.sample() + init_rt = rt_init_rv.sample() - return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt, **kwargs) rt_proc = MyRt() @@ -220,11 +219,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt.value) +axs[0].plot(sim_data.Rt) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.observed_infections.value) +axs[1].plot(sim_data.observed_infections) axs[1].set_ylabel("Infections") fig.suptitle("Basic renewal model") @@ -242,7 +241,7 @@ import jax model1.run( num_warmup=2000, num_samples=1000, - data_observed_infections=sim_data.observed_infections.value, + data_observed_infections=sim_data.observed_infections, rng_key=jax.random.key(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 4b4c6f8f..854f7454 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -89,7 +89,6 @@ I0 = InfectionInitializationProcess( n_initialization_points, deterministic.DeterministicVariable(name="rate", value=0.05), ), - t_unit=1, ) # Generation interval and Rt @@ -110,11 +109,11 @@ class MyRt(metaclass.RandomVariable): def sample(self, n: int, **kwargs) -> tuple: # Standard deviation of the random walk - sd_rt, *_ = self.sd_rv() + sd_rt = self.sd_rv() # Random walk step step_rv = randomvariable.DistributionalVariable( - name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) + name="rw_step_rv", distribution=dist.Normal(0, sd_rt) ) rt_init_rv = randomvariable.DistributionalVariable( @@ -133,9 +132,9 @@ class MyRt(metaclass.RandomVariable): base_rv=base_rv, transforms=transformation.ExpTransform(), ) - init_rt, *_ = rt_init_rv.sample() + init_rt = rt_init_rv.sample() - return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt, **kwargs) rtproc = MyRt( @@ -168,7 +167,7 @@ obs = observation.NegativeBinomialObservation( ) ``` -4. And finally, we built the model: +4. And finally, we build the model: ```{python} # | label: init-model @@ -282,10 +281,11 @@ As a result, we can see the posterior distribution of our novel day-of-the-week # | label: fig-output-day-of-week # | fig-cap: Day of the week effect out = hosp_model_dow.plot_posterior( - var="dayofweek_effect", ylab="Day of the Week Effect", samples=500 + var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500 ) -sp = hosp_model_dow.spread_draws(["dayofweek_effect"]) +sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"]) +# dayofweek_effect is not recorded ``` The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect: diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index f81de653..2cef0d2a 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -55,7 +55,6 @@ I0 = InfectionInitializationProcess( gen_int_array.size, DeterministicVariable(name="rate", value=0.05), ), - t_unit=1, ) latent_infections = InfectionsWithFeedback( @@ -85,9 +84,9 @@ class MyRt(RandomVariable): rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) - init_rt, *_ = rt_init_rv.sample() + init_rt = rt_init_rv.sample() - return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt, **kwargs) ``` With all the components defined, we can build the model: @@ -118,7 +117,7 @@ with numpyro.handlers.seed(rng_seed=223): import matplotlib.pyplot as plt fig, ax = plt.subplots() -ax.plot(model0_samp.latent_infections.value) +ax.plot(model0_samp.latent_infections) ax.set_xlabel("Time") ax.set_ylabel("Infections") plt.show() @@ -164,7 +163,7 @@ from collections import namedtuple # Creating a tuple to store the output InfFeedbackSample = namedtuple( typename="InfFeedbackSample", - field_names=["infections", "rt"], + field_names=["post_initialization_infections", "rt"], defaults=(None, None), ) ``` @@ -175,7 +174,7 @@ The next step is to create the actual class. The bulk of its implementation lies # | label: new-model-def # | code-line-numbers: true # Creating the class -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable from pyrenew.latent import compute_infections_from_rt_with_feedback from pyrenew import arrayutils as au from jax.typing import ArrayLike @@ -219,11 +218,11 @@ class InfFeedback(RandomVariable): I0_vec = I0[-gen_int_rev.size :] # Sampling inf feedback strength and adjusting the shape - inf_feedback_strength, *_ = self.infection_feedback_strength( + inf_feedback_strength = self.infection_feedback_strength( **kwargs, ) - inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value) + inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength) inf_feedback_strength = au.pad_x_to_match_y( x=inf_feedback_strength, @@ -232,8 +231,8 @@ class InfFeedback(RandomVariable): ) # Sampling inf feedback and adjusting the shape - inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) + inf_feedback_pmf = self.infection_feedback_pmf(**kwargs) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) # Generating the infections with feedback all_infections, Rt_adj = compute_infections_from_rt_with_feedback( @@ -250,8 +249,8 @@ class InfFeedback(RandomVariable): # Preparing theoutput return InfFeedbackSample( - infections=SampledValue(all_infections), - rt=SampledValue(Rt_adj), + post_initialization_infections=all_infections, + rt=Rt_adj, ) ``` @@ -259,11 +258,9 @@ The core of the class is implemented in the `sample()` method. Things to highlig 1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused. -2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`). - -3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later. +2. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later. -4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`. +3. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`. ```{python} # | label: simulation2 @@ -293,8 +290,8 @@ Comparing `model0` with `model1`, these two should match: import matplotlib.pyplot as plt fig, ax = plt.subplots(ncols=2) -ax[0].plot(model0_samp.latent_infections.value) -ax[1].plot(model1_samp.latent_infections.value) +ax[0].plot(model0_samp.latent_infections) +ax[1].plot(model1_samp.latent_infections) ax[0].set_xlabel("Time (model 0)") ax[1].set_xlabel("Time (model 1)") ax[0].set_ylabel("Infections") diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 07c38644..bf11bf3a 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -179,7 +179,6 @@ I0 = InfectionInitializationProcess( n_initialization_points, deterministic.DeterministicVariable(name="rate", value=0.05), ), - t_unit=1, ) # Generation interval and Rt @@ -207,9 +206,9 @@ class MyRt(metaclass.RandomVariable): rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) - init_rt, *_ = rt_init_rv.sample() + init_rt = rt_init_rv.sample() - return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt, **kwargs) rtproc = MyRt() @@ -256,7 +255,6 @@ import numpy as np timeframe = 120 - with numpyro.handlers.seed(rng_seed=223): simulated_data = hosp_model.sample(n_datapoints=timeframe) ``` @@ -269,11 +267,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(simulated_data.Rt.value) +axs[0].plot(simulated_data.Rt) axs[0].set_ylabel("Simulated Rt") # Admissions plot -axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o") +axs[1].plot(simulated_data.observed_hosp_admissions, "-o") axs[1].set_ylabel("Simulated Admissions") fig.suptitle("Basic renewal model") @@ -483,7 +481,7 @@ def compute_eti(dataset, eti_prob): eti_bdry = dataset.quantile( ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") ) - return eti_bdry.values.T + return eti_bdry.T fig, axes = plt.subplots(figsize=(6, 5)) diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 1603ed59..7f3cf5b7 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20): # Plotting the Rt values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post") +plt.step(np.arange(len(sim_data)), sim_data, where="post") plt.xlabel("Time") plt.ylabel("Rt") plt.title("Simulated Rt values") @@ -79,7 +79,6 @@ dayofweek = process.DayOfWeekEffect( quantity_to_broadcast=randomvariable.DistributionalVariable( name="simp", distribution=mysimplex ), - t_start=0, ) ``` @@ -92,9 +91,7 @@ with numpyro.handlers.seed(rng_seed=20): # Plotting the effect values import matplotlib.pyplot as plt -plt.step( - np.arange(len(sim_data.value.value)), sim_data.value.value, where="post" -) +plt.step(np.arange(len(sim_data)), sim_data, where="post") plt.xlabel("Time") plt.ylabel("Effect size") plt.title("Simulated Day of Week Effect values") diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd deleted file mode 100644 index 02523cf2..00000000 --- a/docs/source/tutorials/time.qmd +++ /dev/null @@ -1,51 +0,0 @@ ---- -title: Time handling in pyrenew -format: gfm -engine: jupyter ---- - -Every `pyrenew` model has a _fundamental discrete time unit_. All time-aware arrays used in the model should have units expressible as integer multiples of this fundamental unit. - -The fundamental time unit should represent a period of fixed (or approximately fixed) duration. That is, "days" could make for a good fundamental time unit; "months" would not, since different months represent different absolute lengths of time. - -For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - -`pyrenew` deals with time by having `RandomVariable`s carry information about - -1. their own time unit expressed relative to the fundamental unit (`t_unit`) and -2. the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. - -Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes. - -By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`. - -The `t_unit, t_start` pair can encode different types of time series data. For example: - -| Description | `t_unit` | `t_start` | -|:-----------------|----------------:|-----------------:| -| Daily starting on day two | 1 | 1 | -| Weekly starting on week two | 7 | 7 | -| Daily starting on day 40 | 1 | 39 | -| Biweekly starting on day 40 | 14 | 39 | -| Daily, with the first observation starting five days before the model (as in the initialization process) | 1 | -5 | - - -## How it relates to periodicity - -The `tile_until_n()` and `repeat_until_n()` functions provide a way of tiling and repeating data accounting starting time, but they do not encode the time unit, only the period length and starting point. Furthermore, samples returned from `PeriodicEffect()` and `RtPeriodicDiffProcess()` both currently return daily values shifted so that the first entry of their arrays matches day 0 in the model. - -## Unimplemented features - -The following section describes some preliminary design principles that may be included in future versions of `pyrenew`. - -### Array alignment - -Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the initialization process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.: - -```python -Rt_aligned, infections_aligned = align([Rt, infections]) -``` - -### Retrieving time information from sites - -Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model. diff --git a/docs/source/tutorials/time.rst b/docs/source/tutorials/time.rst deleted file mode 100644 index 4453da60..00000000 --- a/docs/source/tutorials/time.rst +++ /dev/null @@ -1,5 +0,0 @@ -.. WARNING -.. Please do not edit this file directly. -.. This file is just a placeholder. -.. For the source file, see: -.. diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 402fbe2c..9b8c4b21 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -109,7 +109,7 @@ class PeriodicProcessSample(NamedTuple): value: ArrayLike | None = None def __repr__(self): - return f"PeriodicProcessSample(value={self.value})" + return f"PeriodicProcessSample(value={self})" def tile_until_n( diff --git a/pyrenew/deterministic/__init__.py b/pyrenew/deterministic/__init__.py index 0bde1e3e..8dbd2151 100644 --- a/pyrenew/deterministic/__init__.py +++ b/pyrenew/deterministic/__init__.py @@ -2,18 +2,11 @@ from pyrenew.deterministic.deterministic import DeterministicVariable from pyrenew.deterministic.deterministicpmf import DeterministicPMF -from pyrenew.deterministic.nullrv import ( - NullObservation, - NullProcess, - NullVariable, -) -from pyrenew.deterministic.process import DeterministicProcess +from pyrenew.deterministic.nullrv import NullObservation, NullVariable __all__ = [ "DeterministicVariable", "DeterministicPMF", - "DeterministicProcess", "NullVariable", - "NullProcess", "NullObservation", ] diff --git a/pyrenew/deterministic/deterministic.py b/pyrenew/deterministic/deterministic.py index ad264570..ed61f1ba 100644 --- a/pyrenew/deterministic/deterministic.py +++ b/pyrenew/deterministic/deterministic.py @@ -5,7 +5,7 @@ import numpyro from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class DeterministicVariable(RandomVariable): @@ -18,8 +18,6 @@ def __init__( self, name: str, value: ArrayLike, - t_start: int | None = None, - t_unit: int | None = None, ) -> None: """Default constructor @@ -29,10 +27,6 @@ def __init__( A name to assign to the variable. value : ArrayLike An ArrayLike object. - t_start : int, optional - The start time of the variable, if any. - t_unit : int, optional - The unit of time relative to the model's fundamental (smallest) time unit, if any Returns ------- @@ -41,7 +35,6 @@ def __init__( self.name = name self.validate(value) self.value = value - self.set_timeseries(t_start, t_unit) return None @@ -77,7 +70,7 @@ def sample( self, record=False, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Retrieve the value of the deterministic Rv @@ -92,19 +85,8 @@ def sample( Returns ------- - tuple[SampledValue] - A length-one tuple whose single entry is a - :class:`SampledValue` - instance with `value=self.value`, - `t_start=self.t_start`, and - `t_unit=self.t_unit`. + ArrayLike """ if record: - numpyro.deterministic(self.name, self.value) - return ( - SampledValue( - value=self.value, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + numpyro.deterministic(self.name, self) + return self.value diff --git a/pyrenew/deterministic/deterministicpmf.py b/pyrenew/deterministic/deterministicpmf.py index daec0348..a3f629db 100644 --- a/pyrenew/deterministic/deterministicpmf.py +++ b/pyrenew/deterministic/deterministicpmf.py @@ -18,8 +18,6 @@ def __init__( name: str, value: ArrayLike, tol: float = 1e-5, - t_start: int | None = None, - t_unit: int | None = None, ) -> None: """ Default constructor @@ -33,16 +31,11 @@ def __init__( ---------- name : str A name to assign to the variable. - value : tuple + value : ArrayLike An ArrayLike object. tol : float, optional Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults to 1e-5. - t_start : int, optional - The start time of the process. - t_unit : int, optional - The unit of time relative to the model's fundamental (smallest) - time unit. Returns ------- @@ -56,8 +49,6 @@ def __init__( self.basevar = DeterministicVariable( name=name, value=value, - t_start=t_start, - t_unit=t_unit, ) return None @@ -81,7 +72,7 @@ def validate(value: ArrayLike) -> None: def sample( self, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Retrieves the deterministic PMF @@ -93,8 +84,7 @@ def sample( Returns ------- - tuple - Containing the stored values during construction wrapped in a SampledValue. + ArrayLike """ return self.basevar.sample(**kwargs) diff --git a/pyrenew/deterministic/nullrv.py b/pyrenew/deterministic/nullrv.py index 0827769e..1498d33a 100644 --- a/pyrenew/deterministic/nullrv.py +++ b/pyrenew/deterministic/nullrv.py @@ -5,7 +5,6 @@ from jax.typing import ArrayLike from pyrenew.deterministic.deterministic import DeterministicVariable -from pyrenew.metaclass import SampledValue class NullVariable(DeterministicVariable): @@ -37,7 +36,7 @@ def validate() -> None: def sample( self, **kwargs, - ) -> tuple: + ) -> None: """Retrieve the value of the Null (None) Parameters @@ -45,63 +44,13 @@ def sample( **kwargs : dict, optional Ignored. - Returns - ------- - tuple - Containing a SampledValue with None. - """ - - return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) - - -class NullProcess(NullVariable): - """A null random variable. Sampling returns None.""" - - def __init__(self) -> None: - """Default constructor - Returns ------- None """ - self.validate() - - return None - - @staticmethod - def validate() -> None: - """ - Not used - - Returns - ------- - None - """ return None - def sample( - self, - duration: int, - **kwargs, - ) -> tuple: - """Retrieve the value of the Null (None) - - Parameters - ---------- - duration : int - Number of timepoints to sample (ignored). - **kwargs : dict, optional - Ignored. - - Returns - ------- - tuple - Containing a SampledValue with None. - """ - - return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) - class NullObservation(NullVariable): """A null observation random variable. Sampling returns None.""" @@ -134,7 +83,7 @@ def sample( mu: ArrayLike, obs: ArrayLike | None = None, **kwargs, - ) -> tuple: + ) -> None: """ Retrieve the value of the Null (None) @@ -149,8 +98,7 @@ def sample( Returns ------- - tuple - Containing a SampledValue with None. + None """ - return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) + return None diff --git a/pyrenew/deterministic/process.py b/pyrenew/deterministic/process.py deleted file mode 100644 index 08e133ea..00000000 --- a/pyrenew/deterministic/process.py +++ /dev/null @@ -1,53 +0,0 @@ -# numpydoc ignore=GL08 - -import jax.numpy as jnp - -from pyrenew.deterministic.deterministic import DeterministicVariable -from pyrenew.metaclass import SampledValue - - -class DeterministicProcess(DeterministicVariable): - """ - A deterministic process (degenerate) random variable. - Useful to pass fixed quantities over time.""" - - __init__ = DeterministicVariable.__init__ - - def sample( - self, - duration: int, - **kwargs, - ) -> tuple: - """ - Retrieve the value of the deterministic Rv - - Parameters - ---------- - duration : int - Number of timepoints to sample. - **kwargs : dict, optional - Ignored. - - Returns - ------- - tuple[SampledValue] - containing the deterministic value(s) provided - at construction as a series of length `duration`. - """ - - res, *_ = super().sample(**kwargs) - - dif = duration - res.value.shape[0] - - if dif > 0: - value = jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]) - else: - value = res.value[:duration] - - res = SampledValue( - value, - t_start=self.t_start, - t_unit=self.t_unit, - ) - - return (res,) diff --git a/pyrenew/latent/hospitaladmissions.py b/pyrenew/latent/hospitaladmissions.py index 1fcfc581..99f46c13 100644 --- a/pyrenew/latent/hospitaladmissions.py +++ b/pyrenew/latent/hospitaladmissions.py @@ -6,11 +6,12 @@ import jax.numpy as jnp import numpyro +from jax.typing import ArrayLike import pyrenew.arrayutils as au from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class HospitalAdmissionsSample(NamedTuple): @@ -19,19 +20,19 @@ class HospitalAdmissionsSample(NamedTuple): Attributes ---------- - infection_hosp_rate : SampledValue, optional + infection_hosp_rate : ArrayLike, optional The infection-to-hospitalization rate. Defaults to None. - latent_hospital_admissions : SampledValue or None + latent_hospital_admissions : ArrayLike or None The computed number of hospital admissions. Defaults to None. - multiplier : SampledValue or None + multiplier : ArrayLike or None The day of the week effect multiplier. Defaults to None. It should match the number of timepoints in the latent hospital admissions. """ - infection_hosp_rate: SampledValue | None = None - latent_hospital_admissions: SampledValue | None = None - multiplier: SampledValue | None = None + infection_hosp_rate: ArrayLike | None = None + latent_hospital_admissions: ArrayLike | None = None + multiplier: ArrayLike | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions}, multiplier={self.multiplier})" @@ -86,7 +87,7 @@ def __init__( infection_hospitalization_ratio_rv : RandomVariable Infection to hospitalization rate random variable. day_of_week_effect_rv : RandomVariable, optional - Day of the week effect. Should return a SampledValue with 7 + Day of the week effect. Should return a ArrayLike with 7 values. Defaults to a deterministic variable with jax.numpy.ones(7) (no effect). hospitalization_reporting_ratio_rv : RandomVariable, optional @@ -182,7 +183,7 @@ def validate( def sample( self, - latent_infections: SampledValue, + latent_infections: ArrayLike, **kwargs, ) -> HospitalAdmissionsSample: """ @@ -190,7 +191,7 @@ def sample( Parameters ---------- - latent_infections : SampledValue + latent_infections : ArrayLike Latent infections. Possibly the output of the `latent.Infections()`. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` @@ -201,47 +202,36 @@ def sample( HospitalAdmissionsSample """ - infection_hosp_rate, *_ = self.infection_hospitalization_ratio_rv( - **kwargs - ) + infection_hosp_rate = self.infection_hospitalization_ratio_rv(**kwargs) - ( - infection_to_admission_interval, - *_, - ) = self.infection_to_admission_interval_rv(**kwargs) + infection_to_admission_interval = ( + self.infection_to_admission_interval_rv(**kwargs) + ) latent_hospital_admissions = compute_delay_ascertained_incidence( - latent_infections.value, - infection_to_admission_interval.value, - infection_hosp_rate.value, + latent_infections, + infection_to_admission_interval, + infection_hosp_rate, ) # Applying the day of the week effect. For this we need to: # 1. Get the day of the week effect # 2. Identify the offset of the latent_infections # 3. Apply the day of the week effect to the latent_hospital_admissions - dow_effect_sampled = self.day_of_week_effect_rv(**kwargs, record=True)[ - 0 - ] + dow_effect_sampled = self.day_of_week_effect_rv(**kwargs) - if dow_effect_sampled.value.size != 7: + if dow_effect_sampled.size != 7: raise ValueError( "Day of the week effect should have 7 values. " - f"Got {dow_effect_sampled.value.size} instead." + f"Got {dow_effect_sampled.size} instead." ) - # Identifying the offset - if latent_infections.t_start is None: - inf_offset = 0 - else: - inf_offset = latent_infections.t_start - - inf_offset = (inf_offset + self.obs_data_first_day_of_the_week) % 7 + inf_offset = self.obs_data_first_day_of_the_week % 7 # Replicating the day of the week effect to match the number of # timepoints dow_effect = au.tile_until_n( - data=dow_effect_sampled.value, + data=dow_effect_sampled, n_timepoints=latent_hospital_admissions.size, offset=inf_offset, ) @@ -251,7 +241,7 @@ def sample( # Applying reporting probability latent_hospital_admissions = ( latent_hospital_admissions - * self.hospitalization_reporting_ratio_rv(**kwargs)[0].value + * self.hospitalization_reporting_ratio_rv(**kwargs) ) numpyro.deterministic( @@ -260,14 +250,6 @@ def sample( return HospitalAdmissionsSample( infection_hosp_rate=infection_hosp_rate, - latent_hospital_admissions=SampledValue( - value=latent_hospital_admissions, - t_start=self.t_start, - t_unit=self.t_unit, - ), - multiplier=SampledValue( - dow_effect, - t_start=self.t_start, - t_unit=self.t_unit, - ), + latent_hospital_admissions=latent_hospital_admissions, + multiplier=dow_effect, ) diff --git a/pyrenew/latent/infection_initialization_method.py b/pyrenew/latent/infection_initialization_method.py index 67d25fb3..116785f0 100644 --- a/pyrenew/latent/infection_initialization_method.py +++ b/pyrenew/latent/infection_initialization_method.py @@ -177,7 +177,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): An array of length ``n_timepoints`` with the number of initialized infections at each time point. """ I_pre_init = jnp.array(I_pre_init) - rate = jnp.array(self.rate_rv()[0].value) + rate = jnp.array(self.rate_rv()) initial_infections = I_pre_init * jnp.exp( rate * (jnp.arange(self.n_timepoints)[:, jnp.newaxis] - self.t_pre_init) diff --git a/pyrenew/latent/infection_initialization_process.py b/pyrenew/latent/infection_initialization_process.py index 74f8cd1e..ef9525f4 100644 --- a/pyrenew/latent/infection_initialization_process.py +++ b/pyrenew/latent/infection_initialization_process.py @@ -1,8 +1,10 @@ # numpydoc ignore=GL08 +from jax.typing import ArrayLike + from pyrenew.latent.infection_initialization_method import ( InfectionInitializationMethod, ) -from pyrenew.metaclass import RandomVariable, SampledValue, _assert_type +from pyrenew.metaclass import RandomVariable, _assert_type class InfectionInitializationProcess(RandomVariable): @@ -13,8 +15,6 @@ def __init__( name, I_pre_init_rv: RandomVariable, infection_init_method: InfectionInitializationMethod, - t_unit: int, - t_start: int | None = None, ) -> None: """Default class constructor for InfectionInitializationProcess @@ -26,14 +26,6 @@ def __init__( A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_init_method` uses this random variable in different ways. infection_init_method : InfectionInitializationMethod An `InfectionInitializationMethod` that generates the initial infections for the renewal process. - t_unit : int - The unit of time for the time series passed to `RandomVariable.set_timeseries`. - t_start : int, optional - The relative starting time of the time series. If `None`, the relative starting time is set to `-infection_init_method.n_timepoints`. - - Notes - ----- - The relative starting time of the time series (`t_start`) is set to `-infection_init_method.n_timepoints`. Returns ------- @@ -46,13 +38,6 @@ def __init__( self.I_pre_init_rv = I_pre_init_rv self.infection_init_method = infection_init_method self.name = name - if t_start is None: - t_start = -infection_init_method.n_timepoints - - self.set_timeseries( - t_start=t_start, - t_unit=t_unit, - ) @staticmethod def validate( @@ -79,26 +64,19 @@ def validate( InfectionInitializationMethod, ) - def sample(self) -> tuple: + def sample(self) -> ArrayLike: """Sample the Infection Initialization Process. Returns ------- - tuple - a tuple where the only element is an array with + ArrayLike the number of initialized infections at each time point. """ - (I_pre_init,) = self.I_pre_init_rv() + I_pre_init = self.I_pre_init_rv() infection_initialization = self.infection_init_method( - I_pre_init.value, + I_pre_init, ) - return ( - SampledValue( - infection_initialization, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return infection_initialization diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index ec8e6d74..887d0e5a 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -8,7 +8,7 @@ from jax.typing import ArrayLike import pyrenew.latent.infection_functions as inf -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class InfectionsSample(NamedTuple): @@ -97,10 +97,4 @@ def sample( reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample( - SampledValue( - post_initialization_infections, - t_start=self.t_start, - t_unit=self.t_unit, - ) - ) + return InfectionsSample(post_initialization_infections) diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 833087e2..e4b03594 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -7,11 +7,7 @@ import pyrenew.arrayutils as au import pyrenew.latent.infection_functions as inf -from pyrenew.metaclass import ( - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) +from pyrenew.metaclass import RandomVariable class InfectionsRtFeedbackSample(NamedTuple): @@ -20,14 +16,14 @@ class InfectionsRtFeedbackSample(NamedTuple): Attributes ---------- - post_initialization_infections : SampledValue | None, optional + post_initialization_infections : ArrayLike | None, optional The estimated latent infections. Defaults to None. - rt : SampledValue | None, optional + rt : ArrayLike | None, optional The adjusted reproduction number. Defaults to None. """ - post_initialization_infections: SampledValue | None = None - rt: SampledValue | None = None + post_initialization_infections: ArrayLike | None = None + rt: ArrayLike | None = None def __repr__(self): return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections}, rt={self.rt})" @@ -112,8 +108,8 @@ def validate( ------- None """ - _assert_sample_and_rtype(inf_feedback_strength) - _assert_sample_and_rtype(inf_feedback_pmf) + assert isinstance(inf_feedback_strength, RandomVariable) + assert isinstance(inf_feedback_pmf, RandomVariable) return None @@ -162,7 +158,7 @@ def sample( inf_feedback_strength = jnp.atleast_1d( self.infection_feedback_strength( **kwargs, - )[0].value + ) ) # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: @@ -180,9 +176,9 @@ def sample( ) # Sampling inf feedback pmf - inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) + inf_feedback_pmf = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) ( post_initialization_infections, @@ -196,10 +192,6 @@ def sample( ) return InfectionsRtFeedbackSample( - post_initialization_infections=SampledValue( - value=post_initialization_infections, - t_start=self.t_start, - t_unit=self.t_unit, - ), - rt=SampledValue(Rt_adj, t_start=self.t_start, t_unit=self.t_unit), + post_initialization_infections=post_initialization_infections, + rt=Rt_adj, ) diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 00e85b43..eb7119eb 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -3,7 +3,6 @@ """ from abc import ABCMeta, abstractmethod -from typing import NamedTuple, get_type_hints import jax import jax.random as jr @@ -46,193 +45,17 @@ def _assert_type(arg_name: str, value, expected_type) -> None: ) -def _assert_sample_and_rtype( - rp: "RandomVariable", skip_if_none: bool = True -) -> None: - """ - Return type-checking for RandomVariable's sample function - - Objects passed as `RandomVariable` should (a) have a `sample()` method that - (b) returns either a tuple or a named tuple. - - Parameters - ---------- - rp : RandomVariable - Random variable to check. - skip_if_none : bool, optional - When `True` it returns if `rp` is None. Defaults to True. - - Returns - ------- - None - - Raises - ------ - Exception - If rp is not a RandomVariable, does not have a sample function, or - does not return a tuple. Also occurs if rettype does not initialized - properly. - """ - - # Addressing the None case - if (rp is None) and (not skip_if_none): - Exception( - "The passed object cannot be None. It should be RandomVariable" - ) - elif skip_if_none and (rp is None): - return None - - if not isinstance(rp, RandomVariable): - raise Exception(f"{rp} is not an instance of RandomVariable.") - - # Otherwise, checking for the sample function (must have one) - # with a defined rtype. - try: - sfun = rp.sample - except Exception: - raise Exception( - f"The RandomVariable {rp} does not have a sample function." - ) # noqa: E722 - - # Getting the return annotation (if any) - rettype = get_type_hints(sfun).get("return", None) - - if rettype is None: - raise Exception( - f"The RandomVariable {rp} does not have return type " - + "annotation." - ) - - try: - if not isinstance(rettype(), tuple): - raise Exception( - f"The RandomVariable {rp}'s return type annotation is not" - + "a tuple" - ) - except Exception: - raise Exception( - f"There was a problem when trying to initialize {rettype}." - + "the rtype of the random variable should be a tuple or a namedtuple" - + " with default values." - ) - - return None - - -class SampledValue(NamedTuple): - """ - A container for a value sampled from a RandomVariable. - - Attributes - ---------- - value : ArrayLike, optional - The sampled value. - t_start : int, optional - The start time of the value. - t_unit : int, optional - The unit of time relative to the model's fundamental - (smallest) time unit. - """ - - value: ArrayLike | None = None - t_start: int | None = None - t_unit: int | None = None - - def __repr__(self): - return f"SampledValue(value={self.value}, t_start={self.t_start}, t_unit={self.t_unit})" - - class RandomVariable(metaclass=ABCMeta): """ Abstract base class for latent and observed random variables. - - Notes - ----- - RandomVariables in pyrenew can be time-aware, meaning that they can - have a t_start and t_unit attribute. These attributes - are expected to be used internally mostly for tasks including padding, - alignment of time series, and other time-aware operations. - - Both attributes give information about the output of the `sample()` method, - in other words, the relative time units of the returning value. - - Attributes - ---------- - t_start : int - The start of the time series. - t_unit : int - The unit of the time series relative to the model's fundamental - (smallest) time unit. e.g. if the fundamental unit is days, - then 1 corresponds to units of days and 7 to units of weeks. """ - t_start: int = None - t_unit: int = None - def __init__(self, **kwargs): """ Default constructor """ pass - def set_timeseries( - self, - t_start: int, - t_unit: int, - ) -> None: - """ - Set the time series start and unit - - Parameters - ---------- - t_start : int - The start of the time series relative to the - model time. It could be negative, indicating - that the `sample()` method returns timepoints - that occur prior to the model t = 0. - t_unit : int - The unit of the time series relative - to the model's fundamental (smallest) - time unit. e.g. if the fundamental unit - is days, then 1 corresponds to units of - days and 7 to units of weeks. - - Returns - ------- - None - """ - - # Either both values are None or both are not None - assert (t_unit is not None and t_start is not None) or ( - t_unit is None and t_start is None - ), ( - "Both t_start and t_unit should be None or not None. " - "Currently, t_start is {t_start} and t_unit is {t_unit}." - ) - - if t_unit is None and t_start is None: - return None - - # Timeseries unit should be a positive integer - assert isinstance( - t_unit, int - ), f"t_unit should be an integer. It is {type(t_unit)}." - - # Timeseries unit should be a positive integer - assert ( - t_unit > 0 - ), f"t_unit should be a positive integer. It is {t_unit}." - - # Data starts should be a positive integer - assert isinstance( - t_start, int - ), f"t_start should be an integer. It is {type(t_start)}." - - self.t_start = t_start - self.t_unit = t_unit - - return None - @abstractmethod def sample( self, diff --git a/pyrenew/model/admissionsmodel.py b/pyrenew/model/admissionsmodel.py index 1536ca79..7ce27845 100644 --- a/pyrenew/model/admissionsmodel.py +++ b/pyrenew/model/admissionsmodel.py @@ -7,12 +7,7 @@ from jax.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import ( - Model, - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) +from pyrenew.metaclass import Model, RandomVariable from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -22,23 +17,23 @@ class HospModelSample(NamedTuple): Attributes ---------- - Rt : SampledValue | None, optional + Rt : ArrayLike | None, optional The reproduction number over time. Defaults to None. - latent_infections : SampledValue | None, optional + latent_infections : ArrayLike | None, optional The estimated number of new infections over time. Defaults to None. - infection_hosp_rate : SampledValue | None, optional + infection_hosp_rate : ArrayLike | None, optional The infected hospitalization rate. Defaults to None. - latent_hosp_admissions : SampledValue | None, optional + latent_hosp_admissions : ArrayLike | None, optional The estimated latent hospitalizations. Defaults to None. - observed_hosp_admissions : SampledValue | None, optional + observed_hosp_admissions : ArrayLike | None, optional The sampled or observed hospital admissions. Defaults to None. """ - Rt: SampledValue | None = None - latent_infections: SampledValue | None = None - infection_hosp_rate: SampledValue | None = None - latent_hosp_admissions: SampledValue | None = None - observed_hosp_admissions: SampledValue | None = None + Rt: ArrayLike | None = None + latent_infections: ArrayLike | None = None + infection_hosp_rate: ArrayLike | None = None + latent_hosp_admissions: ArrayLike | None = None + observed_hosp_admissions: ArrayLike | None = None def __repr__(self): return ( @@ -125,15 +120,11 @@ def validate( Returns ------- None - - See Also - -------- - _assert_sample_and_rtype : Perform type-checking and verify RV """ - _assert_sample_and_rtype(latent_hosp_admissions_rv, skip_if_none=False) - _assert_sample_and_rtype( - hosp_admission_obs_process_rv, skip_if_none=True - ) + assert isinstance(latent_hosp_admissions_rv, RandomVariable) + if hosp_admission_obs_process_rv is not None: + assert isinstance(hosp_admission_obs_process_rv, RandomVariable) + return None def sample( @@ -193,7 +184,6 @@ def sample( padding=padding, **kwargs, ) - # Sampling the latent hospital admissions ( infection_hosp_rate, @@ -203,12 +193,8 @@ def sample( latent_infections=basic_model.latent_infections, **kwargs, ) - - ( - observed_hosp_admissions, - *_, - ) = self.hosp_admission_obs_process_rv( - mu=latent_hosp_admissions.value[-n_datapoints:], + observed_hosp_admissions = self.hosp_admission_obs_process_rv( + mu=latent_hosp_admissions[-n_datapoints:], obs=data_observed_hosp_admissions, **kwargs, ) diff --git a/pyrenew/model/rtinfectionsrenewalmodel.py b/pyrenew/model/rtinfectionsrenewalmodel.py index fc8881dd..6a7e49e3 100644 --- a/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/pyrenew/model/rtinfectionsrenewalmodel.py @@ -9,12 +9,7 @@ from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import ( - Model, - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) +from pyrenew.metaclass import Model, RandomVariable # Output class of the RtInfectionsRenewalModel @@ -24,17 +19,17 @@ class RtInfectionsRenewalSample(NamedTuple): Attributes ---------- - Rt : SampledValue | None, optional + Rt : ArrayLike | None, optional The reproduction number over time. Defaults to None. - latent_infections : SampledValue | None, optional + latent_infections : ArrayLike | None, optional The estimated latent infections. Defaults to None. - observed_infections : SampledValue | None, optional + observed_infections : ArrayLike | None, optional The sampled infections. Defaults to None. """ - Rt: SampledValue | None = None - latent_infections: SampledValue | None = None - observed_infections: SampledValue | None = None + Rt: ArrayLike | None = None + latent_infections: ArrayLike | None = None + observed_infections: ArrayLike | None = None def __repr__(self): return ( @@ -131,16 +126,12 @@ def validate( Returns ------- None - - See Also - -------- - _assert_sample_and_rtype : Perform type-checking and verify RV """ - _assert_sample_and_rtype(gen_int_rv, skip_if_none=False) - _assert_sample_and_rtype(I0_rv, skip_if_none=False) - _assert_sample_and_rtype(latent_infections_rv, skip_if_none=False) - _assert_sample_and_rtype(infection_obs_process_rv, skip_if_none=False) - _assert_sample_and_rtype(Rt_process_rv, skip_if_none=False) + assert isinstance(gen_int_rv, RandomVariable) + assert isinstance(I0_rv, RandomVariable) + assert isinstance(latent_infections_rv, RandomVariable) + assert isinstance(infection_obs_process_rv, RandomVariable) + assert isinstance(Rt_process_rv, RandomVariable) return None def sample( @@ -191,41 +182,37 @@ def sample( n_timepoints = n_datapoints + padding # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) - Rt, *_ = self.Rt_process_rv( + + Rt = self.Rt_process_rv( n=n_timepoints, **kwargs, ) # Getting the generation interval - gen_int, *_ = self.gen_int_rv(**kwargs) + gen_int = self.gen_int_rv(**kwargs) # Sampling initial infections - I0, *_ = self.I0_rv(**kwargs) + I0 = self.I0_rv(**kwargs) + # Sampling from the latent process - ( - post_initialization_latent_infections, - *_, - ) = self.latent_infections_rv( - Rt=Rt.value, - gen_int=gen_int.value, - I0=I0.value, + post_initialization_latent_infections = self.latent_infections_rv( + Rt=Rt, + gen_int=gen_int, + I0=I0, **kwargs, - ) - - observed_infections, *_ = self.infection_obs_process_rv( - mu=post_initialization_latent_infections.value[padding:], + ).post_initialization_infections + observed_infections = self.infection_obs_process_rv( + mu=post_initialization_latent_infections[padding:], obs=data_observed_infections, **kwargs, ) - all_latent_infections = SampledValue( - jnp.hstack([I0.value, post_initialization_latent_infections.value]) - ) - numpyro.deterministic( - "all_latent_infections", all_latent_infections.value + all_latent_infections = jnp.hstack( + [I0, post_initialization_latent_infections] ) + numpyro.deterministic("all_latent_infections", all_latent_infections) - numpyro.deterministic("Rt", Rt.value) + numpyro.deterministic("Rt", Rt) return RtInfectionsRenewalSample( Rt=Rt, diff --git a/pyrenew/observation/negativebinomial.py b/pyrenew/observation/negativebinomial.py index 2029eb24..c15e96fc 100644 --- a/pyrenew/observation/negativebinomial.py +++ b/pyrenew/observation/negativebinomial.py @@ -6,7 +6,7 @@ import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class NegativeBinomialObservation(RandomVariable): @@ -70,7 +70,7 @@ def sample( mu: ArrayLike, obs: ArrayLike | None = None, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the negative binomial distribution @@ -85,23 +85,17 @@ def sample( Returns ------- - tuple + ArrayLike """ - concentration, *_ = self.concentration_rv.sample() + concentration = self.concentration_rv.sample() negative_binomial_sample = numpyro.sample( name=self.name, fn=dist.NegativeBinomial2( mean=mu + self.eps, - concentration=concentration.value, + concentration=concentration, ), obs=obs, ) - return ( - SampledValue( - negative_binomial_sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return negative_binomial_sample diff --git a/pyrenew/observation/poisson.py b/pyrenew/observation/poisson.py index f9ee29d5..23ae4a6b 100644 --- a/pyrenew/observation/poisson.py +++ b/pyrenew/observation/poisson.py @@ -6,7 +6,7 @@ import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class PoissonObservation(RandomVariable): @@ -49,7 +49,7 @@ def sample( mu: ArrayLike, obs: ArrayLike | None = None, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the Poisson process @@ -64,7 +64,7 @@ def sample( Returns ------- - tuple + ArrayLike """ poisson_sample = numpyro.sample( @@ -72,10 +72,4 @@ def sample( fn=dist.Poisson(rate=mu + self.eps), obs=obs, ) - return ( - SampledValue( - poisson_sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return poisson_sample diff --git a/pyrenew/process/ar.py b/pyrenew/process/ar.py index 5153f3e6..c7473395 100644 --- a/pyrenew/process/ar.py +++ b/pyrenew/process/ar.py @@ -6,7 +6,7 @@ from jax.typing import ArrayLike from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable from pyrenew.process.iidrandomsequence import StandardNormalSequence @@ -36,7 +36,7 @@ def sample( init_vals: ArrayLike, noise_sd: float | ArrayLike, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the AR process @@ -62,9 +62,7 @@ def sample( Returns ------- - tuple - With a single SampledValue containing an - array of shape (n,). + ArrayLike """ noise_sd_arr = jnp.atleast_1d(noise_sd) if not noise_sd_arr.shape == (1,): @@ -96,8 +94,8 @@ def sample( f"order {order}" ) - raw_noise, *_ = self.noise_rv_(n=n, **kwargs) - noise = noise_sd_arr * raw_noise.value + raw_noise = self.noise_rv_(n=n, **kwargs) + noise = noise_sd_arr * raw_noise def transition(recent_vals, next_noise): # numpydoc ignore=GL08 new_term = jnp.dot(autoreg, recent_vals) + next_noise @@ -107,13 +105,7 @@ def transition(recent_vals, next_noise): # numpydoc ignore=GL08 return new_recent_vals, new_term last, ts = scan(transition, init_vals, noise) - return ( - SampledValue( - jnp.hstack([init_vals, ts]), - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return jnp.hstack([init_vals, ts]) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/pyrenew/process/differencedprocess.py b/pyrenew/process/differencedprocess.py index 99cb87fc..d932ee33 100644 --- a/pyrenew/process/differencedprocess.py +++ b/pyrenew/process/differencedprocess.py @@ -6,7 +6,7 @@ from jax.typing import ArrayLike from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class DifferencedProcess(RandomVariable): @@ -151,7 +151,7 @@ def sample( *args, fundamental_process_init_vals: ArrayLike = None, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the process @@ -185,9 +185,8 @@ def sample( Returns ------- - SampledValue - Whose value entry is a single array representing the - undifferenced timeseries + ArrayLike + representing the undifferenced timeseries """ if not isinstance(n, int): raise ValueError("n must be an integer. " f"Got {type(n)}") @@ -196,23 +195,17 @@ def sample( n_diffs = n - self.differencing_order if n_diffs > 0: - diff_samp, *_ = self.fundamental_process.sample( + diff_samp = self.fundamental_process.sample( *args, n=n_diffs, init_vals=fundamental_process_init_vals, **kwargs, ) - diffs = diff_samp.value + diffs = diff_samp else: diffs = jnp.array([]) integrated_ts = self.integrate(init_vals, diffs)[:n] - return ( - SampledValue( - value=integrated_ts, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return integrated_ts @staticmethod def validate(): diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index 5f64d7f1..cdc93fae 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -1,9 +1,10 @@ # numpydoc ignore=GL08 import numpyro.distributions as dist +from jax.typing import ArrayLike from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable from pyrenew.randomvariable import DistributionalVariable @@ -42,7 +43,7 @@ def __init__( def sample( self, n: int, *args, vectorize: bool = False, **kwargs - ) -> tuple: + ) -> ArrayLike: """ Sample an IID random sequence. @@ -69,22 +70,18 @@ def sample( Returns ------- - tuple[SampledValue] - Whose value is an array of `n` - samples from `self.distribution` + ArrayLike + `n` samples from `self.distribution` """ if vectorize and hasattr(self.element_rv, "expand_by"): - result, *_ = self.element_rv.expand_by((n,)).sample( - *args, **kwargs - ) - result = result.value + result = self.element_rv.expand_by((n,)).sample(*args, **kwargs) else: def transition(_carry, _x): # numpydoc ignore=GL08 - el, *_ = self.element_rv.sample(*args, **kwargs) - return None, el.value + el = self.element_rv.sample(*args, **kwargs) + return None, el _, result = scan( transition, @@ -93,13 +90,7 @@ def transition(_carry, _x): length=n, ) - return ( - SampledValue( - result, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return result @staticmethod def validate(): diff --git a/pyrenew/process/periodiceffect.py b/pyrenew/process/periodiceffect.py index 1a8deb4f..df99d851 100644 --- a/pyrenew/process/periodiceffect.py +++ b/pyrenew/process/periodiceffect.py @@ -1,30 +1,8 @@ # numpydoc ignore=GL08 -from typing import NamedTuple import pyrenew.arrayutils as au -from pyrenew.metaclass import ( - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) - - -class PeriodicEffectSample(NamedTuple): - """ - A container for holding the output from - `process.PeriodicEffect()`. - - Attributes - ---------- - value: SampledValue - The sampled value. - """ - - value: SampledValue - - def __repr__(self): - return f"PeriodicEffectSample(value={self.value})" +from pyrenew.metaclass import RandomVariable class PeriodicEffect(RandomVariable): @@ -36,8 +14,6 @@ def __init__( self, offset: int, quantity_to_broadcast: RandomVariable, - t_start: int, - t_unit: int, ): """ Default constructor for PeriodicEffect class. @@ -49,10 +25,6 @@ def __init__( period_size - 1. quantity_to_broadcast : RandomVariable Values to be broadcasted (repeated or tiled). - t_start : int - Start time of the process. - t_unit : int - Unit of time relative to the model's fundamental (smallest) time unit. Returns ------- @@ -63,11 +35,6 @@ def __init__( self.offset = offset - self.set_timeseries( - t_start=t_start, - t_unit=t_unit, - ) - self.quantity_to_broadcast = quantity_to_broadcast @staticmethod @@ -85,7 +52,7 @@ def validate(quantity_to_broadcast: RandomVariable) -> None: None """ - _assert_sample_and_rtype(quantity_to_broadcast) + assert isinstance(quantity_to_broadcast, RandomVariable) return None @@ -102,19 +69,13 @@ def sample(self, duration: int, **kwargs): Returns ------- - PeriodicEffectSample + ArrayLike """ - return PeriodicEffectSample( - value=SampledValue( - au.tile_until_n( - data=self.quantity_to_broadcast.sample(**kwargs)[0].value, - n_timepoints=duration, - offset=self.offset, - ), - t_start=self.t_start, - t_unit=self.t_unit, - ) + return au.tile_until_n( + data=self.quantity_to_broadcast.sample(**kwargs), + n_timepoints=duration, + offset=self.offset, ) @@ -127,7 +88,6 @@ def __init__( self, offset: int, quantity_to_broadcast: RandomVariable, - t_start: int, ): """ Default constructor for DayOfWeekEffect class. @@ -139,8 +99,6 @@ def __init__( 6. quantity_to_broadcast : RandomVariable Values to be broadcasted (repeated or tiled). - t_start : int - Start time of the process. Returns ------- @@ -152,8 +110,6 @@ def __init__( super().__init__( offset=offset, quantity_to_broadcast=quantity_to_broadcast, - t_start=t_start, - t_unit=1, ) return None diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 9186b9ef..25fb1d9a 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -1,34 +1,13 @@ # numpydoc ignore=GL08 -from typing import NamedTuple import jax.numpy as jnp +from jax.typing import ArrayLike import pyrenew.arrayutils as au -from pyrenew.metaclass import ( - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) +from pyrenew.metaclass import RandomVariable from pyrenew.process import ARProcess, DifferencedProcess -class RtPeriodicDiffARProcessSample(NamedTuple): - """ - A container for holding the output from - `process.RtPeriodicDiffARProcess()`. - - Attributes - ---------- - rt : SampledValue, optional - The sampled Rt. - """ - - rt: SampledValue | None = None - - def __repr__(self): - return f"RtPeriodicDiffARProcessSample(rt={self.rt})" - - class RtPeriodicDiffARProcess(RandomVariable): r""" Periodic Rt with autoregressive first differences @@ -130,9 +109,9 @@ def validate( None """ - _assert_sample_and_rtype(log_rt_rv) - _assert_sample_and_rtype(autoreg_rv) - _assert_sample_and_rtype(periodic_diff_sd_rv) + assert isinstance(log_rt_rv, RandomVariable) + assert isinstance(autoreg_rv, RandomVariable) + assert isinstance(periodic_diff_sd_rv, RandomVariable) return None @@ -140,7 +119,7 @@ def sample( self, duration: int, **kwargs, - ) -> RtPeriodicDiffARProcessSample: + ) -> ArrayLike: """ Samples the periodic Rt with autoregressive difference. @@ -154,14 +133,14 @@ def sample( Returns ------- - RtPeriodicDiffARProcessSample - Named tuple with "rt". + ArrayLike + Sampled Rt values. """ # Initial sample - log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value - b = self.autoreg_rv.sample(**kwargs)[0].value - s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value + log_rt_rv = self.log_rt_rv.sample(**kwargs) + b = self.autoreg_rv.sample(**kwargs) + s_r = self.periodic_diff_sd_rv.sample(**kwargs) # How many periods to sample? n_periods = (duration + self.period_size - 1) // self.period_size @@ -176,19 +155,13 @@ def sample( fundamental_process_init_vals=jnp.array( [log_rt_rv[1] - log_rt_rv[0]] ), - )[0] - - return RtPeriodicDiffARProcessSample( - rt=SampledValue( - au.repeat_until_n( - data=jnp.exp(log_rt.value), - n_timepoints=duration, - offset=self.offset, - period_size=self.period_size, - ), - t_start=self.t_start, - t_unit=self.t_unit, - ), + ) + + return au.repeat_until_n( + data=jnp.exp(log_rt), + n_timepoints=duration, + offset=self.offset, + period_size=self.period_size, ) diff --git a/pyrenew/randomvariable/distributionalvariable.py b/pyrenew/randomvariable/distributionalvariable.py index 671dde08..20ec94a2 100644 --- a/pyrenew/randomvariable/distributionalvariable.py +++ b/pyrenew/randomvariable/distributionalvariable.py @@ -7,7 +7,7 @@ from jax.typing import ArrayLike from numpyro.infer.reparam import Reparam -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable class DynamicDistributionalVariable(RandomVariable): @@ -97,7 +97,7 @@ def sample( *args, obs: ArrayLike = None, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the distributional rv. @@ -113,8 +113,8 @@ def sample( Returns ------- - SampledValue - Containing a sample from the distribution. + ArrayLike + a sample from the distribution. """ distribution = self.distribution_constructor(*args, **kwargs) if self.expand_by_shape is not None: @@ -125,13 +125,7 @@ def sample( fn=distribution, obs=obs, ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return sample def expand_by(self, sample_shape) -> Self: """ @@ -224,7 +218,7 @@ def sample( self, obs: ArrayLike | None = None, **kwargs, - ) -> tuple: + ) -> ArrayLike: """ Sample from the distribution. @@ -239,7 +233,7 @@ def sample( Returns ------- - SampledValue + ArrayLike Containing a sample from the distribution. """ with numpyro.handlers.reparam(config=self.reparam_dict): @@ -248,13 +242,7 @@ def sample( fn=self.distribution, obs=obs, ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return sample def expand_by(self, sample_shape) -> Self: """ diff --git a/pyrenew/randomvariable/transformedvariable.py b/pyrenew/randomvariable/transformedvariable.py index 36519a24..59f5c6cc 100644 --- a/pyrenew/randomvariable/transformedvariable.py +++ b/pyrenew/randomvariable/transformedvariable.py @@ -2,7 +2,7 @@ import numpyro -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable from pyrenew.transformation import Transform @@ -70,18 +70,17 @@ def sample(self, record=False, **kwargs) -> tuple: """ untransformed_values = self.base_rv.sample(**kwargs) + + if not isinstance(untransformed_values, tuple): + untransformed_values = (untransformed_values,) + transformed_values = tuple( - SampledValue( - t(uv.value), - t_start=self.t_start, - t_unit=self.t_unit, - ) - for t, uv in zip(self.transforms, untransformed_values) + t(uv) for t, uv in zip(self.transforms, untransformed_values) ) if record: if len(untransformed_values) == 1: - numpyro.deterministic(self.name, transformed_values[0].value) + numpyro.deterministic(self.name, transformed_values) else: suffixes = ( untransformed_values._fields @@ -89,7 +88,10 @@ def sample(self, record=False, **kwargs) -> tuple: else range(len(transformed_values)) ) for suffix, tv in zip(suffixes, transformed_values): - numpyro.deterministic(f"{self.name}_{suffix}", tv.value) + numpyro.deterministic(f"{self.name}_{suffix}", tv) + + if len(transformed_values) == 1: + transformed_values = transformed_values[0] return transformed_values diff --git a/pyrenew/regression.py b/pyrenew/regression.py index 9348958e..e9c4eced 100755 --- a/pyrenew/regression.py +++ b/pyrenew/regression.py @@ -12,7 +12,6 @@ from jax.typing import ArrayLike import pyrenew.transformation as t -from pyrenew.metaclass import SampledValue class AbstractRegressionPrediction(metaclass=ABCMeta): # numpydoc ignore=GL08 @@ -39,19 +38,19 @@ class GLMPredictionSample(NamedTuple): Attributes ---------- - prediction : SampledValue | None, optional + prediction : ArrayLike | None, optional Transformed predictions. Defaults to None. - intercept : SampledValue | None, optional + intercept : ArrayLike | None, optional Sampled intercept from intercept priors. Defaults to None. - coefficients : SampledValue | None, optional + coefficients : ArrayLike | None, optional Prediction coefficients generated from coefficients priors. Defaults to None. """ - prediction: SampledValue | None = None - intercept: SampledValue | None = None - coefficients: SampledValue | None = None + prediction: ArrayLike | None = None + intercept: ArrayLike | None = None + coefficients: ArrayLike | None = None def __repr__(self): return ( @@ -182,9 +181,9 @@ def sample(self, predictor_values: ArrayLike) -> GLMPredictionSample: prediction = self.predict(intercept, coefficients, predictor_values) return GLMPredictionSample( - prediction=SampledValue(prediction), - intercept=SampledValue(intercept), - coefficients=SampledValue(coefficients), + prediction=prediction, + intercept=intercept, + coefficients=coefficients, ) def __call__(self, *args, **kwargs): diff --git a/test/test_ar_process.py b/test/test_ar_process.py index b1df31af..b16ebbc4 100755 --- a/test/test_ar_process.py +++ b/test/test_ar_process.py @@ -98,11 +98,11 @@ def test_ar_samples_correctly_distributed(): with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it - long_ts, *_ = ar( + long_ts = ar( n=10000, init_vals=ar_inits, autoreg=jnp.array([0.75]), noise_sd=noise_sd, ) - assert_almost_equal(long_ts.value[0], ar_inits) - assert jnp.abs(long_ts.value[-1]) < 4 * noise_sd + assert_almost_equal(long_ts[0], ar_inits) + assert jnp.abs(long_ts[-1]) < 4 * noise_sd diff --git a/test/test_assert_sample_and_rtype.py b/test/test_assert_sample_and_rtype.py deleted file mode 100644 index d0f9ee8a..00000000 --- a/test/test_assert_sample_and_rtype.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Tests for _assert_sample_and_rtype method -""" - -import jax.numpy as jnp -import numpyro.distributions as dist -import pytest -from numpy.testing import assert_equal - -from pyrenew.deterministic import DeterministicVariable, NullObservation -from pyrenew.metaclass import ( - RandomVariable, - SampledValue, - _assert_sample_and_rtype, -) -from pyrenew.randomvariable import DistributionalVariable - - -class RVreturnsTuple(RandomVariable): - """ - Class for a RandomVariable with - sample value 1 - """ - - def sample(self, **kwargs) -> tuple: - """ - Deterministic sampling method that returns 1 - - Returns - ------- - ( - SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), - ) - """ - - return ( - SampledValue(value=1, t_start=self.t_start, t_unit=self.t_unit), - ) - - def validate(self): - """ - No validation. - - Returns - ------- - None - """ - return None - - -class RVnoAnnotation(RandomVariable): - """ - Class for a RandomVariable with - sample value 1 - """ - - def sample(self, **kwargs): - """ - Deterministic sampling method that returns 1 - - Returns - ------- - ( - SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), - ) - """ - - return ( - SampledValue(value=1, t_start=self.t_start, t_unit=self.t_unit), - ) - - def validate(self): - """ - No validation. - - Returns - ------- - None - """ - return None - - -def test_none_rv(): # numpydoc ignore=GL08 - assert_equal(_assert_sample_and_rtype(None), None) - - with pytest.raises( - Exception, match="None is not an instance of RandomVariable" - ): - _assert_sample_and_rtype(None, skip_if_none=False) - - -def test_input_rv(): # numpydoc ignore=GL08 - valid_rv = [ - NullObservation(), - DeterministicVariable(name="rv1", value=jnp.array([1, 2, 3, 4])), - DistributionalVariable(name="rv2", distribution=dist.Normal(0, 1)), - ] - not_rv = jnp.array([1]) - - for rv in valid_rv: - _assert_sample_and_rtype(rv) - - with pytest.raises( - Exception, - match="is not an instance of RandomVariable", - ): - _assert_sample_and_rtype(not_rv) - - -def test_sample_return(): # numpydoc ignore=GL08 - """ - Test that RandomVariable has a sample method with return type tuple - """ - - rv3 = RVreturnsTuple() - _assert_sample_and_rtype(rv3) - - rv4 = RVnoAnnotation() - with pytest.raises( - Exception, - match="does not have return type annotation", - ): - _assert_sample_and_rtype(rv4) diff --git a/test/test_deterministic.py b/test/test_deterministic.py index 69892d4f..ccbf0c76 100644 --- a/test/test_deterministic.py +++ b/test/test_deterministic.py @@ -9,17 +9,15 @@ from pyrenew.deterministic import ( DeterministicPMF, - DeterministicProcess, DeterministicVariable, - NullProcess, NullVariable, ) def test_deterministic(): """ - Test the DeterministicVariable, DeterministicPMF, and - DeterministicProcess classes in the deterministic module. + Test the DeterministicVariable and DeterministicPMF classes in the + deterministic module. """ var1 = DeterministicVariable( @@ -33,12 +31,10 @@ def test_deterministic(): var2 = DeterministicPMF( name="var2", value=jnp.array([0.25, 0.25, 0.2, 0.3]) ) - var3 = DeterministicProcess(name="var3", value=jnp.array([1, 2, 3, 4])) - var4 = NullVariable() - var5 = NullProcess() + var3 = NullVariable() testing.assert_array_equal( - var1()[0].value, + var1(), jnp.array( [ 1, @@ -46,27 +42,11 @@ def test_deterministic(): ), ) testing.assert_array_equal( - var2()[0].value, + var2(), jnp.array([0.25, 0.25, 0.2, 0.3]), ) - testing.assert_array_equal( - var3(duration=5)[0].value, - jnp.array([1, 2, 3, 4, 4]), - ) - - testing.assert_array_equal( - var3(duration=3)[0].value, - jnp.array( - [ - 1, - 2, - 3, - ] - ), - ) - testing.assert_equal(var4()[0].value, None) - testing.assert_equal(var5(duration=1)[0].value, None) + testing.assert_equal(var3(), None) def test_deterministic_validation(): diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index ba4e95c9..07ad1854 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -189,18 +189,18 @@ def test_differenced_process_sample( n_fail = -1 n_fail_alt = 0 with numpyro.handlers.seed(rng_seed=6723): - samp, *_ = proc.sample(n=n_long, init_vals=init_diff_vals) - samp_alt, *_ = proc.sample(n=n_long_alt, init_vals=init_diff_vals) - samp_one_diff, *_ = proc.sample(n=n_one_diff, init_vals=init_diff_vals) - samp_no_diffs, *_ = proc.sample(n=n_no_diffs, init_vals=init_diff_vals) - samp_no_diffs_alt, *_ = proc.sample( + samp = proc.sample(n=n_long, init_vals=init_diff_vals) + samp_alt = proc.sample(n=n_long_alt, init_vals=init_diff_vals) + samp_one_diff = proc.sample(n=n_one_diff, init_vals=init_diff_vals) + samp_no_diffs = proc.sample(n=n_no_diffs, init_vals=init_diff_vals) + samp_no_diffs_alt = proc.sample( n=n_no_diffs_alt, init_vals=init_diff_vals ) - assert samp.value.shape == (n_long,) - assert samp_alt.value.shape == (n_long_alt,) - assert samp_one_diff.value.shape == (n_one_diff,) - assert samp_no_diffs.value.shape == (n_no_diffs,) - assert samp_no_diffs_alt.value.shape == (n_no_diffs_alt,) + assert samp.shape == (n_long,) + assert samp_alt.shape == (n_long_alt,) + assert samp_one_diff.shape == (n_one_diff,) + assert samp_no_diffs.shape == (n_no_diffs,) + assert samp_no_diffs_alt.shape == (n_no_diffs_alt,) with numpyro.handlers.seed(rng_seed=7834): with pytest.raises(ValueError, match="must be positive"): @@ -267,5 +267,5 @@ def test_manual_difference_process_sample( differencing_order=len(inits), fundamental_process=fundamental_process, ) - result, *_ = proc.sample(n=n, init_vals=inits) - assert_array_almost_equal(result.value, expected_solution) + result = proc.sample(n=n, init_vals=inits) + assert_array_almost_equal(result, expected_solution) diff --git a/test/test_distributional_rv.py b/test/test_distributional_rv.py index cebe6f8e..46c3d44d 100644 --- a/test/test_distributional_rv.py +++ b/test/test_distributional_rv.py @@ -153,7 +153,7 @@ def test_sampling_equivalent(dist, params): assert isinstance(static, StaticDistributionalVariable) assert isinstance(dynamic, DynamicDistributionalVariable) with numpyro.handlers.seed(rng_seed=5): - static_samp, *_ = static() + static_samp = static() with numpyro.handlers.seed(rng_seed=5): - dynamic_samp, *_ = dynamic(**params) - assert_array_equal(static_samp.value, dynamic_samp.value) + dynamic_samp = dynamic(**params) + assert_array_equal(static_samp, dynamic_samp) diff --git a/test/test_forecast.py b/test/test_forecast.py index d8d1d55c..df70436e 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -30,7 +30,6 @@ def test_forecast(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() observed_infections = PoissonObservation(name="poisson_rv") @@ -52,7 +51,7 @@ def test_forecast(): model.run( num_warmup=5, num_samples=5, - data_observed_infections=model_sample.observed_infections.value, + data_observed_infections=model_sample.observed_infections, rng_key=jr.key(54), ) diff --git a/test/test_iid_random_sequence.py b/test/test_iid_random_sequence.py index 73b683aa..9630a642 100755 --- a/test/test_iid_random_sequence.py +++ b/test/test_iid_random_sequence.py @@ -6,7 +6,6 @@ import pytest from scipy.stats import kstest -from pyrenew.metaclass import SampledValue from pyrenew.process import IIDRandomSequence, StandardNormalSequence from pyrenew.randomvariable import ( DistributionalVariable, @@ -37,22 +36,21 @@ def test_iidrandomsequence_with_dist_rv(distribution, n): expected_shape = tuple([n] + [x for x in distribution.batch_shape]) with numpyro.handlers.seed(rng_seed=62): - ans_vec, *_ = rseq.sample(n=n, vectorize=True) - ans_serial, *_ = rseq.sample(n=n, vectorize=False) + ans_vec = rseq.sample(n=n, vectorize=True) + ans_serial = rseq.sample(n=n, vectorize=False) # check that samples are the right type for ans in [ans_serial, ans_vec]: - assert isinstance(ans, SampledValue) # check that the samples are of the right shape - assert ans.value.shape == expected_shape + assert ans.shape == expected_shape # vectorized and unvectorized sampling should # not give the same answer # but they should give similar distributions - assert all(ans_serial.value.flatten() != ans_vec.value.flatten()) + assert all(ans_serial.flatten() != ans_vec.flatten()) if expected_shape == (n,): - kstest_out = kstest(ans_serial.value, ans_vec.value) + kstest_out = kstest(ans_serial, ans_vec) assert kstest_out.pvalue > 0.01 @@ -72,9 +70,8 @@ def test_standard_normal_sequence(): # should be sampleable with numpyro.handlers.seed(rng_seed=67): - ans, *_ = norm_seq.sample(n=50000) + ans = norm_seq.sample(n=50000) - assert isinstance(ans, SampledValue) # samples should be approximately standard normal - kstest_out = kstest(ans.value, "norm", (0, 1)) + kstest_out = kstest(ans, "norm", (0, 1)) assert kstest_out.pvalue > 0.01 diff --git a/test/test_infection_initialization_method.py b/test/test_infection_initialization_method.py index ce9796f5..51b1616a 100644 --- a/test/test_infection_initialization_method.py +++ b/test/test_infection_initialization_method.py @@ -20,8 +20,8 @@ def test_initialize_infections_exponential(): rate_RV = DeterministicVariable(name="rate_RV", value=np.array([0.5, 0.1])) rate_scalar_RV = DeterministicVariable(name="rate_RV", value=0.5) - rate = rate_RV()[0].value - rate_scalar = rate_scalar_RV()[0].value + rate = rate_RV() + rate_scalar = rate_scalar_RV() I_pre_init = np.array([5.0, 10.0]) I_pre_init_scalar = 5.0 @@ -131,8 +131,8 @@ def test_initialize_infections_zero_pad(): n_timepoints = 10 I_pre_init_RV = DeterministicVariable(name="I_pre_init_RV", value=10.0) - (I_pre_init,) = I_pre_init_RV() - I_pre_init = I_pre_init.value + I_pre_init = I_pre_init_RV() + I_pre_init = I_pre_init infections = InitializeInfectionsZeroPad( n_timepoints @@ -149,8 +149,8 @@ def test_initialize_infections_zero_pad(): name="I_pre_init_RV", value=np.array([10.0, 10.0]) ) - (I_pre_init_2,) = I_pre_init_RV_2() - I_pre_init_2 = I_pre_init_2.value + I_pre_init_2 = I_pre_init_RV_2() + I_pre_init_2 = I_pre_init_2 infections_2 = InitializeInfectionsZeroPad( n_timepoints diff --git a/test/test_infection_initialization_process.py b/test/test_infection_initialization_process.py index 069299cd..fcf94708 100644 --- a/test/test_infection_initialization_process.py +++ b/test/test_infection_initialization_process.py @@ -22,7 +22,6 @@ def test_infection_initialization_process(): "zero_pad_model", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints), - t_unit=1, ) exp_model = InfectionInitializationProcess( @@ -31,14 +30,12 @@ def test_infection_initialization_process(): InitializeInfectionsExponentialGrowth( n_timepoints, DeterministicVariable(name="rate", value=0.5) ), - t_unit=1, ) vec_model = InfectionInitializationProcess( "vec_model", DeterministicVariable(name="I0", value=jnp.arange(n_timepoints)), InitializeInfectionsFromVec(n_timepoints), - t_unit=1, ) for model in [zero_pad_model, exp_model, vec_model]: @@ -51,7 +48,6 @@ def test_infection_initialization_process(): "vec_model", jnp.arange(n_timepoints), InitializeInfectionsFromVec(n_timepoints), - t_unit=1, ) with pytest.raises(TypeError): @@ -59,5 +55,4 @@ def test_infection_initialization_process(): "vec_model", DeterministicVariable(name="I0", value=jnp.arange(n_timepoints)), 3, - t_unit=1, ) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 8d56b885..d241783c 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -96,10 +96,10 @@ def test_infectionsrtfeedback(): ) assert_array_equal( - samp1.post_initialization_infections.value, - samp2.post_initialization_infections.value, + samp1.post_initialization_infections, + samp2.post_initialization_infections, ) - assert_array_equal(samp1.rt.value, Rt) + assert_array_equal(samp1.rt, Rt) return None @@ -143,18 +143,18 @@ def test_infectionsrtfeedback_feedback(): gen_int=gen_int, Rt=Rt, I0=I0, - inf_feedback_strength=inf_feed_strength()[0].value, - inf_feedback_pmf=inf_feedback_pmf()[0].value, + inf_feedback_strength=inf_feed_strength(), + inf_feedback_pmf=inf_feedback_pmf(), ) assert not jnp.array_equal( - samp1.post_initialization_infections.value, - samp2.post_initialization_infections.value, + samp1.post_initialization_infections, + samp2, ) assert_array_almost_equal( - samp1.post_initialization_infections.value, + samp1.post_initialization_infections, res["post_initialization_infections"], ) - assert_array_almost_equal(samp1.rt.value, res["rt"]) + assert_array_almost_equal(samp1.rt, res["rt"]) return None diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index c2a260f7..e652e382 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -9,7 +9,6 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import SampledValue from pyrenew.randomvariable import DistributionalVariable @@ -25,7 +24,7 @@ def test_admissions_sample(): n_steps = 30 with numpyro.handlers.seed(rng_seed=223): - sim_rt = rt(n=n_steps)[0].value + sim_rt = rt(n=n_steps) gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) inf_hosp_int_array = jnp.array( @@ -71,16 +70,14 @@ def test_admissions_sample(): with numpyro.handlers.seed(rng_seed=223): sim_hosp_1 = hosp1( - latent_infections=SampledValue( - value=jnp.hstack( - [i0, inf_sampled1.post_initialization_infections.value] - ) + latent_infections=jnp.hstack( + [i0, inf_sampled1.post_initialization_infections] ) ) testing.assert_array_less( - sim_hosp_1.latent_hospital_admissions.value[-n_steps:], - inf_sampled1[0].value, + sim_hosp_1.latent_hospital_admissions[-n_steps:], + inf_sampled1.post_initialization_infections, ) inf_hosp2 = jnp.ones(30) inf_hosp2 = DeterministicPMF("i2h", inf_hosp2 / sum(inf_hosp2)) @@ -116,17 +113,17 @@ def test_admissions_sample(): obs_data_first_day_of_the_week=2, ) - inf_sampled2 = SampledValue(jnp.ones(30)) + inf_sampled2 = jnp.ones(30) with numpyro.handlers.seed(rng_seed=223): - sim_hosp_2a = hosp2a(latent_infections=inf_sampled2).multiplier.value + sim_hosp_2a = hosp2a(latent_infections=inf_sampled2).multiplier with numpyro.handlers.seed(rng_seed=223): - sim_hosp_2b = hosp2b(latent_infections=inf_sampled2).multiplier.value + sim_hosp_2b = hosp2b(latent_infections=inf_sampled2).multiplier with numpyro.handlers.seed(rng_seed=223): with testing.assert_raises(ValueError): - hosp3b(latent_infections=inf_sampled2).multiplier.value + hosp3b(latent_infections=inf_sampled2).multiplier testing.assert_array_equal( sim_hosp_2a[2 : (sim_hosp_2b.size - 2)], diff --git a/test/test_latent_infections.py b/test/test_latent_infections.py index f90819f2..0edd3231 100755 --- a/test/test_latent_infections.py +++ b/test/test_latent_infections.py @@ -19,24 +19,30 @@ def test_infections_as_deterministic(): rt = SimpleRt() with numpyro.handlers.seed(rng_seed=223): - sim_rt, *_ = rt(n=30) + sim_rt = rt(n=30) gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) inf1 = Infections() obs = dict( - Rt=sim_rt.value, + Rt=sim_rt, I0=jnp.zeros(gen_int.size), gen_int=gen_int, ) with numpyro.handlers.seed(rng_seed=223): + Infections()( + Rt=sim_rt, + I0=jnp.zeros(gen_int.size), + gen_int=gen_int, + ) + inf_sampled1 = inf1(**obs) inf_sampled2 = inf1(**obs) testing.assert_array_equal( - inf_sampled1.post_initialization_infections.value, - inf_sampled2.post_initialization_infections.value, + inf_sampled1, + inf_sampled2, ) # Check that Initial infections vector must be at least as long as the generation interval. diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index d83906cb..dfc33c3b 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -1,6 +1,5 @@ # numpydoc ignore=GL08 - from test.utils import SimpleRt import jax.numpy as jnp @@ -37,7 +36,6 @@ def test_model_basicrenewal_no_timepoints_or_observations(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() @@ -73,7 +71,6 @@ def test_model_basicrenewal_both_timepoints_and_observations(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() @@ -116,7 +113,6 @@ def test_model_basicrenewal_no_obs_model(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() @@ -144,21 +140,21 @@ def test_model_basicrenewal_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_equal(model0_samp.Rt.value, model1_samp.Rt.value) + np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) np.testing.assert_array_equal( - model0_samp.latent_infections.value, - model1_samp.latent_infections.value, + model0_samp.latent_infections, + model1_samp.latent_infections, ) np.testing.assert_array_equal( - model0_samp.observed_infections.value, - model1_samp.observed_infections.value, + model0_samp.observed_infections, + model1_samp.observed_infections, ) model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_infections=model0_samp.latent_infections.value, + data_observed_infections=model0_samp.latent_infections, ) inf = model0.spread_draws(["all_latent_infections"]) @@ -187,7 +183,6 @@ def test_model_basicrenewal_with_obs_model(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() @@ -209,15 +204,15 @@ def test_model_basicrenewal_with_obs_model(): model1_samp = model1.sample(n_datapoints=30) print(model1_samp) - print(model1_samp.Rt.value.size) - print(model1_samp.latent_infections.value.size) - print(model1_samp.observed_infections.value.size) + print(model1_samp.Rt.size) + print(model1_samp.latent_infections.size) + print(model1_samp.observed_infections.size) model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections.value, + data_observed_infections=model1_samp.observed_infections, ) inf = model1.spread_draws(["all_latent_infections"]) @@ -241,7 +236,6 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() @@ -267,7 +261,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections.value, + data_observed_infections=model1_samp.observed_infections, padding=pad_size, ) diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index a67056ca..4aed146c 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -1,6 +1,5 @@ # numpydoc ignore=GL08 - from test.utils import SimpleRt import jax.numpy as jnp @@ -22,39 +21,11 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation from pyrenew.randomvariable import DistributionalVariable -class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 - def __init__( - self, - size: int, - pname: str, - ): # numpydoc ignore=GL08 - self.size = size - self.name = pname - - return None - - @staticmethod - def validate(self): # numpydoc ignore=GL08 - return None - - def sample(self, **kwargs): # numpydoc ignore=GL08 - return ( - SampledValue( - numpyro.sample( - name=self.name, - fn=dist.Uniform(high=0.99, low=0.01), - sample_shape=(self.size,), - ) - ), - ) - - def test_model_hosp_no_timepoints_or_observations(): """ Checks that the hospital admissions model does not run @@ -228,7 +199,6 @@ def test_model_hosp_no_obs_model(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), - t_unit=1, ) latent_infections = Infections() @@ -260,31 +230,29 @@ def test_model_hosp_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_almost_equal( - model0_samp.Rt.value, model1_samp.Rt.value - ) + np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) np.testing.assert_array_equal( - model0_samp.latent_infections.value, - model1_samp.latent_infections.value, + model0_samp.latent_infections, + model1_samp.latent_infections, ) np.testing.assert_array_equal( - model0_samp.infection_hosp_rate.value, - model1_samp.infection_hosp_rate.value, + model0_samp.infection_hosp_rate, + model1_samp.infection_hosp_rate, ) np.testing.assert_array_equal( - model0_samp.latent_hosp_admissions.value, - model1_samp.latent_hosp_admissions.value, + model0_samp.latent_hosp_admissions, + model1_samp.latent_hosp_admissions, ) # These are supposed to be none, both - assert model0_samp.observed_hosp_admissions.value is None - assert model1_samp.observed_hosp_admissions.value is None + assert model0_samp.observed_hosp_admissions is None + assert model1_samp.observed_hosp_admissions is None model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model0_samp.latent_hosp_admissions.value, + data_observed_hosp_admissions=model0_samp.latent_hosp_admissions, ) inf = model0.spread_draws(["latent_hospital_admissions"]) @@ -340,7 +308,6 @@ def test_model_hosp_with_obs_model(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), - t_unit=1, ) latent_infections = Infections() @@ -372,7 +339,7 @@ def test_model_hosp_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -429,15 +396,16 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), - t_unit=1, ) latent_infections = Infections() Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") - hosp_report_prob_dist = UniformProbForTest(1, "hosp_report_prob_dist") - weekday = UniformProbForTest(7, "weekday") + hosp_report_prob_dist = DistributionalVariable( + "hosp_report_prob_dist", dist.Uniform() + ) + weekday = DistributionalVariable("weekday", dist.Uniform()).expand_by((7,)) latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, @@ -457,7 +425,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): hosp_admission_obs_process_rv=observed_admissions, ) - # Sampling and fitting model 0 (with no obs for infections) + # Sampling and fitting model 0 (with no obs for admissions) with numpyro.handlers.seed(rng_seed=223): model1_samp = model1.sample(n_datapoints=30) @@ -465,7 +433,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -524,7 +492,6 @@ def test_model_hosp_with_obs_model_weekday_phosp(): "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), - t_unit=1, ) latent_infections = Infections() @@ -580,7 +547,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # obs = jnp.hstack( # [ # jnp.repeat(jnp.nan, pad_size), - # model1_samp.observed_hosp_admissions.value[pad_size:], + # model1_samp.observed_hosp_admissions[pad_size:], # ] # ) # Running with padding @@ -588,7 +555,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, padding=pad_size, ) diff --git a/test/test_observation_negativebinom.py b/test/test_observation_negativebinom.py index 77a6dcf1..7c9ef519 100644 --- a/test/test_observation_negativebinom.py +++ b/test/test_observation_negativebinom.py @@ -24,14 +24,12 @@ def test_negativebinom_deterministic_obs(): sim_nb1 = negb(mu=rates, obs=rates) sim_nb2 = negb(mu=rates, obs=rates) - assert isinstance(sim_nb1, tuple) - assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0].value, ArrayLike) - assert isinstance(sim_nb2[0].value, ArrayLike) + assert isinstance(sim_nb1, ArrayLike) + assert isinstance(sim_nb2, ArrayLike) testing.assert_array_equal( - sim_nb1[0].value, - sim_nb2[0].value, + sim_nb1, + sim_nb2, ) @@ -49,13 +47,12 @@ def test_negativebinom_random_obs(): with numpyro.handlers.seed(rng_seed=223): sim_nb1 = negb(mu=rates) sim_nb2 = negb(mu=rates) - assert isinstance(sim_nb1, tuple) - assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0].value, ArrayLike) - assert isinstance(sim_nb2[0].value, ArrayLike) + + assert isinstance(sim_nb1, ArrayLike) + assert isinstance(sim_nb2, ArrayLike) testing.assert_array_almost_equal( - np.mean(sim_nb1[0].value), - np.mean(sim_nb2[0].value), + np.mean(sim_nb1), + np.mean(sim_nb2), decimal=1, ) diff --git a/test/test_observation_poisson.py b/test/test_observation_poisson.py index 4fb7f664..b9d975be 100644 --- a/test/test_observation_poisson.py +++ b/test/test_observation_poisson.py @@ -17,6 +17,6 @@ def test_poisson_obs(): rates = np.random.randint(1, 5, size=10) with numpyro.handlers.seed(rng_seed=223): - sim_pois, *_ = pois(mu=rates) + sim_pois = pois(mu=rates) - testing.assert_array_equal(sim_pois.value, jnp.ceil(sim_pois.value)) + testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) diff --git a/test/test_periodiceffect.py b/test/test_periodiceffect.py index 98ec6fb9..dcac3260 100644 --- a/test/test_periodiceffect.py +++ b/test/test_periodiceffect.py @@ -19,8 +19,6 @@ def test_periodiceffect() -> None: params = { "offset": 0, "quantity_to_broadcast": rv, - "t_start": 0, - "t_unit": 1, } duration = 30 @@ -28,7 +26,7 @@ def test_periodiceffect() -> None: pe = PeriodicEffect(**params) with numpyro.handlers.seed(rng_seed=223): - ans = pe(duration=duration)[0].value + ans = pe(duration=duration) # Checking that the shape of the sampled Rt is correct assert ans.shape == (duration,) @@ -42,9 +40,9 @@ def test_periodiceffect() -> None: params["offset"] = 5 pe = PeriodicEffect(**params) with numpyro.handlers.seed(rng_seed=223): - ans2 = pe(duration=duration)[0].value + ans2 = pe(duration=duration) - ans2 = pe(duration=duration)[0].value + ans2 = pe(duration=duration) assert ans2.shape == (duration,) # This time series should be the same as the previous one, but shifted by @@ -63,14 +61,11 @@ def test_weeklyeffect() -> None: params = { "offset": 2, "quantity_to_broadcast": rv, - "t_start": 0, - "t_unit": 1, } params2 = { "offset": 2, "quantity_to_broadcast": rv, - "t_start": 0, } duration = 30 @@ -78,8 +73,8 @@ def test_weeklyeffect() -> None: pe = PeriodicEffect(**params) pe2 = DayOfWeekEffect(**params2) - ans1 = pe(duration=duration)[0].value - ans2 = pe2(duration=duration)[0].value + ans1 = pe(duration=duration) + ans2 = pe2(duration=duration) assert_array_equal(ans1, ans2) diff --git a/test/test_predictive.py b/test/test_predictive.py index 9e17a97a..414da12e 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -25,7 +25,6 @@ "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") diff --git a/test/test_random_key.py b/test/test_random_key.py index a6be9a11..99314e8f 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -29,7 +29,6 @@ def create_test_model(): # numpydoc ignore=GL08 "I0_initialization", DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") @@ -95,7 +94,7 @@ def test_rng_keys_produce_correct_samples(): # as the observed_infections for the rest of the models with numpyro.handlers.seed(rng_seed=223): model_sample = models[0].sample(n_datapoints=n_datapoints[0]) - obs_infections = [model_sample.observed_infections.value] * len(models) + obs_infections = [model_sample.observed_infections] * len(models) rng_keys = [jr.key(54), jr.key(54), None, None, jr.key(74)] # run test models with the different keys diff --git a/test/test_random_walk.py b/test/test_random_walk.py index 6997d679..b32896d4 100755 --- a/test/test_random_walk.py +++ b/test/test_random_walk.py @@ -37,7 +37,7 @@ def test_rw_can_be_sampled(element_rv, init_value): with numpyro.handlers.seed(rng_seed=62): # can sample with a fixed init # and with a random init - init_vals = init_rv()[0].value + init_vals = init_rv() ans_long = rw(n=5023, init_vals=init_vals) ans_short = rw(n=1, init_vals=init_vals) @@ -46,16 +46,12 @@ def test_rw_can_be_sampled(element_rv, init_value): with pytest.raises(ValueError, match="differencing order"): rw(n=523, init_vals=jnp.hstack([init_vals, 0.25])) # check that the samples are of the right shape - assert ans_long[0].value.shape == (5023,) - assert ans_short[0].value.shape == (1,) + assert ans_long.shape == (5023,) + assert ans_short.shape == (1,) # check that the first n_inits samples are the inits n_inits = jnp.atleast_1d(init_vals).size - assert_array_almost_equal( - ans_long[0].value[0:n_inits], jnp.atleast_1d(init_vals) - ) - assert_array_almost_equal( - ans_short[0].value, jnp.atleast_1d(init_vals)[:1] - ) + assert_array_almost_equal(ans_long[0:n_inits], jnp.atleast_1d(init_vals)) + assert_array_almost_equal(ans_short, jnp.atleast_1d(init_vals)[:1]) @pytest.mark.parametrize( @@ -89,8 +85,7 @@ def test_normal_rw_samples_correctly_distributed(step_mean, step_sd): ) with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal(n=n_samples, init_vals=rw_init_val) - samples = samples.value + samples = rw_normal(n=n_samples, init_vals=rw_init_val) # Checking the shape assert samples.shape == (n_samples,) diff --git a/test/test_regression.py b/test/test_regression.py index cdb39e32..95dabf0f 100755 --- a/test/test_regression.py +++ b/test/test_regression.py @@ -9,7 +9,6 @@ import pyrenew.regression as r import pyrenew.transformation as t -from pyrenew.metaclass import SampledValue def test_glm_prediction(): @@ -61,15 +60,12 @@ def test_glm_prediction(): ## check prediction output ## is of expected type and shape - assert isinstance(preds.prediction, SampledValue) - assert preds.prediction.value.shape[0] == predictor_values.shape[0] + assert preds.prediction.shape[0] == predictor_values.shape[0] ## check coeffficients and intercept - assert isinstance(preds.coefficients, SampledValue) - assert isinstance(preds.intercept, SampledValue) # check results agree with manual calculation assert_array_almost_equal( - preds.prediction.value, - preds.intercept.value + predictor_values @ preds.coefficients.value, + preds.prediction, + preds.intercept + predictor_values @ preds.coefficients, ) diff --git a/test/test_rtperiodicdiff.py b/test/test_rtperiodicdiff.py index 8d1ac28a..236a310b 100644 --- a/test/test_rtperiodicdiff.py +++ b/test/test_rtperiodicdiff.py @@ -32,7 +32,7 @@ def test_rtweeklydiff() -> None: rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): - rt = rtwd(duration=duration).rt.value + rt = rtwd(duration=duration) # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -47,7 +47,7 @@ def test_rtweeklydiff() -> None: rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): - rt2 = rtwd(duration=duration).rt.value + rt2 = rtwd(duration=duration) # Checking that the shape of the sampled Rt is correct assert rt2.shape == (duration,) @@ -83,7 +83,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: duration = 1000 with numpyro.handlers.seed(rng_seed=323): - rt = rtwd(duration=duration).rt.value + rt = rtwd(duration=duration) # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -125,7 +125,7 @@ def test_rtperiodicdiff_smallsample(inits): rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): - rt = rtwd(duration=6).rt.value + rt = rtwd(duration=6) # Checking that the shape of the sampled Rt is correct assert rt.shape == (6,) diff --git a/test/test_transformed_rv_class.py b/test/test_transformed_rv_class.py index a3a9d44e..3e1f354c 100644 --- a/test/test_transformed_rv_class.py +++ b/test/test_transformed_rv_class.py @@ -8,10 +8,11 @@ import numpyro import numpyro.distributions as dist import pytest +from jax.typing import ArrayLike from numpy.testing import assert_almost_equal import pyrenew.transformation as t -from pyrenew.metaclass import Model, RandomVariable, SampledValue +from pyrenew.metaclass import Model, RandomVariable from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -29,14 +30,10 @@ def sample(self, **kwargs): Returns ------- tuple - (SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - SampledValue(val, t_start=self.t_start, t_unit=self.t_unit)) + (val, val) """ val = numpyro.sample("my_normal", dist.Normal(0, 1)) - return ( - SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - ) + return (val, val) def sample_length(self): """ @@ -65,8 +62,8 @@ class RVSamples(NamedTuple): A container to hold the output of `NamedBaseRV()`. """ - rv1: SampledValue | None = None - rv2: SampledValue | None = None + rv1: ArrayLike | None = None + rv2: ArrayLike | None = None def __repr__(self): return f"RVSamples(rv1={self.rv1},rv2={self.rv2})" @@ -85,14 +82,10 @@ def sample(self, **kwargs): Returns ------- tuple - (rv1= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - rv2= SampledValue(val, t_start=self.t_start, t_unit=self.t_unit)) + (rv1=val, rv2=val) """ val = numpyro.sample("my_normal", dist.Normal(0, 1)) - return RVSamples( - rv1=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - rv2=SampledValue(val, t_start=self.t_start, t_unit=self.t_unit), - ) + return RVSamples(rv1=val, rv2=val) def validate(self): """ @@ -196,15 +189,13 @@ def test_transforms_applied_at_sampling(): norm_transformed_sample = tr_norm.sample() l2_transformed_sample = tr_l2.sample() - assert_almost_equal( - tr(norm_base_sample[0].value), norm_transformed_sample[0].value - ) + assert_almost_equal(tr(norm_base_sample), norm_transformed_sample) assert_almost_equal( ( - tr(l2_base_sample[0].value), - t.ExpTransform()(l2_base_sample[1].value), + tr(l2_base_sample[0]), + t.ExpTransform()(l2_base_sample[1]), ), - (l2_transformed_sample[0].value, l2_transformed_sample[1].value), + l2_transformed_sample, ) diff --git a/test/utils.py b/test/utils.py index 7a51e55c..11f7cad7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -3,9 +3,10 @@ """ import numpyro.distributions as dist +from jax.typing import ArrayLike import pyrenew.transformation as t -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import RandomVariable from pyrenew.process import RandomWalk from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -31,6 +32,7 @@ def __init__(self, name: str = "Rt_rv"): None """ self.name = name + name = "Rt_rv" self.rt_rv_ = TransformedVariable( name=f"{name}_log_rt_random_walk", base_rv=RandomWalk( @@ -45,16 +47,16 @@ def __init__(self, name: str = "Rt_rv"): name=f"{name}_init_log_rt", distribution=dist.Normal(0, 0.2) ) - def sample(self, n=None, **kwargs) -> SampledValue: + def sample(self, n=None, **kwargs) -> ArrayLike: """ Sample method Returns ------- - SampledValue + ArrayLike """ - init_rt, *_ = self.rt_init_rv_.sample() - return self.rt_rv_(init_vals=init_rt.value, n=n) + init_rt = self.rt_init_rv_() + return self.rt_rv_(init_vals=init_rt, n=n) @staticmethod def validate(self):