Skip to content

Commit

Permalink
Rewrote and corrected bug in sa_diff()
Browse files Browse the repository at this point in the history
  • Loading branch information
tripartio committed Nov 7, 2024
1 parent e0a694a commit dc2c4a4
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 33 deletions.
10 changes: 7 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# staccuracy (development version)

- Corrected bug in p-value calculation. Rewrote `sa_diff()` output to separate staccuracies from their differences. Use the `(r+1)/(n+1)` p-value calculation from [North et al. (2003)](https://pmc.ncbi.nlm.nih.gov/articles/PMC379244/).

# staccuracy 0.2.0

* Added sa_diff() function for bootstrapped-based comparison of staccuracies.
* var_type() is no longer exported since it is really an internal function.
- Added `sa_diff()` function for bootstrapped-based comparison of staccuracies.
- `var_type()` is no longer exported since it is really an internal function.

# staccuracy 0.1.0

* Initial CRAN submission.
- Initial CRAN submission.
60 changes: 43 additions & 17 deletions R/staccuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ sa_wrmse_sd <- staccuracy(win_rmse, stats::sd)
#'
#' @return tibble with staccuracy difference results:
#' * `staccuracy`: name of staccuracy measure
#' * `pred`, `type`: When `type` is 'pred', the `pred` column gives named element in the input `preds`. The row values give the staccuracy for that prediction. When `type` is 'diff', the `pred` column is of the form 'model1-model2', where 'model1' and 'model2' are names from the input `preds`, which should be the names of each model that provided the predictions. The row values give the difference between staccuracies of model1 and model2.
#' * `pred`: Each named element (model name) in the input `preds`. The row values give the staccuracy for that prediction. When `pred` is `NA`, the row represents the difference between prediction staccuracies (`diff`) instead of staccuracies themselves.
#' * `diff`: When `diff` takes the form 'model1-model2', then the row values give the difference in staccuracies between two named elements (model names) in the input `preds`. When `diff` is `NA`, the row instead represents the staccuracy of a specific model prediction (`pred`).
#' * `lo`, `mean`, `hi`: The lower bound, mean, and upper bound of the bootstrapped staccuracy. The lower and upper bounds are confidence intervals specified by the input `boot_alpha`.
#' * `p__`: p-values that the staccuracies are at least the specified percentage difference or greater. E.g., for the default input `pct = c(0.01, 0.02, 0.03, 0.04, 0.05)`, these columns would be `p01`, `p02`, `p03`, `p04`, and `p05`. As they apply only to differences between staccuracies, they are `NA` for rows of `type` 'pred'. As an example of their meaning, if the `mean` difference for 'model1-model2' is 0.0832 with `p01` of 0.012 and `p02` of 0.035, then it means that 1.2% of bootstrapped staccuracies had a difference of model1 - model2 less than 0.01 and 3.5% were less than 0.02. (That is, 98.8% of differences were greater than 0.01 and 96.5% were greater than 0.02.)
#' * `p__`: p-values that the difference in staccuracies are at least the specified percentage amount or greater. E.g., for the default input `pct = c(0.01, 0.02, 0.03, 0.04, 0.05)`, these columns would be `p01`, `p02`, `p03`, `p04`, and `p05`. As they apply only to differences between staccuracies, they are provided only for `diff` rows and are `NA` for `pred` rows. As an example of their meaning, if the `mean` difference for 'model1-model2' is 0.0832 with `p01` of 0.012 and `p02` of 0.035, then 1.2% of bootstrapped staccuracies had a model1 - model2 difference of less than 0.01 and 3.5% were less than 0.02. (That is, 98.8% of differences were greater than 0.01 and 96.5% were greater than 0.02.)
#'
#' @export
#'
Expand Down Expand Up @@ -247,10 +248,6 @@ sa_diff <- function(
names_pct <- 'p' %+%
stringr::str_pad(pct * 100, width = 2, pad = '0')

# Generate expressions for p-values each percent difference threshold
p_pct_exprs <- map(pct, ~ expr(sum(value < !!.x) / boot_it))
names(p_pct_exprs) <- names_pct

# Summarize the bootstrapped staccuracies and differences
sa_tbl <- sa_boot |>
# Pivot long for easier summarization
Expand All @@ -263,18 +260,47 @@ sa_diff <- function(
lo = stats::quantile(.data$value, boot_alpha / 2),
mean = mean(.data$value),
hi = stats::quantile(.data$value, 1 - (boot_alpha / 2)),
# Summarize p-values for requested percent thresholds
!!!p_pct_exprs,
) |>
# Count the number of times the value is greater than or equal to the p threshold.
# Create as a list column with the integer vector of counts.
num_gte = purrr::map_int(pct, \(it.pct) {
sum(.data$value >= it.pct)
}) |>
set_names(names_pct) |>
list(),
# Count the number of times the value is less than or equal to the p threshold.
num_lte = purrr::map_int(pct, \(it.pct) {
sum(.data$value <= it.pct)
}) |>
set_names(names_pct) |>
list()
)

# Create diff column and distinguish from pred
sa_tbl <- sa_tbl |>
mutate(
type = if_else(stringr::str_detect(.data$pred, '-'), 'diff', 'pred'),
# Delete p-value differences for actual staccuracies; they are irrelevant here
across(
all_of(names_pct),
~ if_else(type == 'diff', .x, NA)
)
) |>
select('staccuracy', 'pred', 'type', everything())
diff = if_else(stringr::str_detect(.data$pred, '-'), .data$pred, NA),
pred = if_else(is.na(.data$diff), .data$pred, NA)
)

