Skip to content

Commit

Permalink
fix summary output for non-MCMC algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Mar 18, 2024
1 parent 02784e9 commit b46b706
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,36 @@ summary.brmsfit <- function(object, priors = FALSE, prob = 0.95,
# the model does not contain posterior draws
return(out)
}
out$chains <- nchains(object)
# iterations before thinning
out$iter <- object$fit@sim$iter
out$warmup <- object$fit@sim$warmup
out$thin <- nthin(object)
stan_args <- object$fit@stan_args[[1]]
out$sampler <- paste0(stan_args$method, "(", stan_args$algorithm, ")")
if (priors) {
out$prior <- prior_summary(object, all = FALSE)
}

variables <- variables(object)
incl_classes <- c(
"b", "bs", "bcs", "bsp", "bmo", "bme", "bmi", "bm",
valid_dpars(object), "delta", "lncor", "rescor", "ar", "ma", "sderr",
"cosy", "cortime", "lagsar", "errorsar", "car", "sdcar", "rhocar",
"sd", "cor", "df", "sds", "sdgp", "lscale", "simo"
)
incl_regex <- paste0("^", regex_or(incl_classes), "(_|$|\\[)")
variables <- variables[grepl(incl_regex, variables)]
draws <- as_draws_array(object, variable = variables)

out$total_ndraws <- ndraws(draws)
out$chains <- nchains(object)
if (length(object$fit@sim$iter)) {
# MCMC algorithms
out$iter <- object$fit@sim$iter
out$warmup <- object$fit@sim$warmup
} else {
# non-MCMC algorithms
out$iter <- out$total_ndraws
out$warmup <- 0
}
out$thin <- nthin(object)

# compute a summary for given set of parameters
# TODO: align names with summary outputs of other methods and packages
.summary <- function(draws, variables, probs, robust) {
Expand Down Expand Up @@ -96,16 +115,6 @@ summary.brmsfit <- function(object, priors = FALSE, prob = 0.95,
return(out)
}

variables <- variables(object)
incl_classes <- c(
"b", "bs", "bcs", "bsp", "bmo", "bme", "bmi", "bm",
valid_dpars(object), "delta", "lncor", "rescor", "ar", "ma", "sderr",
"cosy", "cortime", "lagsar", "errorsar", "car", "sdcar", "rhocar",
"sd", "cor", "df", "sds", "sdgp", "lscale", "simo"
)
incl_regex <- paste0("^", regex_or(incl_classes), "(_|$|\\[)")
variables <- variables[grepl(incl_regex, variables)]
draws <- as_draws_array(object, variable = variables)
full_summary <- .summary(draws, variables, probs, robust)
if (algorithm(object) == "sampling") {
Rhats <- full_summary[, "Rhat"]
Expand Down Expand Up @@ -246,11 +255,10 @@ print.brmssummary <- function(x, digits = 2, ...) {
# TODO: make this option a user-facing argument?
short <- as_one_logical(getOption("brms.short_summary", FALSE))
if (!short) {
total_ndraws <- ceiling((x$iter - x$warmup) / x$thin * x$chains)
cat(paste0(
" Draws: ", x$chains, " chains, each with iter = ", x$iter,
"; warmup = ", x$warmup, "; thin = ", x$thin, ";\n",
" total post-warmup draws = ", total_ndraws, "\n"
" total post-warmup draws = ", x$total_ndraws, "\n"
))
}
cat("\n")
Expand Down

0 comments on commit b46b706

Please sign in to comment.