Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted rvars #331

Open
wants to merge 52 commits into
base: master
Choose a base branch
from
Open

Weighted rvars #331

wants to merge 52 commits into from

Conversation

mjskay
Copy link
Collaborator

@mjskay mjskay commented Jan 6, 2024

Summary

This PR aims to address (at least part of) #184 by implementing weighted rvars.

Currently, rvars cannot contain weights, and weighting of them can only be done by putting them in a draws_rvars object that itself contains a ".log_weight" rvar containing the weights. This leads to counterintuitive behavior, like the default output of the rvar (showing mean and sd) using unweighted versions of those statistics.

This PR addresses that issue in the following ways:

  • It stores rvar weights as a "log_weights" attribute directly on the rvar, just like the "nchains" attribute is used to store chain count.
  • draws_rvars no longer use a ".log_weight" variable to store weights, instead storing them directly on each rvar they contain, and requiring all rvars they contain to have the same weights (the same way it handles "nchains").
  • A new log_weights() function for draws and rvars is added, which is a lower-level version of weights(x, log = TRUE, normalize = FALSE) that just returns the log weights stored in the object without transformation. I initially did not have this, but found it greatly eased programming with weights.
  • weight_draws(x, NULL) is now allowed as the canonical way to remove weights from a draws object, since remove_variables(x, ".log_weight") does not work on draws_rvars objects anymore.
  • All summary functions for rvars have been updated to incorporate weights (with a couple of exceptions I haven't gotten to yet, see TODOs and Questions below).
  • Since rvar internals are becoming (even more) complicated, I have added an "rvar Internals" section to ?rvar that hopefully will help in case others need to touch the code ;).

Demo

set.seed(1234)
x = rvar(rnorm(1000))
x
#> rvar<1000>[1] mean ± sd:
#> [1] -0.027 ± 1

w1 = rexp(1000)
x1 = weight_draws(x, w1)
x1
#> weighted rvar<1000>[1] mean ± sd:
#> [1] -0.00087 ± 1

w2 = rexp(1000)
x2 = weight_draws(x, w2)
x2
#> weighted rvar<1000>[1] mean ± sd:
#> [1] -0.003 ± 0.96

You can't combine two rvars with different weights:

x1 + x2
#> Error: Random variables have different log weights and cannot be used together:
#> <dbl> 0.794199981930473, -1.61888922585584, 1.02558084358998, -0.657945687118312, 0.132635682996154 ...
#> <dbl> 0.661721670766407, -1.46589074644228, -1.39312536919089, 0.318133129307739, 0.66043235310858 ...

The check for equality of weights is done on the internal weights using identical(), which should be fast, especially in cases where the two weight vectors are actually pointers to the same vector in memory (in which case the comparison is constant time). This does mean the weights vectors must be exactly the same (no tolerance for floating point error), but I suspect in most cases when weighting happens the exact same weight vector is being applied to many objects. In any case, if someone did encounter this issue they could simply assign the log weights from object to the other.

If one rvar is weighted and another is not, the weights of the weighted rvar are inherited, which I believe covers the use case of (weighted draws from some model) + (unweighted draws, e.g. used to simulate predictions):

x1 + rvar(rnorm(1000, 1))
#> weighted rvar<1000>[1] mean ± sd:
#> [1] 0.96 ± 1.4

If you install the dev version of {ggdist}:

remotes::install_github("mjskay/ggdist")

It supports weighted rvars in all functions (densities, CDFs, quantiles, all interval types and all point summaries):

Without weights:

library(ggplot2)
library(ggdist)

set.seed(1234)
x = rvar(rnorm(10000, c(1,5)))

ggplot() + stat_slabinterval(aes(xdist = x))

image

With weights:

xw = weight_draws(x, rep(c(1,2), 5000))
ggplot() + stat_slabinterval(aes(xdist = xw))

image

Weights should work basically everywhere:

ggplot() + 
  stat_slabinterval(
    aes(xdist = xw), 
    point_interval = mode_hdi, 
    density = "histogram", 
    breaks = 50
  )

image

TODOs and Questions

