From d52c6fa66f10e0553682ef47b618758b44b3a4c5 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 22 Aug 2023 09:46:58 +0200 Subject: [PATCH] update doc --- man/mlr_learners_surv.bart.Rd | 84 +++-------------------------------- 1 file changed, 7 insertions(+), 77 deletions(-) diff --git a/man/mlr_learners_surv.bart.Rd b/man/mlr_learners_surv.bart.Rd index bd043fc2a..1fb67cc0f 100644 --- a/man/mlr_learners_surv.bart.Rd +++ b/man/mlr_learners_surv.bart.Rd @@ -7,14 +7,12 @@ \description{ Fits a Bayesian Additive Regression Trees (BART) learner to right-censored survival data. + For prediction, we return the mean posterior estimates of the survival function and the corresponding \code{crank} (expected mortality) using \link[mlr3proba:dot-surv_return]{mlr3proba::.surv_return}. The full posterior estimates are currently stored in the -\code{learner$state$surv_test} slot, along with the number of test observations -\code{N}, number of unique times in the train set \code{K} and number of posterior -draws \code{M}. -See example for more details. +\code{learner$model$surv.test} slot. Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. } @@ -25,7 +23,7 @@ Calls \code{\link[BART:surv.bart]{BART::mc.surv.bart()}} from \CRANpkg{BART}. } } -\section{Initial parameter values}{ +\section{Custom mlr3 parameters}{ \itemize{ \item \code{quiet} allows to suppress messages generated by the wrapped C++ code. Is \code{TRUE} by default. @@ -90,79 +88,11 @@ lrn("surv.bart") } \examples{ -library(mlr3proba) -library(dplyr) -library(tidyr) -library(ggplot2) - -learner = lrn("surv.bart", nskip = 10, ndpost = 20, keepevery = 2) -task = tsk("lung") -task$missings() # has missing values - -# split to train and test sets -set.seed(42) -part = partition(task) - -# Train -learner$train(task, row_ids = part$train) - -# Importance: average number of times a feature has been used in the trees -learner$importance() - -# Test -p = learner$predict(task, row_ids = part$test) -p$score() # C-index - -# Mean survival probabilities for the first 3 patients at given time points -p$distr$survival(times = c(1,50,150))[,1:3] - -# number of posterior draws -M = learner$state$M -stopifnot(M == 20) -# number of test observations -N = learner$state$N -stopifnot(N == length(part$test)) -# number of unique time points in the train set -K = learner$state$K -stopifnot(K == length(task$unique_times(rows = part$train))) -# the actual times are also available in the `$model` slot: -head(learner$model$times) - -# Full posterior prediction matrix -surv_test = learner$state$surv_test -stopifnot(all(dim(surv_test) == c(M, K * N))) - -# Posterior survival function estimates for the 1st test patient for all -# time points (from the train set) - see Sparapani (2021), pages 34-35 -post_surv = surv_test[, 1:K] - -# For every time point, get the median survival estimate as well as -# the lower and upper bounds of the 95\% quantile credible interval -surv_data = post_surv \%>\% - as.data.frame() \%>\% - `colnames<-` (learner$model$times) \%>\% - summarise(across(everything(), list( - median = ~ median(.), - low_qi = ~ quantile(., 0.025), - high_qi = ~ quantile(., 0.975) - ))) \%>\% - pivot_longer( - cols = everything(), - names_to = c("times", ".value"), - names_pattern = "(^[^_]+)_(.*)" # everything until the first underscore - ) \%>\% - mutate(times = as.numeric(times)) -surv_data +learner = mlr3::lrn("surv.bart") +print(learner) -# Draw a survival curve for the first patient in the test set with -# uncertainty quantified -surv_data \%>\% - ggplot(aes(x = times, y = median)) + - geom_step(col = 'black') + - xlab('Time (Days)') + - ylab('Survival Probability') + - geom_ribbon(aes(ymin = low_qi, ymax = high_qi), alpha = 0.3) + - theme_bw() +# available parameters: +learner$param_set$ids() } \references{ Sparapani, Rodney, Spanbauer, Charles, McCulloch, Robert (2021).