Skip to content

Commit

Permalink
test prediction summarization with funs when type = 'mean'
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Feb 15, 2024
1 parent 6bd95e1 commit 66c0be2
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ test_that("global_fixed options produce equal results with balanced data", {
})

test_that("summarising via funs is equivalent to manual summary", {
# type = "response"
set.seed(1)
pred1 <- predict(
gaussian_example_fit,
Expand All @@ -378,14 +379,36 @@ test_that("summarising via funs is equivalent to manual summary", {
)
pred1 <- pred1$simulated |> dplyr::filter(time > 1)
set.seed(1)
pred2 <- predict(gaussian_example_fit, n_draws = 2L, expand = FALSE)
pred2 <- predict(
gaussian_example_fit, n_draws = 2L, expand = FALSE
)
pred2 <- pred2$simulated |>
dplyr::group_by(time, .draw) |>
dplyr::summarise(y_mean = mean(y_new), y_sd = sd(y_new)) |>
dplyr::filter(time > 1) |>
dplyr::summarise(mean_y = mean(y_new), sd_y = sd(y_new)) |>
dplyr::arrange(.draw)
expect_equal(pred1$mean_y, pred2$mean_y)
expect_equal(pred1$sd_y, pred2$sd_y)
# type = "mean"
set.seed(1)
pred3 <- predict(
gaussian_example_fit,
type = "mean",
funs = list(y = list(mean = mean, sd = sd)),
n_draws = 2L
)
pred3 <- pred3$simulated |> dplyr::filter(time > 1)
set.seed(1)
pred4 <- predict(
gaussian_example_fit, type = "mean", n_draws = 2L, expand = FALSE
)
pred4 <- pred4$simulated |>
dplyr::group_by(time, .draw) |>
dplyr::filter(time > 1) |>
dplyr::summarise(mean_y = mean(y_mean), sd_y = sd(y_mean)) |>
dplyr::arrange(.draw)
expect_equal(pred1$mean_y, pred2$y_mean)
expect_equal(pred1$sd_y, pred2$y_sd)
expect_equal(pred3$mean_y, pred4$mean_y)
expect_equal(pred3$sd_y, pred4$sd_y)
})

test_that("predict with loglik works", {
Expand Down

0 comments on commit 66c0be2

Please sign in to comment.