TODOs:

  • I still have to implement density(<rvar>), cdf(<rvar>), and quantile(<rvar>) / quantile2(<rvar>). The first two are straightforward. For weighted quantiles, I have an implementation in {ggdist} that I can port over, but I may want to update it first; some thoughts on weighted quantiles are here and feedback is welcome. (In fact, since writing that document my thinking has changed a bit---I originally thought the way I suggested implementing weighted quantiles in that document is an improvement on ggdist's current implementation, but after further investigation I might be leaning back towards how I did it in ggdist originally...).
  • Mention weights in vignette("rvar")

Questions:

  • I don't know if any of the other functions in R/convergence.R should be modified for weighted rvars. @avehtari?
  • Have I missed anything else?

Would love for folks to kick the tires. I think once this is in we could also start thinking about what a successor to summarise_draws() might look like that supports weights (and solves the various other open issues on summarise_draws()).

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to
license the submitted work under the following licenses:

@codecov-commenter
Copy link

codecov-commenter commented Jan 6, 2024

Codecov Report

Attention: Patch coverage is 98.89299% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 95.80%. Comparing base (c312846) to head (88fa83d).
Report is 12 commits behind head on master.

❗ Current head 88fa83d differs from pull request most recent head 1079cef. Consider uploading reports for the commit 1079cef to get more accurate results

Files Patch % Lines
R/rvar-.R 94.87% 2 Missing ⚠️
R/weighted.R 98.07% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #331      +/-   ##
==========================================
+ Coverage   95.31%   95.80%   +0.49%     
==========================================
  Files          50       51       +1     
  Lines        3840     3979     +139     
==========================================
+ Hits         3660     3812     +152     
+ Misses        180      167      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

@avehtari
Copy link
Collaborator

avehtari commented Jan 8, 2024

I don't know if any of the other functions in R/convergence.R should be modified for weighted rvars.

Currently, everything else than pareto_ functions assume non-weighted MCMC. I have so far assumed that MCMC and weighting are independent of each other (there might be some less common algorithms that jointly sample parameter values and weights).

  • In PSIS paper experiments, I computed separately ESS for MCMC and ESS for PSIS and combined them as ESS_MCMC*ESS_PSIS/S, which worked well for getting MCSE that matched RMSE (given khat<0.7)
  • I did not try what would happen if in ESS for MCMC computation we would just replace autocorrelation computation with functions that would use weighted means and variances, and I have not checked what it would produce, but I guess that would be not what we want.
  • It would be nice to have a flag stating if rvar does not have Markov dependency, and then ESS and MCSE would be based just on weights.
  • If there are both (assumed) Markov dependency and weights, we could follow the approach presented in PSIS paper for ess_ and mcse_ functions
  • If we assume MCMC and weighting are independent, then rhat_ and rstar could do the MCMC convergence check without weights (until we are aware of an algorithm that would have the weighting inside the MCMC already)
  • pareto_ functions are checking the tail(s) of a given argument, and it has been used to check tails of raw weights/ratios (r or r(theta) in PSIS paper notation), function of a variable (h or h(theta)), or the product (hr). With the weight support, they could automatically make the diagnostics for r and hr (and if no weights then just h). Here I'm assuming that we almost always use self-normalization so that we need to check the normalization (E[r]) and the quantity of interest (E[hr])

Pinging @n-kall , too

@n-kall
Copy link
Collaborator

n-kall commented Jan 15, 2024

  • If there are both (assumed) Markov dependency and weights, we could follow the approach presented in PSIS paper for ess_ and mcse_ functions

For reference: Equations 6 (MCSE) and 7 (ESS) in preprint v6

  • pareto_ functions are checking the tail(s) of a given argument, and it has been used to check tails of raw weights/ratios (r or r(theta) in PSIS paper notation), function of a variable (h or h(theta)), or the product (hr). With the weight support, they could automatically make the diagnostics for r and hr (and if no weights then just h). Here I'm assuming that we almost always use self-normalization so that we need to check the normalization (E[r]) and the quantity of interest (E[hr])

Any thoughts on how the two sets of diagnostics should be presented in summarise_draws?
Would it make sense to have separate e.g. pareto_khat_quantity, pareto_khat_weights columns?

@n-kall
Copy link
Collaborator

n-kall commented Jan 17, 2024

I'm currently working on updating the pareto_, ess_ and mcse_ functions for weighted rvars in a fork

This comment was marked as outdated.

This comment was marked as outdated.

R/rvar-.R Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants