Skip to content

Commit

Permalink
Merge pull request #429 from mlr-org/brier_fix
Browse files Browse the repository at this point in the history
Brier fix
  • Loading branch information
bblodfon authored Dec 18, 2024
2 parents e4a478b + e3766e8 commit a35a4c4
Show file tree
Hide file tree
Showing 33 changed files with 636 additions and 143 deletions.
8 changes: 6 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# mlr3proba 0.7.1

* Removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
* Bug fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
* cleanup: removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
* fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
* fix: G(t) is not filtered when `t_max|p_max` is specified in scoring rules (didn't influence evaluation at all)
* docs: Clarified the use and impact of using `t_max` in scoring rules, added examples in scoring rules and AUC scores
* feat: Added new argument `remove_obs` in scoring rules to remove observations with observed time `t > t_max` as a processing step to alleviate IPCW issues.
This was before 'hard-coded' which made the Integrated Brier Score (`msr("surv.graf")`) differ minimally from other implementations and the original definition.

# mlr3proba 0.7.0

Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvChamblessAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
inherit = MeasureSurvAUC,
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#'
#' # Harrell's C-index evaluated up to a specific time horizon
#' p$score(msr("surv.cindex", t_max = 97))
#'
#' # Harrell's C-index evaluated up to the time corresponding to 30% of censoring
#' p$score(msr("surv.cindex", p_max = 0.3))
#'
Expand Down
19 changes: 11 additions & 8 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#' @templateVar fullname MeasureSurvDCalibration
#'
#' @description
#' `r lifecycle::badge("experimental")`
#'
#' This calibration method is defined by calculating the following statistic:
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
Expand All @@ -12,8 +14,8 @@
#' falls within the corresponding interval.
#' This statistic assumes that censoring time is independent of death time.
#'
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated).
#' A model is well D-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated, i.e. higher p-values are preferred).
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)},
#' meaning that *lower values* of this measure are preferred.
#'
Expand All @@ -23,7 +25,7 @@
#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually
#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`.
#'
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
#' **NOTE**: This measure is still experimental both theoretically and in implementation. Results
#' should therefore only be taken as an indicator of performance and not for
#' conclusive judgements about model calibration.
#'
Expand All @@ -38,11 +40,12 @@
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
#' The null hypothesis is that the model is D-calibrated.
#' - `truncate` (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
#' when `chisq` is `FALSE`. We use `truncate = Inf` by default but \eqn{10} may be sufficient
#' for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
#' \eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
#' less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
#' This parameter controls the upper bound of the output statistic, when `chisq` is `FALSE`.
#' We use `truncate = Inf` by default but values between \eqn{10-16} are sufficient
#' for most purposes, which correspond to p-values of \eqn{0.35-0.06} for the `chisq.test` using
#' the default \eqn{B = 10} buckets.
#' Values \eqn{B > 10} translate to even lower p-values and thus less D-calibrated models.
#' If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
#' Note that truncation may severely limit automated tuning with this measure.
#'
Expand Down
15 changes: 9 additions & 6 deletions R/MeasureSurvGraf.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @aliases MeasureSurvBrier mlr_measures_surv.brier
#'
Expand All @@ -25,13 +26,13 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISBS** (RISBS) is:
#'
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
Expand All @@ -48,10 +49,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999")`
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvGraf = R6Class("MeasureSurvGraf",
inherit = MeasureSurv,
Expand All @@ -73,11 +75,12 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -132,7 +135,7 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvHungAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvHungAUC = R6Class("MeasureSurvHungAUC",
inherit = MeasureSurvAUC,
Expand Down
15 changes: 9 additions & 6 deletions R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @description
#' Calculates the **Integrated Survival Log-Likelihood** (ISLL) or Integrated
Expand All @@ -23,13 +24,13 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = -\text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = - \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISLL** (RISLL) is:
#'
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{\log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \frac{\int^{\tau^*}_0 \log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
Expand All @@ -46,10 +47,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999")`
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
inherit = MeasureSurv,
Expand All @@ -71,11 +73,12 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -130,7 +133,7 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
5 changes: 4 additions & 1 deletion R/MeasureSurvLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' Calculates the cross-entropy, or negative log-likelihood (NLL) or logarithmic (log), loss.
#' @section Parameter details:
#' - `IPCW` (`logical(1)`)\cr
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper).
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper). See Sonabend et al. (2024) for more details.
#'
#' @details
#' The Log Loss, in the context of probabilistic predictions, is defined as the
Expand All @@ -33,6 +33,9 @@
#'
#' @template details_trainG
#'
#' @references
#' `r format_bib("sonabend2024")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
Expand Down
22 changes: 9 additions & 13 deletions R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @description
#' Calculates the **Integrated Schmid Score** (ISS), aka integrated absolute loss.
Expand All @@ -22,27 +23,20 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISS** (RISS) is:
#'
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
#' To get a single score across all \eqn{N} observations of the test set, we
#' return the average of the time-integrated observation-wise scores:
#' \deqn{\sum_{i=1}^N L(S_i, t_i, \delta_i) / N}
#'
#'
#' \deqn{L_{ISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t^*))]}
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The re-weighted ISS, RISS is given by
#' \deqn{L_{RISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t))]}
#'
#' @template properness
#' @templateVar improper_id ISS
#' @templateVar proper_id RISS
Expand All @@ -52,10 +46,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("schemper_2000", "schmid_2011")`
#' `r format_bib("schemper_2000", "schmid_2011", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
inherit = MeasureSurv,
Expand All @@ -77,11 +72,12 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -135,7 +131,7 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvSongAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvSongAUC = R6Class("MeasureSurvSongAUC",
inherit = MeasureSurvAUC,
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvUnoAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvUnoAUC = R6Class("MeasureSurvUnoAUC",
inherit = MeasureSurvAUC,
Expand Down
20 changes: 20 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -741,5 +741,25 @@ bibentries = c( # nolint start
title = "Simulating Survival Data Using the simsurv R Package",
volume = "97",
year = "2021"
),
sonabend2024 = bibentry("misc",
archivePrefix = "arXiv",
arxivId = "2212.05260",
author = "Sonabend, Raphael and Zobolas, John and Kopper, Philipp and Burk, Lukas and Bender, Andreas",
month = "dec",
title = "Examining properness in the external validation of survival models with squared and logarithmic losses",
url = "https://arxiv.org/abs/2212.05260v2",
year = "2024"
),
kvamme2023 = bibentry("article",
author = "Kvamme, Havard and Borgan, Ornulf",
issn = "1533-7928",
journal = "Journal of Machine Learning Research",
number = "2",
pages = "1--26",
title = "The Brier Score under Administrative Censoring: Problems and a Solution",
url = "http://jmlr.org/papers/v24/19-1030.html",
volume = "24",
year = "2023"
)
)
12 changes: 4 additions & 8 deletions R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ score_graf_schmid = function(true_times, unique_times, cdf, power = 2) {
# - `t_max` > 0
# - `p_max` in [0,1]
weighted_survival_score = function(loss, truth, distribution, times = NULL,
t_max = NULL, p_max = NULL, proper, train = NULL, eps, ...) {
t_max = NULL, p_max = NULL, proper, train = NULL, eps, remove_obs = FALSE) {
assert_surv(truth)
# test set's (times, status)
test_times = truth[, "time"]
Expand Down Expand Up @@ -90,8 +90,8 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,
rownames(cdf) = unique_times # times x obs
}

# apply `t_max` cutoff to remove observations
if (tmax_apply) {
# apply `t_max` cutoff to remove observations as a preprocessing step to alleviate inflation
if (tmax_apply && remove_obs) {
true_times = test_times[test_times <= t_max]
true_status = test_status[test_times <= t_max]
cdf = cdf[, test_times <= t_max, drop = FALSE]
Expand All @@ -118,6 +118,7 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,

# use the `truth` (time, status) information from the train or test set
if (is.null(train)) {
# no filtering of observations from test data: use ALL
cens = survival::survfit(Surv(test_times, 1 - test_status) ~ 1)
} else {
# no filtering of observations from train data: use ALL
Expand All @@ -128,11 +129,6 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,
# G(t): KM estimate of the censoring distribution
cens = matrix(c(cens$time, cens$surv), ncol = 2L)

# filter G(t) time points based on `t_max` cutoff
if (tmax_apply) {
cens = cens[cens[, 1L] <= t_max, , drop = FALSE]
}

score = .c_weight_survival_score(score, true_truth, unique_times, cens, proper, eps)
colnames(score) = unique_times

Expand Down
Loading

0 comments on commit a35a4c4

Please sign in to comment.