From 66c0be24758461d1cb3c2ed340c4900eaff09c06 Mon Sep 17 00:00:00 2001 From: Santtu Tikka Date: Thu, 15 Feb 2024 15:50:30 +0200 Subject: [PATCH] test prediction summarization with funs when type = 'mean' --- tests/testthat/test-predict.R | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 99d6d2d..ba42d42 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -370,6 +370,7 @@ test_that("global_fixed options produce equal results with balanced data", { }) test_that("summarising via funs is equivalent to manual summary", { + # type = "response" set.seed(1) pred1 <- predict( gaussian_example_fit, @@ -378,14 +379,36 @@ test_that("summarising via funs is equivalent to manual summary", { ) pred1 <- pred1$simulated |> dplyr::filter(time > 1) set.seed(1) - pred2 <- predict(gaussian_example_fit, n_draws = 2L, expand = FALSE) + pred2 <- predict( + gaussian_example_fit, n_draws = 2L, expand = FALSE + ) pred2 <- pred2$simulated |> dplyr::group_by(time, .draw) |> - dplyr::summarise(y_mean = mean(y_new), y_sd = sd(y_new)) |> dplyr::filter(time > 1) |> + dplyr::summarise(mean_y = mean(y_new), sd_y = sd(y_new)) |> + dplyr::arrange(.draw) + expect_equal(pred1$mean_y, pred2$mean_y) + expect_equal(pred1$sd_y, pred2$sd_y) + # type = "mean" + set.seed(1) + pred3 <- predict( + gaussian_example_fit, + type = "mean", + funs = list(y = list(mean = mean, sd = sd)), + n_draws = 2L + ) + pred3 <- pred3$simulated |> dplyr::filter(time > 1) + set.seed(1) + pred4 <- predict( + gaussian_example_fit, type = "mean", n_draws = 2L, expand = FALSE + ) + pred4 <- pred4$simulated |> + dplyr::group_by(time, .draw) |> + dplyr::filter(time > 1) |> + dplyr::summarise(mean_y = mean(y_mean), sd_y = sd(y_mean)) |> dplyr::arrange(.draw) - expect_equal(pred1$mean_y, pred2$y_mean) - expect_equal(pred1$sd_y, pred2$y_sd) + expect_equal(pred3$mean_y, pred4$mean_y) + expect_equal(pred3$sd_y, pred4$sd_y) }) test_that("predict with loglik works", {