diff --git a/NEWS.md b/NEWS.md index 7b9af2b..a690914 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,10 +2,11 @@ ### Enhancements +* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth + weights before adding to a draws object. * Matrix multiplication of `rvar`s can now be done with the base matrix multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. - # posterior 1.5.0 ### Enhancements diff --git a/R/convergence.R b/R/convergence.R index fa53119..cf895f0 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -21,6 +21,8 @@ #' | [mcse_mean()] | Monte Carlo standard error for the mean | #' | [mcse_quantile()] | Monte Carlo standard error for quantiles | #' | [mcse_sd()] | Monte Carlo standard error for standard deviations | +#' | [pareto_khat()] | Pareto khat diagnostic for tail(s) | +#' | [pareto_diags()] | Additional diagnostics related to Pareto khat | #' | [rhat_basic()] | Basic version of Rhat | #' | [rhat()] | Improved, rank-based version of Rhat | #' | [rhat_nested()] | Rhat for use with many short chains | diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 228aede..52ec6bb 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -5,10 +5,14 @@ #' the number of fractional moments that is useful for convergence #' diagnostics. For further details see Vehtari et al. (2022). #' +#' @family diagnostics #' @template args-pareto #' @template args-methods-dots #' @template ref-vehtari-paretosmooth-2022 #' @return `khat` estimated Generalized Pareto Distribution shape parameter k +#' +#' @seealso [`pareto_diags`] for additional related diagnostics, and +#' [`pareto_smooth`] for Pareto smoothed draws. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_khat(mu) @@ -25,6 +29,7 @@ pareto_khat.default <- function(x, r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ...) { smoothed <- pareto_smooth.default( x, @@ -34,6 +39,7 @@ pareto_khat.default <- function(x, verbose = verbose, return_k = TRUE, smooth_draws = FALSE, + are_log_weights = are_log_weights, ...) return(smoothed$diagnostics) } @@ -65,6 +71,7 @@ pareto_khat.rvar <- function(x, ...) { #' replacing tail draws by order statistics of a generalized Pareto #' distribution fit to the tail(s). #' +#' @family diagnostics #' @template args-pareto #' @template args-methods-dots #' @template ref-vehtari-paretosmooth-2022 @@ -100,6 +107,8 @@ pareto_khat.rvar <- function(x, ...) { #' when the sample size is increased, compared to the central limit #' theorem convergence rate. See Appendix B in Vehtari et al. (2022). #' +#' @seealso [`pareto_khat`] for only calculating khat, and +#' [`pareto_smooth`] for Pareto smoothed draws. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_diags(mu) @@ -113,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags") #' @rdname pareto_diags #' @export pareto_diags.default <- function(x, - tail = c("both", "right", "left"), - r_eff = NULL, - ndraws_tail = NULL, - verbose = FALSE, - ...) { + tail = c("both", "right", "left"), + r_eff = NULL, + ndraws_tail = NULL, + verbose = FALSE, + are_log_weights = FALSE, + ...) { smoothed <- pareto_smooth.default( x, @@ -128,6 +138,7 @@ pareto_diags.default <- function(x, extra_diags = TRUE, verbose = verbose, smooth_draws = FALSE, + are_log_weights = FALSE, ...) return(smoothed$diagnostics) @@ -189,6 +200,8 @@ pareto_diags.rvar <- function(x, ...) { #' Pareto smoothed estimates #' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates #' +#' @seealso [`pareto_khat`] for only calculating khat, and +#' [`pareto_diags`] for additional diagnostics. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_smooth(mu) @@ -225,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) { ) } out <- list( - x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)), - diagnostics = diags + x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)), + diagnostics = diags ) } else { out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x)) @@ -238,25 +251,36 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = NULL, + r_eff = 1, ndraws_tail = NULL, return_k = TRUE, extra_diags = FALSE, verbose = FALSE, + are_log_weights = FALSE, ...) { - checkmate::assert_number(ndraws_tail, null.ok = TRUE) - checkmate::assert_number(r_eff, null.ok = TRUE) - checkmate::assert_logical(extra_diags) - checkmate::assert_logical(return_k) - checkmate::assert_logical(verbose) + checkmate::expect_numeric(ndraws_tail, null.ok = TRUE) + checkmate::expect_numeric(r_eff, null.ok = TRUE) + extra_diags <- as_one_logical(extra_diags) + return_k <- as_one_logical(return_k) + verbose <- as_one_logical(verbose) + are_log_weights <- as_one_logical(are_log_weights) # check for infinite or na values if (should_return_NA(x)) { - warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.") - return(list(x = x, diagnostics = NA_real_)) + warning_no_call("Input contains infinite or NA values, or is constant. Fitting of generalized Pareto distribution not performed.") + if (!return_k) { + out <- x + } else { + out <- list(x = x, diagnostics = NA_real_) + } + return(out) } + if (are_log_weights) { + tail <- "right" + } + tail <- match.arg(tail) S <- length(x) @@ -290,6 +314,7 @@ pareto_smooth.default <- function(x, x, ndraws_tail = ndraws_tail, tail = "left", + are_log_weights = are_log_weights, ... ) left_k <- smoothed$k @@ -299,12 +324,14 @@ pareto_smooth.default <- function(x, x = smoothed$x, ndraws_tail = ndraws_tail, tail = "right", + are_log_weights = are_log_weights, ... ) right_k <- smoothed$k k <- max(left_k, right_k) x <- smoothed$x + } else { smoothed <- .pareto_smooth_tail( @@ -326,10 +353,11 @@ pareto_smooth.default <- function(x, if (verbose) { if (!extra_diags) { - diags_list <- .pareto_smooth_extra_diags(diags_list$khat, length(x)) + diags_list <- c(diags_list, .pareto_smooth_extra_diags(diags_list$khat, length(x))) } pareto_k_diagmsg( - diags = diags_list + diags = diags_list, + are_weights = are_log_weights ) } @@ -349,26 +377,32 @@ pareto_smooth.default <- function(x, ndraws_tail, smooth_draws = TRUE, tail = c("right", "left"), + are_log_weights = FALSE, ... ) { + if (are_log_weights) { + # shift log values for safe exponentiation + x <- x - max(x) + } + tail <- match.arg(tail) S <- length(x) tail_ids <- seq(S - ndraws_tail + 1, S) - if (tail == "left") { x <- -x } ord <- sort.int(x, index.return = TRUE) draws_tail <- ord$x[tail_ids] - cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values + cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -380,12 +414,19 @@ pareto_smooth.default <- function(x, k <- NA } else { # save time not sorting since x already sorted - fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE) + if (are_log_weights) { + draws_tail <- exp(draws_tail) + cutoff <- exp(cutoff) + } + fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...) k <- fit$k sigma <- fit$sigma if (is.finite(k) && smooth_draws) { p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma) + if (are_log_weights) { + smoothed <- log(smoothed) + } } else { smoothed <- NULL } @@ -445,11 +486,11 @@ pareto_smooth.default <- function(x, #' @noRd ps_min_ss <- function(k, ...) { if (k < 1) { - out <- 10^(1 / (1 - max(0, k))) + out <- 10^(1 / (1 - max(0, k))) } else { - out <- Inf + out <- Inf } - out + out } @@ -506,27 +547,38 @@ ps_tail_length <- function(S, r_eff, ...) { #' #' Given S and scalar and k, form a diagnostic message string #' @param diags (numeric) named vector of diagnostic values +#' @param are_weights (logical) are the diagnostics for weights #' @param ... unused #' @return diagnostic message #' @noRd -pareto_k_diagmsg <- function(diags, ...) { +pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { khat <- diags$khat min_ss <- diags$min_ss khat_threshold <- diags$khat_threshold convergence_rate <- diags$convergence_rate msg <- NULL - if (khat > 1) { - msg <- paste0(msg,'All estimates are unreliable. If the distribution of ratios is bounded,\n', - 'further draws may improve the estimates, but it is not possible to predict\n', - 'whether any feasible sample size is sufficient.') - } else { - if (khat > khat_threshold) { - msg <- paste0(msg, 'S is too small, and sample size larger than ', round(min_ss, 0), ' is needed for reliable results.\n') + + if (!are_weights) { + + if (khat > 1) { + msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n", + "further draws may improve the estimates, but it is not possible to predict\n", + "whether any feasible sample size is sufficient.") } else { - msg <- paste0(msg, 'To halve the RMSE, approximately ', round(2^(2/convergence_rate),1), ' times bigger S is needed.\n') + if (khat > khat_threshold) { + msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") + } else { + msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n") + } + if (khat > 0.7) { + msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n") + } } - if (khat > 0.7) { - msg <- paste0(msg, 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n') + + } else { + + if (khat > khat_threshold || khat > 0.7) { + msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message(msg) diff --git a/R/weight_draws.R b/R/weight_draws.R index 5232639..3449482 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -14,6 +14,8 @@ #' @param log (logical) Are the weights passed already on the log scale? The #' default is `FALSE`, that is, expecting `weights` to be on the standard #' (non-log) scale. +#' @param pareto_smooth (logical) Should the weights be Pareto-smoothed? +#' The default is `FALSE`. #' @template args-methods-dots #' @template return-draws #' @@ -43,6 +45,9 @@ #' head(weights(x)) #' head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts #' +#' # add weights on log scale and Pareto smooth them +#' x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) +#' #' @export weight_draws <- function(x, weights, ...) { UseMethod("weight_draws") @@ -50,9 +55,15 @@ weight_draws <- function(x, weights, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, ".log_weight"] <- log_weights @@ -66,9 +77,14 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_array <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, , ".log_weight"] <- log_weights @@ -82,18 +98,28 @@ weight_draws.draws_array <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_df <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } x$.log_weight <- log_weights x } #' @rdname weight_draws #' @export -weight_draws.draws_list <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } niterations <- niterations(x) for (i in seq_len(nchains(x))) { sel <- (1 + (i - 1) * niterations):(i * niterations) @@ -104,9 +130,14 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_rvars <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } x$.log_weight <- rvar(log_weights) x } @@ -161,3 +192,14 @@ validate_weights <- function(weights, draws, log = FALSE) { } weights } + + +pareto_smooth_log_weights <- function(log_weights) { + pareto_smooth( + log_weights, + tail = "right", + return_k = TRUE, + are_log_weights = TRUE, + extra_diags = TRUE + )$x +} diff --git a/man-roxygen/args-pareto.R b/man-roxygen/args-pareto.R index a406dbd..8a4d92b 100644 --- a/man-roxygen/args-pareto.R +++ b/man-roxygen/args-pareto.R @@ -11,10 +11,14 @@ #' @param ndraws_tail (numeric) number of draws for the tail. If #' `ndraws_tail` is not specified, it will be calculated as #' ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -#' length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)). +#' length(x) / 5 otherwise (see Appendix H in Vehtari et +#' al. (2022)). #' @param r_eff (numeric) relative effective sample size estimate. If -#' `r_eff` is omitted, it will be calculated assuming the draws are -#' from MCMC. +#' `r_eff` is NULL, it will be calculated assuming the draws are +#' from MCMC. Default is 1. #' @param verbose (logical) Should diagnostic messages be printed? If #' `TRUE`, messages related to Pareto diagnostics will be #' printed. Default is `FALSE`. +#' @param are_log_weights (logical) Are the draws log weights? Default is +#' `FALSE`. If `TRUE` computation will take into account that the +#' draws are log weights, and only right tail will be smoothed. diff --git a/man-roxygen/ref-vehtari-paretosmooth-2022.R b/man-roxygen/ref-vehtari-paretosmooth-2022.R index 267ec29..30f526f 100644 --- a/man-roxygen/ref-vehtari-paretosmooth-2022.R +++ b/man-roxygen/ref-vehtari-paretosmooth-2022.R @@ -1,4 +1,4 @@ #' @references #' Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and #' Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -#' arxiv:arXiv:1507.02646 +#' arxiv:arXiv:1507.02646 (version 8) diff --git a/man/diagnostics.Rd b/man/diagnostics.Rd index 43d5418..f2ba499 100644 --- a/man/diagnostics.Rd +++ b/man/diagnostics.Rd @@ -21,6 +21,8 @@ A list of available diagnostics and links to their individual help pages. \code{\link[=mcse_mean]{mcse_mean()}} \tab Monte Carlo standard error for the mean \cr \code{\link[=mcse_quantile]{mcse_quantile()}} \tab Monte Carlo standard error for quantiles \cr \code{\link[=mcse_sd]{mcse_sd()}} \tab Monte Carlo standard error for standard deviations \cr + \code{\link[=pareto_khat]{pareto_khat()}} \tab Pareto khat diagnostic for tail(s) \cr + \code{\link[=pareto_diags]{pareto_diags()}} \tab Additional diagnostics related to Pareto khat \cr \code{\link[=rhat_basic]{rhat_basic()}} \tab Basic version of Rhat \cr \code{\link[=rhat]{rhat()}} \tab Improved, rank-based version of Rhat \cr \code{\link[=rhat_nested]{rhat_nested()}} \tab Rhat for use with many short chains \cr diff --git a/man/ess_basic.Rd b/man/ess_basic.Rd index e300ad5..867076c 100755 --- a/man/ess_basic.Rd +++ b/man/ess_basic.Rd @@ -79,6 +79,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_bulk.Rd b/man/ess_bulk.Rd index adf3faf..c1456be 100755 --- a/man/ess_bulk.Rd +++ b/man/ess_bulk.Rd @@ -72,6 +72,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_quantile.Rd b/man/ess_quantile.Rd index 6bfc3cd..aa85c90 100755 --- a/man/ess_quantile.Rd +++ b/man/ess_quantile.Rd @@ -81,6 +81,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_sd.Rd b/man/ess_sd.Rd index 2344211..38475d2 100755 --- a/man/ess_sd.Rd +++ b/man/ess_sd.Rd @@ -66,6 +66,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_tail.Rd b/man/ess_tail.Rd index f211f7a..8f95971 100755 --- a/man/ess_tail.Rd +++ b/man/ess_tail.Rd @@ -72,6 +72,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_mean.Rd b/man/mcse_mean.Rd index 9afaa7b..c75935b 100755 --- a/man/mcse_mean.Rd +++ b/man/mcse_mean.Rd @@ -63,6 +63,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_quantile.Rd b/man/mcse_quantile.Rd index cc4f968..2d05f62 100755 --- a/man/mcse_quantile.Rd +++ b/man/mcse_quantile.Rd @@ -78,6 +78,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_mean}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_sd.Rd b/man/mcse_sd.Rd index 7e32286..671ef24 100755 --- a/man/mcse_sd.Rd +++ b/man/mcse_sd.Rd @@ -68,6 +68,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/pareto_diags.Rd b/man/pareto_diags.Rd index c2558a9..9a1d577 100644 --- a/man/pareto_diags.Rd +++ b/man/pareto_diags.Rd @@ -14,6 +14,7 @@ pareto_diags(x, ...) r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ... ) @@ -39,17 +40,22 @@ pareto_diags(x, ...) The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ List of Pareto smoothing diagnostics: @@ -101,5 +107,25 @@ pareto_diags(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_khat}} for only calculating khat, and +\code{\link{pareto_smooth}} for Pareto smoothed draws. + +Other diagnostics: +\code{\link{ess_basic}()}, +\code{\link{ess_bulk}()}, +\code{\link{ess_quantile}()}, +\code{\link{ess_sd}()}, +\code{\link{ess_tail}()}, +\code{\link{mcse_mean}()}, +\code{\link{mcse_quantile}()}, +\code{\link{mcse_sd}()}, +\code{\link{pareto_khat}()}, +\code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, +\code{\link{rhat}()}, +\code{\link{rstar}()} } +\concept{diagnostics} diff --git a/man/pareto_khat.Rd b/man/pareto_khat.Rd index c3383eb..a4f9170 100644 --- a/man/pareto_khat.Rd +++ b/man/pareto_khat.Rd @@ -14,6 +14,7 @@ pareto_khat(x, ...) r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ... ) @@ -39,17 +40,22 @@ pareto_khat(x, ...) The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ \code{khat} estimated Generalized Pareto Distribution shape parameter k @@ -70,5 +76,25 @@ pareto_khat(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_diags}} for additional related diagnostics, and +\code{\link{pareto_smooth}} for Pareto smoothed draws. + +Other diagnostics: +\code{\link{ess_basic}()}, +\code{\link{ess_bulk}()}, +\code{\link{ess_quantile}()}, +\code{\link{ess_sd}()}, +\code{\link{ess_tail}()}, +\code{\link{mcse_mean}()}, +\code{\link{mcse_quantile}()}, +\code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, +\code{\link{rhat}()}, +\code{\link{rstar}()} } +\concept{diagnostics} diff --git a/man/pareto_smooth.Rd b/man/pareto_smooth.Rd index 259421d..c0e6f01 100644 --- a/man/pareto_smooth.Rd +++ b/man/pareto_smooth.Rd @@ -13,11 +13,12 @@ pareto_smooth(x, ...) \method{pareto_smooth}{default}( x, tail = c("both", "right", "left"), - r_eff = NULL, + r_eff = 1, ndraws_tail = NULL, return_k = TRUE, extra_diags = FALSE, verbose = FALSE, + are_log_weights = FALSE, ... ) } @@ -50,17 +51,22 @@ returned. Default is \code{FALSE}.} The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ Either a vector \code{x} of smoothed values or a named list @@ -91,5 +97,9 @@ pareto_smooth(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_khat}} for only calculating khat, and +\code{\link{pareto_diags}} for additional diagnostics. } diff --git a/man/rhat.Rd b/man/rhat.Rd index fed2c14..263561c 100755 --- a/man/rhat.Rd +++ b/man/rhat.Rd @@ -67,6 +67,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rstar}()} diff --git a/man/rhat_basic.Rd b/man/rhat_basic.Rd index 8a94efb..16ffd33 100755 --- a/man/rhat_basic.Rd +++ b/man/rhat_basic.Rd @@ -75,6 +75,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index 2e23242..f2536ef 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -83,6 +83,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat}()}, \code{\link{rstar}()} diff --git a/man/rstar.Rd b/man/rstar.Rd index 87e8e37..c947990 100644 --- a/man/rstar.Rd +++ b/man/rstar.Rd @@ -115,6 +115,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()} diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index 4601c98..d866d46 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -11,15 +11,15 @@ \usage{ weight_draws(x, weights, ...) -\method{weight_draws}{draws_matrix}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_matrix}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_array}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_array}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_df}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_df}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_list}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_list}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_rvars}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_rvars}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) } \arguments{ \item{x}{(draws) A \code{draws} object or another \R object for which the method @@ -35,6 +35,9 @@ can be returned via the \code{\link[=weights.draws]{weights.draws()}} method lat \item{log}{(logical) Are the weights passed already on the log scale? The default is \code{FALSE}, that is, expecting \code{weights} to be on the standard (non-log) scale.} + +\item{pareto_smooth}{(logical) Should the weights be Pareto-smoothed? +The default is \code{FALSE}.} } \value{ A \code{draws} object of the same class as \code{x}. @@ -70,6 +73,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# add weights on log scale and Pareto smooth them +x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) + } \seealso{ \code{\link[=weights.draws]{weights.draws()}}, \code{\link[=resample_draws]{resample_draws()}} diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 6b2a46f..1a47788 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -48,6 +48,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# add weights on log scale and Pareto smooth them +x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) + } \seealso{ \code{\link{weight_draws}}, \code{\link{resample_draws}} diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index dff22ee..6b67d2b 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -73,7 +73,7 @@ test_that("pareto_khat diagnostics messages are as expected", { diags$khat <- 1.1 expect_message(pareto_k_diagmsg(diags), - paste0('All estimates are unreliable. If the distribution of ratios is bounded,\n', + paste0('All estimates are unreliable. If the distribution of draws is bounded,\n', 'further draws may improve the estimates, but it is not possible to predict\n', 'whether any feasible sample size is sufficient.')) @@ -192,3 +192,16 @@ test_that("pareto_smooth returns x with smoothed tail", { expect_false(isTRUE(all.equal(sort(tau), sort(tau_smoothed)))) }) + +test_that("pareto_smooth works for log_weights", { + w <- c(1:25, 1e3, 1e3, 1e3) + lw <- log(w) + + ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10) + + # only right tail is smoothed + expect_equal(ps$x[1:15], lw[1:15]) + + expect_true(ps$diagnostics$khat > 0.7) + +}) diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 8733743..fb6e6cc 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -63,7 +63,6 @@ test_that("weight_draws works on draws_rvars", { expect_equal(weights2, weights) }) - # conversion preserves weights -------------------------------------------- test_that("conversion between formats preserves weights", { @@ -88,3 +87,13 @@ test_that("conversion between formats preserves weights", { expect_equal(as_draws_rvars(draws[[!!type]]), draws$rvars) } }) + +# pareto smoothing ---------------- + +test_that("pareto smoothing smooths weights in weight_draws", { + x <- example_draws() + lw <- sort(log(abs(rt(ndraws(x), 1)))) + weighted <- weight_draws(x, lw, pareto_smooth = FALSE, log = TRUE) + smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE) + expect_false(all(weights(weighted) == weights(smoothed))) +})