Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pareto-smoothing updates #314

Merged
merged 17 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

### Enhancements

* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth
weights before adding to a draws object.
* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.


# posterior 1.5.0

### Enhancements
Expand Down
2 changes: 2 additions & 0 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#' | [mcse_mean()] | Monte Carlo standard error for the mean |
#' | [mcse_quantile()] | Monte Carlo standard error for quantiles |
#' | [mcse_sd()] | Monte Carlo standard error for standard deviations |
#' | [pareto_khat()] | Pareto khat diagnostic for tail(s) |
#' | [pareto_diags()] | Additional diagnostics related to Pareto khat |
#' | [rhat_basic()] | Basic version of Rhat |
#' | [rhat()] | Improved, rank-based version of Rhat |
#' | [rhat_nested()] | Rhat for use with many short chains |
Expand Down
122 changes: 87 additions & 35 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#' the number of fractional moments that is useful for convergence
#' diagnostics. For further details see Vehtari et al. (2022).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return `khat` estimated Generalized Pareto Distribution shape parameter k
#'
#' @seealso [`pareto_diags`] for additional related diagnostics, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_khat(mu)
Expand All @@ -25,6 +29,7 @@ pareto_khat.default <- function(x,
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
are_log_weights = FALSE,
...) {
smoothed <- pareto_smooth.default(
x,
Expand All @@ -34,6 +39,7 @@ pareto_khat.default <- function(x,
verbose = verbose,
return_k = TRUE,
smooth_draws = FALSE,
are_log_weights = are_log_weights,
...)
return(smoothed$diagnostics)
}
Expand Down Expand Up @@ -65,6 +71,7 @@ pareto_khat.rvar <- function(x, ...) {
#' replacing tail draws by order statistics of a generalized Pareto
#' distribution fit to the tail(s).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
Expand Down Expand Up @@ -100,6 +107,8 @@ pareto_khat.rvar <- function(x, ...) {
#' when the sample size is increased, compared to the central limit
#' theorem convergence rate. See Appendix B in Vehtari et al. (2022).
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_diags(mu)
Expand All @@ -113,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags")
#' @rdname pareto_diags
#' @export
pareto_diags.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
...) {
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
are_log_weights = FALSE,
...) {

smoothed <- pareto_smooth.default(
x,
Expand All @@ -128,6 +138,7 @@ pareto_diags.default <- function(x,
extra_diags = TRUE,
verbose = verbose,
smooth_draws = FALSE,
are_log_weights = FALSE,
...)

return(smoothed$diagnostics)
Expand Down Expand Up @@ -189,6 +200,8 @@ pareto_diags.rvar <- function(x, ...) {
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_smooth(mu)
Expand Down Expand Up @@ -225,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
)
}
out <- list(
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
)
} else {
out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x))
Expand All @@ -238,25 +251,36 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
#' @export
pareto_smooth.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
r_eff = 1,
ndraws_tail = NULL,
return_k = TRUE,
extra_diags = FALSE,
verbose = FALSE,
are_log_weights = FALSE,
...) {

checkmate::assert_number(ndraws_tail, null.ok = TRUE)
checkmate::assert_number(r_eff, null.ok = TRUE)
checkmate::assert_logical(extra_diags)
checkmate::assert_logical(return_k)
checkmate::assert_logical(verbose)
checkmate::expect_numeric(ndraws_tail, null.ok = TRUE)
checkmate::expect_numeric(r_eff, null.ok = TRUE)
extra_diags <- as_one_logical(extra_diags)
return_k <- as_one_logical(return_k)
verbose <- as_one_logical(verbose)
are_log_weights <- as_one_logical(are_log_weights)

# check for infinite or na values
if (should_return_NA(x)) {
warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.")
return(list(x = x, diagnostics = NA_real_))
warning_no_call("Input contains infinite or NA values, or is constant. Fitting of generalized Pareto distribution not performed.")
if (!return_k) {
out <- x
} else {
out <- list(x = x, diagnostics = NA_real_)
}
return(out)
}

if (are_log_weights) {
tail <- "right"
}

tail <- match.arg(tail)
S <- length(x)

Expand Down Expand Up @@ -290,6 +314,7 @@ pareto_smooth.default <- function(x,
x,
ndraws_tail = ndraws_tail,
tail = "left",
are_log_weights = are_log_weights,
...
)
left_k <- smoothed$k
Expand All @@ -299,12 +324,14 @@ pareto_smooth.default <- function(x,
x = smoothed$x,
ndraws_tail = ndraws_tail,
tail = "right",
are_log_weights = are_log_weights,
...
)
right_k <- smoothed$k

k <- max(left_k, right_k)
x <- smoothed$x

} else {

smoothed <- .pareto_smooth_tail(
Expand All @@ -326,10 +353,11 @@ pareto_smooth.default <- function(x,

if (verbose) {
if (!extra_diags) {
diags_list <- .pareto_smooth_extra_diags(diags_list$khat, length(x))
diags_list <- c(diags_list, .pareto_smooth_extra_diags(diags_list$khat, length(x)))
}
pareto_k_diagmsg(
diags = diags_list
diags = diags_list,
are_weights = are_log_weights
)
}

Expand All @@ -349,26 +377,32 @@ pareto_smooth.default <- function(x,
ndraws_tail,
smooth_draws = TRUE,
tail = c("right", "left"),
are_log_weights = FALSE,
...
) {

if (are_log_weights) {
# shift log values for safe exponentiation
x <- x - max(x)
}

tail <- match.arg(tail)

S <- length(x)
tail_ids <- seq(S - ndraws_tail + 1, S)


if (tail == "left") {
x <- -x
}

ord <- sort.int(x, index.return = TRUE)
draws_tail <- ord$x[tail_ids]
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

max_tail <- max(draws_tail)
min_tail <- min(draws_tail)

if (ndraws_tail >= 5) {
ord <- sort.int(x, index.return = TRUE)
if (abs(max_tail - min_tail) < .Machine$double.eps / 100) {
Expand All @@ -380,12 +414,19 @@ pareto_smooth.default <- function(x,
k <- NA
} else {
# save time not sorting since x already sorted
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE)
if (are_log_weights) {
draws_tail <- exp(draws_tail)
cutoff <- exp(cutoff)
}
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...)
k <- fit$k
sigma <- fit$sigma
if (is.finite(k) && smooth_draws) {
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
if (are_log_weights) {
smoothed <- log(smoothed)
}
} else {
smoothed <- NULL
}
Expand Down Expand Up @@ -445,11 +486,11 @@ pareto_smooth.default <- function(x,
#' @noRd
ps_min_ss <- function(k, ...) {
if (k < 1) {
out <- 10^(1 / (1 - max(0, k)))
out <- 10^(1 / (1 - max(0, k)))
} else {
out <- Inf
out <- Inf
}
out
out
}


Expand Down Expand Up @@ -506,27 +547,38 @@ ps_tail_length <- function(S, r_eff, ...) {
#'
#' Given S and scalar and k, form a diagnostic message string
#' @param diags (numeric) named vector of diagnostic values
#' @param are_weights (logical) are the diagnostics for weights
#' @param ... unused
#' @return diagnostic message
#' @noRd
pareto_k_diagmsg <- function(diags, ...) {
pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
khat <- diags$khat
min_ss <- diags$min_ss
khat_threshold <- diags$khat_threshold
convergence_rate <- diags$convergence_rate
msg <- NULL
if (khat > 1) {
msg <- paste0(msg,'All estimates are unreliable. If the distribution of ratios is bounded,\n',
'further draws may improve the estimates, but it is not possible to predict\n',
'whether any feasible sample size is sufficient.')
} else {
if (khat > khat_threshold) {
msg <- paste0(msg, 'S is too small, and sample size larger than ', round(min_ss, 0), ' is needed for reliable results.\n')

if (!are_weights) {

if (khat > 1) {
msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n",
"further draws may improve the estimates, but it is not possible to predict\n",
"whether any feasible sample size is sufficient.")
} else {
msg <- paste0(msg, 'To halve the RMSE, approximately ', round(2^(2/convergence_rate),1), ' times bigger S is needed.\n')
if (khat > khat_threshold) {
msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
} else {
msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n")
}
if (khat > 0.7) {
msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n")
}
}
if (khat > 0.7) {
msg <- paste0(msg, 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n')

} else {

if (khat > khat_threshold || khat > 0.7) {
msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
}
}
message(msg)
Expand Down
Loading
Loading