Skip to content

Commit

Permalink
Modified sa_diff() to support multiple input variables
Browse files Browse the repository at this point in the history
  • Loading branch information
tripartio committed Nov 5, 2024
1 parent e2b48e5 commit 52a830f
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 77 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Authors@R:
person("Chitu", "Okoli", , "Chitu.Okoli@skema.edu", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5574-7572"))
Language: en-US
Description: Standardized accuracy (staccuracy) is framework for expressing accuracy scores such that 50% represents a reference level of performance and 100% is perfect prediction. The 'staccuracy' package provides tools for creating staccuracy functions as well as some recommended staccuracy measures. It also provides functions for some classic performance metrics such as mean absolute error (MAE), root mean squared error (RMSE), and area under the receiver operating characteristic curve (AUCROC), as well as their winsorized versions when applicable.
Description: Standardized accuracy (staccuracy) is framework for expressing accuracy scores such that 50% represents a reference level of performance and 100% is a perfect prediction. The 'staccuracy' package provides tools for creating staccuracy functions as well as some recommended staccuracy measures. It also provides functions for some classic performance metrics such as mean absolute error (MAE), root mean squared error (RMSE), and area under the receiver operating characteristic curve (AUCROC), as well as their winsorized versions when applicable.
License: MIT + file LICENSE
Suggests:
testthat (>= 3.0.0)
Expand All @@ -18,7 +18,7 @@ Imports:
dplyr,
purrr,
rlang,
tidyr,
utils
stringr,
tidyr
URL: https://github.com/tripartio/staccuracy, https://tripartio.github.io/staccuracy/
BugReports: https://github.com/tripartio/staccuracy/issues
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export(sa_rmse_sd)
export(sa_wmae_mad)
export(sa_wrmse_sd)
export(staccuracy)
export(var_type)
export(win_mae)
export(win_rmse)
export(winsorize)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# staccuracy (development version)

* Added [sa_diff()] function for bootstrapped-based comparison of staccuracies.
* Added {sa_diff()} function for bootstrapped-based comparison of staccuracies.
* var_type() is no longer exported since it is really an internal function.

# staccuracy 0.1.0

Expand Down
143 changes: 84 additions & 59 deletions R/staccuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,41 @@ sa_wrmse_sd <- staccuracy(win_rmse, stats::sd)

#' Statistical tests for the differences between standardized accuracies (staccuracies)
#'
#' Because the distribution of staccuracies is uncertain (and indeed, different staccuracies likely have different distributions), bootstrapping is used to empirically estimate the distributions and calculate the p-values.
# But what test is appropriate? I don't think the t-test (a test of differences in means) is appropriate to test the difference between staccuracies. At the very simplest, I could bootstrap the differences and then compare confidence intervals. But ideally, I should work with a statistician to develop and analytical test of differences between staccuracies. The challenge is that different staccuracies might have different distributions. So, bootstrapping would be a simple, useful function. In the meantime, I could use univariateML to observe the distributions and thus gain an intuition. Parallelization should be an option to speed things up.
#'
#'
#'
#' Because the distribution of staccuracies is uncertain (and indeed, different staccuracies likely have different distributions), bootstrapping is used to empirically estimate the distributions and calculate the p-values. See the return value description for details on what the function provides.
#'
#' @param actual numeric vector. The actual (true) labels.
#' @param preds named list of at least two numeric vectors. Each element is a vector of the same length as actual with predictions for each row corresponding to each element of actual. The names of the list elements will be used for the columns of the result.
#' @param preds named list of at least two numeric vectors. Each element is a vector of the same length as actual with predictions for each row corresponding to each element of actual. The names of the list elements should be the names of the models that produced each respective prediction; these names will be used to distinguish the results.
#' @param ... not used. Forces explicit naming of subsequent arguments.
#' @param sa list of functions. Each element is the unquoted name of a valid staccuracy function (see [staccuracy()] for the required function signature.) If an element is named, the name will be displayed as the value of the `sa` column of the result. Otherwise, the function name will be displayed. If NULL (default), staccuracy functions will be automatically selected based on the datatypes of actual and `preds`.
#' @param na.rm See documentation for [staccuracy()]
#' @param sa list of functions. Each element is the unquoted name of a staccuracy function. If an element is named, the name will be displayed as the value of the `sa` column of the result. Otherwise, the function name will be displayed. If NULL (default), staccuracy functions will be automatically selected based on the datatypes of actual and `preds`.
#' @param pct numeric with values from (0, 1). The percentage values on which the difference in staccuracies will be tested.
#' @param boot_alpha numeric(1) from 0 to 1. Alpha for percentile-based confidence interval range for the bootstrapped means; the bootstrap confidence intervals will be the lowest and highest `(1 - 0.05) / 2` percentiles. For example, if `boot_alpha = 0.05` (default), the intervals will be at the 2.5 and 97.5 percentiles.
#' @param boot_it positive integer(1). The number of bootstrap iterations.
#' @param seed integer(1). Random seed for the bootstrap sampling. Supply this between runs to assure identical results.
#'
#' @return tibble. Columns are `sa` (name of staccuracy measure), then a column for each named element in `preds` (staccuracy for that prediction), then a column for each element of `pct` (p-value; percentage of iterations where the difference between staccuracy in the pair of `preds` is equal or greater than the `pct` value). E.g., for the default `pct = c(0.01, 0.02, 0.03, 0.04, 0.05)`, these columns would be `p_1`, `p_2`, `p_3`, `p_4`, and `p_5`. When there are more than two predictions, then each row will compare only two at a time.
#' @return tibble with staccuracy difference results:
#' * `staccuracy`: name of staccuracy measure
#' * `pred`, `type`: When `type` is 'pred', the `pred` column gives named element in the input `preds`. The row values give the staccuracy for that prediction. When `type` is 'diff', the `pred` column is of the form 'model1-model2', where 'model1' and 'model2' are names from the input `preds`, which should be the names of each model that provided the predictions. The row values give the difference between staccuracies of model1 and model2.
#' * `lo`, `mean`, `hi`: The lower bound, mean, and upper bound of the bootstrapped staccuracy. The lower and upper bounds are confidence intervals specified by the input `boot_alpha`.
#' * `p__`: p-values that the staccuracies are at least the specified percentage difference or greater. E.g., for the default input `pct = c(0.01, 0.02, 0.03, 0.04, 0.05)`, these columns would be `p01`, `p02`, `p03`, `p04`, and `p05`. As they apply only to differences between staccuracies, they are `NA` for rows of `type` 'pred'. As an example of their meaning, if the `mean` difference for 'model1-model2' is 0.0832 with `p01` of 0.012 and `p02` of 0.035, then it means that 1.2% of bootstrapped staccuracies had a difference of model1 - model2 less than 0.01 and 3.5% were less than 0.02. (That is, 98.8% of differences were greater than 0.01 and 96.5% were greater than 0.02.)
#'
#' @export
#'
#' @examples
#' lm_attitude_all <- lm(rating ~ ., data = attitude)
#' lm_attitude__a <- lm(rating ~ . - advance, data = attitude)
#' lm_attitude__c <- lm(rating ~ . - complaints, data = attitude)
#'
#' sa_diff(
#' sdf <- sa_diff(
#' attitude$rating,
#' list(
#' all = predict(lm_attitude_all),
#' madv = predict(lm_attitude__c)
#' madv = predict(lm_attitude__a),
#' mcmp = predict(lm_attitude__c)
#' ),
#' boot_it = 10
#' )
#' sdf
#'
sa_diff <- function(
actual,
Expand All @@ -152,12 +155,12 @@ sa_diff <- function(
) {
d_type <- var_type(actual)

lgth <- length(actual)
len <- length(actual)

if (d_type == 'numeric') {
sa <- list(
`Staccuracy WinMAE on MAD` = sa_wmae_mad,
`Staccuarcy WinRMSE on SD` = sa_wrmse_sd
`WinMAE on MAD` = sa_wmae_mad,
`WinRMSE on SD` = sa_wrmse_sd
)
}

Expand All @@ -172,28 +175,31 @@ sa_diff <- function(
# row_idxs: row indices of each bootstrap sample. Store just the indices rather than duplicating the entire dataset multiple times.
row_idxs = map(0:boot_it, \(it.bt) {
if (it.bt == 0) { # row 0 is the full dataset without bootstrapping
1:lgth
} else { # bootstrap: sample lgth with replacement
sample.int(lgth, replace = TRUE)
1:len
} else { # bootstrap: sample len with replacement
sample.int(len, replace = TRUE)
}
}),
staccuracy = character(boot_it + 1),
pred = character(boot_it + 1),
val = double(boot_it + 1)
)


# Bootstrap the calculations of staccuracy
sa_boot <-
map(0:boot_it, \(it) {
# Iteration 0 is the full sample
map(0:boot_it, \(btit) {
# Iterate across staccuracy measures
imap(sa, \(it.sa, it.sa_name) {
# Iterate across model predictions
imap(preds, \(it.pred, it.pred_name) {
tibble(
it = it,
it = btit,
staccuracy = it.sa_name,
pred = it.pred_name,
val = it.sa(
actual[boot_tbl$row_idxs[[it+1]]],
it.pred[boot_tbl$row_idxs[[it+1]]],
actual[boot_tbl$row_idxs[[btit+1]]],
it.pred[boot_tbl$row_idxs[[btit+1]]],
na.rm = FALSE
)
)
Expand All @@ -206,51 +212,70 @@ sa_diff <- function(
tidyr::pivot_wider(
names_from = 'pred',
values_from = 'val'
) |>
# Add the difference between the measures
rowwise() |>
mutate(
diff = max(c_across(where(is.double))) - min(c_across(where(is.double)))
) |>
ungroup()
)

if (boot_it != 0) {
sa_boot <- sa_boot |>
filter(.data$it != 0)
}

# fs (first,second): tbl of all possible combinations of differences between model staccuracies.
# The difference is first - second.
fs <- tidyr::expand_grid(
first = names(preds),
second = names(preds)
) |>
filter(first != .data$second)

# Staccuracies of the full vectors
sa_full <- sa_boot |>
filter(.data$it == 0) |>
select(-'it')
# Remove duplicate difference pairs
unique_fs <- purrr::map_lgl(1:nrow(fs), \(i.r) {
(fs[i.r, 'first'] %+% '|' %+% fs[i.r, 'second']) %notin%
(pull(fs[1:(i.r-1), 'second']) %+% '|' %+% pull(fs[1:(i.r-1), 'first']))
})
fs <- fs[unique_fs, ]

# browser()
# Calculate the differences for each pair of model staccuracies
for (i.row in 1:nrow(fs)) {
#sa_boost$`first-second` <- sa_boot$first - sa_boot$second
sa_boot[[
fs[[i.row, 'first']] %+% '-' %+% fs[[i.row, 'second']]
]] <-
sa_boot[[fs[[i.row, 'first']]]] - sa_boot[[fs[[i.row, 'second']]]]
}

# Create names for the percent difference columns
names_pct <- 'p' %+%
stringr::str_pad(pct * 100, width = 2, pad = '0')

# Generate expressions for p-values each percent difference threshold
p_pct_exprs <- map(pct, ~ expr(sum(value < !!.x) / boot_it))
names(p_pct_exprs) <- names_pct

# Summarize the bootstrapped staccuracies without the full vectors
# Summarize the bootstrapped staccuracies and differences
sa_tbl <- sa_boot |>
filter(.data$it != 0) |>
select(-'it') |>
# Pivot long for easier summarization
tidyr::pivot_longer(
!c('it', 'staccuracy'),
names_to = 'pred'
) |>
summarize(
.by = c(staccuracy),
across(
where(is.numeric),
list(
lo = ~ quantile(.x, boot_alpha / 2),
mn = mean,
hi = ~ quantile(.x, 1 - (boot_alpha / 2))
)
),
.by = c('staccuracy', 'pred'),
lo = stats::quantile(.data$value, boot_alpha / 2),
mean = mean(.data$value),
hi = stats::quantile(.data$value, 1 - (boot_alpha / 2)),
# Summarize p-values for requested percent thresholds
!!!p_pct_exprs,
) |>
mutate(
type = if_else(stringr::str_detect(.data$pred, '-'), 'diff', 'pred'),
# Delete p-value differences for actual staccuracies; they are irrelevant here
across(
diff,
map(pct, \(it.p) {
\(.x) {sum(.x < it.p) / !!boot_it}
}) |>
set_names(
formatC(pct * 100, width = 2, flag = '0')
),
.names = 'p_{.fn}'
all_of(names_pct),
~ if_else(type == 'diff', .x, NA)
)
# p_1 = sum(diff < 0.01) / boot_it,
# p_2 = sum(diff < 0.02) / boot_it,
# p_3 = sum(diff < 0.03) / boot_it,
# p_4 = sum(diff < 0.04) / boot_it,
# p_5 = sum(diff < 0.05) / boot_it,
)
) |>
select('staccuracy', 'pred', 'type', everything())


return(sa_tbl)
}
21 changes: 15 additions & 6 deletions man/sa_diff.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 9 additions & 7 deletions tests/testthat/_snaps/staccuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
sa_diff(attitude$rating, list(all = predict(lm_attitude_all), madv = predict(
lm_attitude__c)), boot_it = 10)
Output
# A tibble: 2 x 15
staccuracy all_lo all_mn all_hi madv_lo madv_mn madv_hi diff_lo diff_mn
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 Staccuracy WinMA~ 0.672 0.719 0.776 0.586 0.635 0.692 0.0440 0.0840
2 Staccuarcy WinRM~ 0.684 0.737 0.781 0.616 0.670 0.723 0.0335 0.0666
# i 6 more variables: diff_hi <dbl>, p_01 <dbl>, p_02 <dbl>, p_03 <dbl>,
# p_04 <dbl>, p_05 <dbl>
# A tibble: 6 x 11
staccuracy pred type lo mean hi p01 p02 p03 p04 p05
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 WinMAE on MAD all pred 0.672 0.719 0.776 NA NA NA NA NA
2 WinMAE on MAD madv pred 0.586 0.635 0.692 NA NA NA NA NA
3 WinMAE on MAD all-madv diff 0.0440 0.0840 0.133 0 0 0 0 0.1
4 WinRMSE on SD all pred 0.684 0.737 0.781 NA NA NA NA NA
5 WinRMSE on SD madv pred 0.616 0.670 0.723 NA NA NA NA NA
6 WinRMSE on SD all-madv diff 0.0335 0.0666 0.107 0 0 0.1 0.1 0.4

0 comments on commit 52a830f

Please sign in to comment.