From 51c4c8b877dcc9e2df0f2bcdd22b961c058b95fa Mon Sep 17 00:00:00 2001 From: AFg6K7h4fhy2 <127630341+AFg6K7h4fhy2@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:19:00 -0400 Subject: [PATCH] adding first tutorial edited w/ arviz, still need to get proper plot_ppc --- docs/source/tutorials/day_of_the_week.qmd | 39 ++++++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index f6558698..98f1c2fc 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -197,13 +197,42 @@ hosp_model.run( rng_key=jax.random.key(54), mcmc_args=dict(progress_bar=False), ) +``` -# Plotting the posterior -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=daily_hosp_admits.astype(float), +```{python} +import arviz as az +import matplotlib.pyplot as plt + +ppc_samples = hosp_model.posterior_predictive( + n_datapoints=daily_hosp_admits.size +) +idata = az.from_numpyro( + posterior=hosp_model.mcmc, + posterior_predictive=ppc_samples, + constant_data={"daily_hosp_admits": daily_hosp_admits}, + coords={"time": np.arange(daily_hosp_admits.size)}, + dims={ + "daily_hosp_admits": ["time"], + "latent_hospital_admissions": ["time"], + }, ) +print(idata.observed_data) +fig, ax = plt.subplots(figsize=(8, 6)) +az.plot_ppc(data=idata, kind="kde", ax=ax, num_pp_samples=100) +# ax.plot(np.arange(daily_hosp_admits.size), daily_hosp_admits.astype(float), color="black", label="Observed signal") + +ax.legend() +plt.xlabel("Time") +plt.ylabel("Hospital Admissions") +plt.show() + + +# # Plotting the posterior +# out = hosp_model.plot_posterior( +# var="latent_hospital_admissions", +# ylab="Hospital Admissions", +# obs_signal=daily_hosp_admits.astype(float), +# ) ```