# Add p-value columns.
# Iterate rows that express differences
for (i.r in which(!is.na(sa_tbl$diff))) {
# Iterate requested p-value thresholds
for (it.pct in names_pct) {
sa_tbl[[i.r, it.pct]] <-
# Count greater-than counts or less-than counts depending on if the mean difference is positive or negative
if (sa_tbl[i.r, 'mean'] >= 0) {
# p = (r+1)/(n+1): https://europepmc.org/article/MED/12111669
(sa_tbl[[i.r, 'num_lte']][[1]][[it.pct]] + 1) / (boot_it + 1)
} else {
(sa_tbl[[i.r, 'num_gte']][[1]][[it.pct]] + 1) / (boot_it + 1)
}
}
}

sa_tbl <- sa_tbl |>
select('staccuracy', 'pred', 'diff', everything()) |>
select(-'num_gte', -'num_lte')


return(sa_tbl)
Expand Down
5 changes: 3 additions & 2 deletions man/sa_diff.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 17 additions & 10 deletions tests/testthat/_snaps/staccuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

Code
sa_diff(attitude$rating, list(all = predict(lm_attitude_all), madv = predict(
lm_attitude__c)), boot_it = 10)
lm_attitude__a), mcmp = predict(lm_attitude__c)), boot_it = 10)
Output
# A tibble: 6 x 11
staccuracy pred type lo mean hi p01 p02 p03 p04 p05
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 WinMAE on MAD all pred 0.672 0.719 0.776 NA NA NA NA NA
2 WinMAE on MAD madv pred 0.586 0.635 0.692 NA NA NA NA NA
3 WinMAE on MAD all-madv diff 0.0440 0.0840 0.133 0 0 0 0 0.1
4 WinRMSE on SD all pred 0.684 0.737 0.781 NA NA NA NA NA
5 WinRMSE on SD madv pred 0.616 0.670 0.723 NA NA NA NA NA
6 WinRMSE on SD all-madv diff 0.0335 0.0666 0.107 0 0 0.1 0.1 0.4
# A tibble: 12 x 11
staccuracy pred diff lo mean hi p01 p02 p03
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 WinMAE on MAD all <NA> 0.672 0.719 0.776 NA NA NA
2 WinMAE on MAD madv <NA> 0.640 0.705 0.767 NA NA NA
3 WinMAE on MAD mcmp <NA> 0.586 0.635 0.692 NA NA NA
4 WinMAE on MAD <NA> all-madv -0.00660 0.0139 0.0369 0.455 0.727 0.818
5 WinMAE on MAD <NA> all-mcmp 0.0440 0.0840 0.133 0.0909 0.0909 0.0909
6 WinMAE on MAD <NA> madv-mcmp 0.0291 0.0702 0.122 0.0909 0.0909 0.182
7 WinRMSE on SD all <NA> 0.684 0.737 0.781 NA NA NA
8 WinRMSE on SD madv <NA> 0.670 0.732 0.782 NA NA NA
9 WinRMSE on SD mcmp <NA> 0.616 0.670 0.723 NA NA NA
10 WinRMSE on SD <NA> all-madv -0.00781 0.00529 0.0272 0.636 0.909 0.909
11 WinRMSE on SD <NA> all-mcmp 0.0335 0.0666 0.107 0.0909 0.0909 0.182
12 WinRMSE on SD <NA> madv-mcmp 0.0273 0.0613 0.108 0.0909 0.0909 0.182
# i 2 more variables: p04 <dbl>, p05 <dbl>

4 changes: 3 additions & 1 deletion tests/testthat/test-staccuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

test_that("sa_diff() works correctly", {
lm_attitude_all <- lm(rating ~ ., data = attitude)
lm_attitude__a <- lm(rating ~ . - advance, data = attitude)
lm_attitude__c <- lm(rating ~ . - complaints, data = attitude)

expect_equal(
Expand Down Expand Up @@ -44,7 +45,8 @@ test_that("sa_diff() works correctly", {
attitude$rating,
list(
all = predict(lm_attitude_all),
madv = predict(lm_attitude__c)
madv = predict(lm_attitude__a),
mcmp = predict(lm_attitude__c)
),
boot_it = 10
)
Expand Down

0 comments on commit dc2c4a4

Please sign in to comment.