From b46b706bd21661edab3ae0a5b23908a5fab97a9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Mon, 18 Mar 2024 10:17:52 +0100 Subject: [PATCH] fix summary output for non-MCMC algorithms --- R/summary.R | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/R/summary.R b/R/summary.R index 741c118ef..eed14a5e0 100644 --- a/R/summary.R +++ b/R/summary.R @@ -47,17 +47,36 @@ summary.brmsfit <- function(object, priors = FALSE, prob = 0.95, # the model does not contain posterior draws return(out) } - out$chains <- nchains(object) - # iterations before thinning - out$iter <- object$fit@sim$iter - out$warmup <- object$fit@sim$warmup - out$thin <- nthin(object) stan_args <- object$fit@stan_args[[1]] out$sampler <- paste0(stan_args$method, "(", stan_args$algorithm, ")") if (priors) { out$prior <- prior_summary(object, all = FALSE) } + variables <- variables(object) + incl_classes <- c( + "b", "bs", "bcs", "bsp", "bmo", "bme", "bmi", "bm", + valid_dpars(object), "delta", "lncor", "rescor", "ar", "ma", "sderr", + "cosy", "cortime", "lagsar", "errorsar", "car", "sdcar", "rhocar", + "sd", "cor", "df", "sds", "sdgp", "lscale", "simo" + ) + incl_regex <- paste0("^", regex_or(incl_classes), "(_|$|\\[)") + variables <- variables[grepl(incl_regex, variables)] + draws <- as_draws_array(object, variable = variables) + + out$total_ndraws <- ndraws(draws) + out$chains <- nchains(object) + if (length(object$fit@sim$iter)) { + # MCMC algorithms + out$iter <- object$fit@sim$iter + out$warmup <- object$fit@sim$warmup + } else { + # non-MCMC algorithms + out$iter <- out$total_ndraws + out$warmup <- 0 + } + out$thin <- nthin(object) + # compute a summary for given set of parameters # TODO: align names with summary outputs of other methods and packages .summary <- function(draws, variables, probs, robust) { @@ -96,16 +115,6 @@ summary.brmsfit <- function(object, priors = FALSE, prob = 0.95, return(out) } - variables <- variables(object) - incl_classes <- c( - "b", "bs", "bcs", "bsp", "bmo", "bme", "bmi", "bm", - valid_dpars(object), "delta", "lncor", "rescor", "ar", "ma", "sderr", - "cosy", "cortime", "lagsar", "errorsar", "car", "sdcar", "rhocar", - "sd", "cor", "df", "sds", "sdgp", "lscale", "simo" - ) - incl_regex <- paste0("^", regex_or(incl_classes), "(_|$|\\[)") - variables <- variables[grepl(incl_regex, variables)] - draws <- as_draws_array(object, variable = variables) full_summary <- .summary(draws, variables, probs, robust) if (algorithm(object) == "sampling") { Rhats <- full_summary[, "Rhat"] @@ -246,11 +255,10 @@ print.brmssummary <- function(x, digits = 2, ...) { # TODO: make this option a user-facing argument? short <- as_one_logical(getOption("brms.short_summary", FALSE)) if (!short) { - total_ndraws <- ceiling((x$iter - x$warmup) / x$thin * x$chains) cat(paste0( " Draws: ", x$chains, " chains, each with iter = ", x$iter, "; warmup = ", x$warmup, "; thin = ", x$thin, ";\n", - " total post-warmup draws = ", total_ndraws, "\n" + " total post-warmup draws = ", x$total_ndraws, "\n" )) } cat("\n")