Skip to content

Commit

Permalink
minor cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Oct 25, 2023
1 parent 32cd38b commit 04f30ab
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
22 changes: 11 additions & 11 deletions R/nested_rhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
#'
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' rhat_nested(mu, superchain_ids = c(1,1,2,2))
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2))
#'
#' d <- as_draws_rvars(example_draws("multi_normal"))
#' rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2))
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2))
#'
#' @export
rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested")
rhat_nested <- function(x, ...) UseMethod("rhat_nested")

#' @rdname rhat_nested
#' @export
Expand All @@ -30,7 +30,9 @@ rhat_nested.default <- function(x, superchain_ids, ...) {
#' @rdname rhat_nested
#' @export
rhat_nested.rvar <- function(x, superchain_ids, ...) {
summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...)
summarise_rvar_by_element_with_chains(
x, rhat_nested, superchain_ids = superchain_ids, ...
)
}

.rhat_nested <- function(x, superchain_ids, ...) {
Expand All @@ -42,35 +44,33 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) {
niterations <- NROW(x)
nchains <- NCOL(x)


# check that all chains are assigned a superchain
if (length(superchain_ids) != nchains) {
warning_no_call("Length of superchain_ids not equal to number of chains, returning NA.")
warning_no_call("Length of superchain_ids not equal to number of chains, ",
"returning NA.")
return(NA_real_)
}


# check that superchains are equal length
superchain_id_table <- table(superchain_ids)
nchains_per_superchain <- max(superchain_id_table)

if (nchains_per_superchain != min(superchain_id_table)) {
warning_no_call("Number of chains per superchain is not the same for each superchain, returning NA.")
warning_no_call("Number of chains per superchain is not the same for ",
"each superchain, returning NA.")
return(NA_real_)
}

superchains <- unique(superchain_ids)


# mean and variance of chains calculated as in rhat
chain_mean <- matrixStats::colMeans2(x)
chain_var <- matrixStats::colVars(x, center = chain_mean)

# mean of superchains calculated by only including specified chains
# (equation 15 in Margossian et al. 2023)
superchain_mean <- sapply(
superchains,
function(k) mean(x[, which(superchain_ids == k)])
superchains, function(k) mean(x[, which(superchain_ids == k)])
)

# overall mean (as defined in equation 16 in Margossian et al. 2023)
Expand Down
10 changes: 5 additions & 5 deletions man/rhat_nested.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 04f30ab

Please sign in to comment.