diff --git a/R/Learner.R b/R/Learner.R index 461dd150f..abb7e0d78 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -544,6 +544,21 @@ Learner = R6Class("Learner", } return(invisible(self)) + }, + + #' @description + #' Returns the features selected by the model. + #' The field `selected_features_impute` controls the behavior if the learner does not support feature selection. + #' If set to `"error"`, an error is thrown, otherwise all features are returned. + selected_features = function() { + if (is.null(self$model)) { + stopf("No model stored") + } + if (private$.selected_features_impute == "error") { + stop("Learner does not support feature selection") + } else { + self$state$feature_names + } } ), @@ -669,6 +684,17 @@ Learner = R6Class("Learner", private$.hotstart_stack = rhs }, + #' @field selected_features_impute (`character(1)`)\cr + #' Controls the behavior if the learner does not support feature selection. + #' If set to `"error"`, an error is thrown. + #' If set to `"all"` the complete feature set is returned. + selected_features_impute = function(rhs) { + if (missing(rhs)) { + return(private$.selected_features_impute) + } + private$.selected_features_impute = assert_choice(rhs, c("error", "all")) + }, + #' @field predict_types (`character()`)\cr #' Stores the possible predict types the learner is capable of. #' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections]. @@ -686,6 +712,7 @@ Learner = R6Class("Learner", .predict_types = NULL, .param_set = NULL, .hotstart_stack = NULL, + .selected_features_impute = "error", deep_clone = function(name, value) { switch(name, diff --git a/man/Learner.Rd b/man/Learner.Rd index 70d7cf8f9..faef571a5 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -301,6 +301,11 @@ Returns the encapsulation settings set with \verb{$encapsulate()}.} \item{\code{hotstart_stack}}{(\link{HotstartStack})\cr. Stores \code{HotstartStack}.} +\item{\code{selected_features_impute}}{(\code{character(1)})\cr +Controls the behavior if the learner does not support feature selection. +If set to \code{"error"}, an error is thrown. +If set to \code{"all"} the complete feature set is returned.} + \item{\code{predict_types}}{(\code{character()})\cr Stores the possible predict types the learner is capable of. A complete list of candidate predict types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}}. @@ -322,6 +327,7 @@ This field is read-only.} \item \href{#method-Learner-base_learner}{\code{Learner$base_learner()}} \item \href{#method-Learner-encapsulate}{\code{Learner$encapsulate()}} \item \href{#method-Learner-configure}{\code{Learner$configure()}} +\item \href{#method-Learner-selected_features}{\code{Learner$selected_features()}} \item \href{#method-Learner-clone}{\code{Learner$clone()}} } } @@ -633,6 +639,18 @@ Named list of parameter values and fields.} } \if{html}{\out{}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Learner-selected_features}{}}} +\subsection{Method \code{selected_features()}}{ +Returns the features selected by the model. +The field \code{selected_features_impute} controls the behavior if the learner does not support feature selection. +If set to \code{"error"}, an error is thrown, otherwise all features are returned. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Learner$selected_features()}\if{html}{\out{
}} +} + } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/LearnerClassif.Rd b/man/LearnerClassif.Rd index 64a29cf74..81c9f456b 100644 --- a/man/LearnerClassif.Rd +++ b/man/LearnerClassif.Rd @@ -93,6 +93,7 @@ Other Learner:
  • mlr3::Learner$predict_newdata()
  • mlr3::Learner$print()
  • mlr3::Learner$reset()
  • +
  • mlr3::Learner$selected_features()
  • mlr3::Learner$train()
  • diff --git a/man/LearnerRegr.Rd b/man/LearnerRegr.Rd index dca2085c7..da505ee73 100644 --- a/man/LearnerRegr.Rd +++ b/man/LearnerRegr.Rd @@ -97,6 +97,7 @@ The quantile to be used as response.}
  • mlr3::Learner$predict_newdata()
  • mlr3::Learner$print()
  • mlr3::Learner$reset()
  • +
  • mlr3::Learner$selected_features()
  • mlr3::Learner$train()
  • diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 8917f4452..4bef5c3ed 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -663,3 +663,21 @@ test_that("configure method works", { expect_equal(learner$param_set$values$xval, 10) expect_equal(learner$predict_sets, "train") }) + +test_that("selected_features works", { + task = tsk("spam") + # alter rpart class to not support feature selection + fun = LearnerClassifRpart$public_methods$selected_features + on.exit({ + LearnerClassifRpart$public_methods$selected_features = fun + }) + LearnerClassifRpart$public_methods$selected_features = NULL + + learner = lrn("classif.rpart") + expect_error(learner$selected_features(), "No model stored") + learner$train(task) + expect_error(learner$selected_features(), "Learner does not support feature selection") + + learner$selected_features_impute = "all" + expect_equal(learner$selected_features(), task$feature_names) +})