Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted rvars #331

Open
wants to merge 52 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
719e3e1
move common weight-processing code into validate_weights
mjskay Jan 5, 2024
c1c9595
basic setup of weighting for rvars
mjskay Jan 5, 2024
243492a
add log_weights() to get internal weights for easier programming
mjskay Jan 5, 2024
4e768c3
use rvar weights instead of .log_weight variable in draws_rvars
mjskay Jan 6, 2024
48c17ee
weighting for rvar functions except rvar-dist
mjskay Jan 6, 2024
4ba1582
weighted discrete summaries
mjskay Jan 6, 2024
2db69fe
tests for weighted rvars
mjskay Jan 6, 2024
6c87a3b
allow weight_draws(x, NULL) to remove weights
mjskay Jan 6, 2024
6ae97f8
test fixes
mjskay Jan 6, 2024
ecfbb84
make test reliable
mjskay Jan 6, 2024
ca67b6e
add documentation of rvar internals
mjskay Jan 7, 2024
739a5b4
minor edits to docs
mjskay Jan 8, 2024
fe307df
Merge branch 'rvar_weights' of github.com:stan-dev/posterior into rva…
Jan 17, 2024
19d2ff7
updating pareto functions for weighted rvars
Jan 17, 2024
53a0b78
cleanup rvar conform functions and drop unused keep_constants arg
mjskay Jan 18, 2024
476cdc2
prevent binding weighted and unweighted non-constant rvars
mjskay Jan 18, 2024
316a81e
Merge branch 'master' into rvar_weights
mjskay Jan 18, 2024
534d7fb
test coverage improvements
mjskay Jan 18, 2024
b0a778a
density, cdf, quantiles for weighted rvar
mjskay Jan 19, 2024
f282f42
use toString instead of paste(collapse = ", ")
mjskay Jan 22, 2024
a97b1fd
Merge branch 'rvar_weights' of github.com:stan-dev/posterior into sta…
Jan 22, 2024
221e5b2
Merge branch 'stan-dev-rvar_weights' into rvar_weights
Jan 22, 2024
6d88269
Merge branch 'master' into rvar_weights
mjskay Feb 2, 2024
83439f4
add weights to rvar vignette
mjskay Feb 3, 2024
22b9691
test coverage improvements
mjskay Feb 3, 2024
cb44b4a
start on weighted convergence
Feb 28, 2024
074349d
Merge branch 'rvar_weights' of github.com:stan-dev/posterior into rva…
Mar 7, 2024
165fb4f
improvements to weighted ess, mcse
Mar 7, 2024
8e1e0df
tweak weighted diagnostics
Mar 11, 2024
a9ba2b6
add r_eff into calculation of weighted ess and mcse
Mar 13, 2024
ccdfb2f
fixes to weighted ess and mcse
Mar 15, 2024
eed7221
use weighted quantile in weighted mcse for quantile
Mar 19, 2024
2c70d46
add tests for weighted convergence measures
Mar 19, 2024
1079cef
check weights must be a vector
mjskay Mar 20, 2024
0bfe6e6
updating pareto functions for weighted rvars
Jan 17, 2024
691d240
start on weighted convergence
Feb 28, 2024
74b3b71
improvements to weighted ess, mcse
Mar 7, 2024
a1b6564
tweak weighted diagnostics
Mar 11, 2024
6ccf496
add r_eff into calculation of weighted ess and mcse
Mar 13, 2024
e47fd92
fixes to weighted ess and mcse
Mar 15, 2024
1734e3d
use weighted quantile in weighted mcse for quantile
Mar 19, 2024
222893a
add tests for weighted convergence measures
Mar 19, 2024
a732ad8
fix r_eff calculations for each quantity
Mar 22, 2024
980dcde
resolve merge conflicts
Mar 22, 2024
6ce80df
fix weighted mcse for sd
Mar 22, 2024
c0d8b8e
fixes to pareto smoothing for weighted draws
Apr 8, 2024
dbdab63
do not unintentionally merge chains in pareto smoothing
Apr 9, 2024
a6cb3fb
add test for pareto_khat on weighted rvar
Apr 9, 2024
d928bd4
updates to weighted mcse for sd
Apr 9, 2024
14c01de
updates to mcse for weighted draws
Apr 9, 2024
33779ca
Merge pull request #2 from n-kall/rvar_weights
n-kall Apr 9, 2024
ffc9c6a
Merge pull request #357 from n-kall/weighted_diagnostics
paul-buerkner Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ S3method(iteration_ids,draws_rvars)
S3method(iteration_ids,rvar)
S3method(length,rvar)
S3method(levels,rvar)
S3method(log_weights,draws)
S3method(log_weights,draws_rvars)
S3method(log_weights,rvar)
S3method(mad,default)
S3method(mad,rvar)
S3method(mad,rvar_ordered)
Expand Down Expand Up @@ -394,7 +397,9 @@ S3method(weight_draws,draws_df)
S3method(weight_draws,draws_list)
S3method(weight_draws,draws_matrix)
S3method(weight_draws,draws_rvars)
S3method(weight_draws,rvar)
S3method(weights,draws)
S3method(weights,rvar)
export("%**%")
export("%in%")
export("draws_of<-")
Expand Down Expand Up @@ -454,6 +459,7 @@ export(is_rvar)
export(is_rvar_factor)
export(is_rvar_ordered)
export(iteration_ids)
export(log_weights)
export(mad)
export(match)
export(mcse_mean)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth
weights before adding to a draws object.
* Add support for applying weights to individual `rvar` objects.
* Add `log_weights()` function for easy access to raw internal weights.
* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.
* `variables()`, `variables<-()`, `set_variables()`, and `nvariables()` now
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_array.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ as_draws_array.draws_rvars <- function(x, ...) {
x <- check_variables_are_numeric(
x, to = "draws_array", is_non_numeric = is_rvar_factor, convert = FALSE
)
x <- promote_rvar_weights_to_variable(x)

# cbind discards class information when applied to vectors, which converts
# the underlying factors to numeric
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_df.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ as_draws_df.draws_rvars <- function(x, ...) {
if (ndraws(x) == 0L) {
return(empty_draws_df(variables(x)))
}
x <- promote_rvar_weights_to_variable(x)
out <- do.call(cbind, lapply(seq_along(x), function(i) {
# flatten each rvar so it only has two dimensions: draws and variables
# this also collapses indices into variable names in the format "var[i,j,k,...]"
Expand Down
1 change: 1 addition & 0 deletions R/as_draws_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ as_draws_matrix.draws_rvars <- function(x, ...) {
x <- check_variables_are_numeric(
x, to = "draws_matrix", is_non_numeric = is_rvar_factor, convert = FALSE
)
x <- promote_rvar_weights_to_variable(x)

# cbind discards class information when applied to vectors, which converts
# the underlying factors to numeric
Expand Down
30 changes: 29 additions & 1 deletion R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,27 @@ as_draws_rvars.mcmc.list <- function(x, ...) {

check_new_variables(names(x))

x <- conform_rvar_ndraws_nchains(x)
x <- conform_rvar_nchains_ndraws_weights(x)

class(x) <- class_draws_rvars()

# move the .log_weight column into the log_weights attribute of each rvar,
# but only if there is no conflict between any existing weights on the rvars
if (".log_weight" %in% names(x)) {
existing_weights <- log_weights(x[[1]])
.log_weight <- as.vector(draws_of(x$.log_weight))
if (is.null(existing_weights)) {
x$.log_weight <- NULL
x <- weight_draws(x, .log_weight, log = TRUE)
} else {
# if we reach this point either existing_weights and .log_weight
# are identical (so we don't have to do anything) or they aren't
# and weights2_common will throw the appropriate error --- thus
# we don't need to do anything with its output
weights2_common(existing_weights, .log_weight)
}
}

x
}

Expand Down Expand Up @@ -258,3 +276,13 @@ empty_draws_rvars <- function(variables = character(0), nchains = 0) {
class(out) <- class_draws_rvars()
out
}

# when converting draws_rvars to other formats, we must promote log weights
# to be a variable before doing the conversion
promote_rvar_weights_to_variable <- function(x) {
.log_weights <- log_weights(x)
if (!is.null(.log_weights)) {
x$.log_weight <- rvar(log_weights(x), nchains = nchains(x))
}
x
}
16 changes: 9 additions & 7 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,9 @@ quantile2.default <- function(
) {
names <- as_one_logical(names)
na.rm <- as_one_logical(na.rm)
if (!na.rm && anyNA(x)) {
# quantile itself doesn't handle this case (#110)
out <- rep(NA_real_, length(probs))
} else {
out <- quantile(x, probs = probs, na.rm = na.rm, ...)
}

out <- weighted_quantile(x, probs = probs, na.rm = na.rm, ...)

if (names) {
names(out) <- paste0("q", probs * 100)
} else {
Expand All @@ -560,7 +557,12 @@ quantile2.default <- function(
quantile2.rvar <- function(
x, probs = c(0.05, 0.95), na.rm = FALSE, names = TRUE, ...
) {
summarise_rvar_by_element_with_chains(x, quantile2, probs, na.rm, names, ...)
weights <- weights(x)
summarise_rvar_by_element(x, function(draws) {
quantile2(
draws, probs = probs, weights = weights, na.rm = na.rm, names = names, ...
)
})
}

# internal ----------------------------------------------------------------
Expand Down
73 changes: 47 additions & 26 deletions R/discrete-summaries.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
#'
#' Normalized entropy, for measuring dispersion in draws from categorical distributions.
#'
#' @param x (multiple options) A vector to be interpreted as draws from
#' a categorical distribution, such as:
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-x-categorical
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Calculates the normalized Shannon entropy of the draws in `x`. This value is
Expand Down Expand Up @@ -51,14 +49,14 @@
#' xy
#' entropy(xy)
#' @export
entropy <- function(x) {
entropy <- function(x, ...) {
UseMethod("entropy")
}
#' @rdname entropy
#' @export
entropy.default <- function(x) {
entropy.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA_real_)
p <- prop.table(simple_table(x)$count)
p <- prop.table(weighted_simple_table(x, weights)$count)
n <- length(p)

if (n == 1) {
Expand All @@ -71,8 +69,8 @@ entropy.default <- function(x) {
}
#' @rdname entropy
#' @export
entropy.rvar <- function(x) {
summarise_rvar_by_element(x, entropy)
entropy.rvar <- function(x, ...) {
summarise_rvar_by_element(x, entropy, weights = weights(x))
}


Expand All @@ -85,6 +83,8 @@ entropy.rvar <- function(x) {
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Calculates Tastle and Wierman's (2007) *dissention* measure:
Expand Down Expand Up @@ -125,12 +125,12 @@ entropy.rvar <- function(x) {
#' xy
#' dissent(xy)
#' @export
dissent <- function(x) {
dissent <- function(x, ...) {
UseMethod("dissent")
}
#' @rdname dissent
#' @export
dissent.default <- function(x) {
dissent.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA_real_)
if (length(x) == 0) return(0)

Expand All @@ -141,33 +141,32 @@ dissent.default <- function(x) {
d <- diff(range(x))
}

tab <- simple_table(x)
tab <- weighted_simple_table(x, weights)
p <- prop.table(tab$count)

if (length(p) == 1) {
out <- 0
} else {
x_i <- tab$x
out <- -sum(p * log2(1 - abs(x_i - mean(x)) / d))
mean_x <- if (is.null(weights)) mean(x) else weighted.mean(x, weights)
out <- -sum(p * log2(1 - abs(x_i - mean_x) / d))
}
out
}
#' @rdname dissent
#' @export
dissent.rvar <- function(x) {
summarise_rvar_by_element(x, dissent)
dissent.rvar <- function(x, ...) {
summarise_rvar_by_element(x, dissent, weights = weights(x))
}


#' Modal category
#'
#' Modal category of a vector.
#'
#' @param x (multiple options) A vector to be interpreted as draws from
#' a categorical distribution, such as:
#' - A [factor]
#' - A [numeric] (should be [integer] or integer-like)
#' - An [rvar], [rvar_factor], or [rvar_ordered]
#' @template args-summaries-x-categorical
#' @template args-summaries-weights
#' @template args-methods-dots
#'
#' @details
#' Finds the modal category (i.e., most frequent value) in `x`. In the case of
Expand All @@ -192,20 +191,20 @@ dissent.rvar <- function(x) {
#' xy
#' modal_category(xy)
#' @export
modal_category <- function(x) {
modal_category <- function(x, ...) {
UseMethod("modal_category")
}
#' @rdname modal_category
#' @export
modal_category.default <- function(x) {
modal_category.default <- function(x, weights = NULL, ...) {
if (anyNA(x)) return(NA)
tab <- simple_table(x)
tab <- weighted_simple_table(x, weights)
tab$x[which.max(tab$count)]
}
#' @rdname modal_category
#' @export
modal_category.rvar <- function(x) {
summarise_rvar_by_element(x, modal_category)
modal_category.rvar <- function(x, ...) {
summarise_rvar_by_element(x, modal_category, weights = weights(x))
}


Expand All @@ -231,3 +230,25 @@ simple_table <- function(x) {
count = tabulate(x_int, nbins = length(values))
)
}

#' A weighted version of simple_table
#' @param x a vector (numeric, factor, character, etc)
#' @param weights weights
#' @returns a list with two components of the same length
#' - `x`: unique values from the input `x`
#' - `count`: sum of weights for each unique value of `x`
#' @noRd
weighted_simple_table <- function(x, weights) {
if (is.null(weights)) return(simple_table(x))
stopifnot(identical(length(x), length(weights)))

if (is.factor(x)) {
values <- levels(x)
} else {
values <- unique(x)
}
list(
x = values,
count = vapply(split(weights, factor(x, values)), sum, numeric(1), USE.NAMES = FALSE)
)
}
7 changes: 5 additions & 2 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,11 @@ nchains.rvar <- function(x) {
# attribute on an rvar, ALWAYS use this function so that the proxy
# cache is invalidated
`nchains_rvar<-` <- function(x, value) {
attr(x, "nchains") <- value
invalidate_rvar_cache(x)
if (attr(x, "nchains") != value) {
attr(x, "nchains") <- value
x <- invalidate_rvar_cache(x)
}
x
}


Expand Down
2 changes: 1 addition & 1 deletion R/mutate_variables.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mutate_variables.draws_rvars <- function(.x, ...) {
for (var in names(dots)) {
.x[[var]] <- as_rvar(eval_tidy(dots[[var]], .x, env))
}
conform_rvar_ndraws_nchains(.x)
conform_rvar_nchains_ndraws_weights(.x)
}

# evaluate an expression passed to 'mutate_variables' and check its validity
Expand Down
2 changes: 1 addition & 1 deletion R/resample_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ resample_draws.draws <- function(x, weights = NULL, method = "stratified",
weights <- rep.int(1/ndraws_total, ndraws_total)
}
# resampling invalidates stored weights
x <- remove_variables(x, ".log_weight")
x <- weight_draws(x, NULL)
} else {
weights <- weights / sum(weights)
}
Expand Down
Loading
Loading