From a3acc04c009e4133cfabac37b1bc7e82c2799de5 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 12:22:22 +0100 Subject: [PATCH 1/4] feat: add selected_features method to all learners --- R/Learner.R | 27 +++++++++++++++++++++++++++ man/Learner.Rd | 18 ++++++++++++++++++ man/LearnerClassif.Rd | 1 + man/LearnerRegr.Rd | 1 + tests/testthat/test_Learner.R | 14 ++++++++++++++ 5 files changed, 61 insertions(+) diff --git a/R/Learner.R b/R/Learner.R index d1b7e7863..2ee18ecf1 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -549,6 +549,21 @@ Learner = R6Class("Learner", } return(invisible(self)) + }, + + #' @description + #' Returns the selected features of the model. + #' The field `selected_features_impute` controls the behavior if the learner does not support feature selection. + #' If set to `"error"`, an error is thrown, otherwise all features are returned. + selected_features = function() { + if (is.null(self$model)) { + stopf("No model stored") + } + if (private$.selected_features_impute == "error") { + stop("Learner does not support feature selection") + } else { + self$state$feature_names + } } ), @@ -672,6 +687,17 @@ Learner = R6Class("Learner", } assert_r6(rhs, "HotstartStack", null.ok = TRUE) private$.hotstart_stack = rhs + }, + + #' @field selected_features_impute (`character(1)`)\cr + #' Controls the behavior if the learner does not support feature selection. + #' If set to `"error"`, an error is thrown. + #' If set to `"all"` the complete feature set is returned. + selected_features_impute = function(rhs) { + if (missing(rhs)) { + return(private$.selected_features_impute) + } + private$.selected_features_impute = assert_choice(rhs, c("error", "all")) } ), @@ -681,6 +707,7 @@ Learner = R6Class("Learner", .predict_type = NULL, .param_set = NULL, .hotstart_stack = NULL, + .selected_features_impute = "error", deep_clone = function(name, value) { switch(name, diff --git a/man/Learner.Rd b/man/Learner.Rd index f5ede1944..7f81db5fa 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -302,6 +302,11 @@ Returns the encapsulation settings set with \verb{$encapsulate()}.} \item{\code{hotstart_stack}}{(\link{HotstartStack})\cr. Stores \code{HotstartStack}.} + +\item{\code{selected_features_impute}}{(\code{character(1)})\cr +Controls the behavior if the learner does not support feature selection. +If set to \code{"error"}, an error is thrown. +If set to \code{"all"} the complete feature set is returned.} } \if{html}{\out{}} } @@ -319,6 +324,7 @@ Stores \code{HotstartStack}.} \item \href{#method-Learner-base_learner}{\code{Learner$base_learner()}} \item \href{#method-Learner-encapsulate}{\code{Learner$encapsulate()}} \item \href{#method-Learner-configure}{\code{Learner$configure()}} +\item \href{#method-Learner-selected_features}{\code{Learner$selected_features()}} \item \href{#method-Learner-clone}{\code{Learner$clone()}} } } @@ -630,6 +636,18 @@ Named list of parameter values and fields.} } \if{html}{\out{}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Learner-selected_features}{}}} +\subsection{Method \code{selected_features()}}{ +Returns the selected features of the model. +The field \code{selected_features_impute} controls the behavior if the learner does not support feature selection. +If set to \code{"error"}, an error is thrown, otherwise all features are returned. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Learner$selected_features()}\if{html}{\out{
}} +} + } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/LearnerClassif.Rd b/man/LearnerClassif.Rd index 64a29cf74..81c9f456b 100644 --- a/man/LearnerClassif.Rd +++ b/man/LearnerClassif.Rd @@ -93,6 +93,7 @@ Other Learner:
  • mlr3::Learner$predict_newdata()
  • mlr3::Learner$print()
  • mlr3::Learner$reset()
  • +
  • mlr3::Learner$selected_features()
  • mlr3::Learner$train()
  • diff --git a/man/LearnerRegr.Rd b/man/LearnerRegr.Rd index dca2085c7..da505ee73 100644 --- a/man/LearnerRegr.Rd +++ b/man/LearnerRegr.Rd @@ -97,6 +97,7 @@ The quantile to be used as response.}
  • mlr3::Learner$predict_newdata()
  • mlr3::Learner$print()
  • mlr3::Learner$reset()
  • +
  • mlr3::Learner$selected_features()
  • mlr3::Learner$train()
  • diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 8917f4452..c6ab0b36a 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -663,3 +663,17 @@ test_that("configure method works", { expect_equal(learner$param_set$values$xval, 10) expect_equal(learner$predict_sets, "train") }) + +test_that("selected_features works", { + task = tsk("spam") + LearnerClassifRpart2 = LearnerClassifRpart + LearnerClassifRpart2$public_methods$selected_features = NULL + + learner = LearnerClassifRpart$new() + expect_error(learner$selected_features(), "No model stored") + learner$train(task) + expect_error(learner$selected_features(), "Learner does not support feature selection") + + learner$selected_features_impute = "all" + expect_equal(learner$selected_features(), task$feature_names) +}) From fc541a84b9084903f2a4033bc10b1fb174fcebab Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 13:16:57 +0100 Subject: [PATCH 2/4] ... --- tests/testthat/test_Learner.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index c6ab0b36a..e4c2a454f 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -666,10 +666,11 @@ test_that("configure method works", { test_that("selected_features works", { task = tsk("spam") - LearnerClassifRpart2 = LearnerClassifRpart + # alter rpart class to not support feature selection + LearnerClassifRpart2 = R6::R6Class("LearnerClassifRpart2", inherit = LearnerClassifRpart) LearnerClassifRpart2$public_methods$selected_features = NULL - learner = LearnerClassifRpart$new() + learner = LearnerClassifRpart2$new() expect_error(learner$selected_features(), "No model stored") learner$train(task) expect_error(learner$selected_features(), "Learner does not support feature selection") From 9df4b85e2ae9be54abb420b9789d5c8172f18181 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 13:19:13 +0100 Subject: [PATCH 3/4] ... --- R/Learner.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Learner.R b/R/Learner.R index 2ee18ecf1..044e5ea20 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -552,7 +552,7 @@ Learner = R6Class("Learner", }, #' @description - #' Returns the selected features of the model. + #' Returns the features selected by the model. #' The field `selected_features_impute` controls the behavior if the learner does not support feature selection. #' If set to `"error"`, an error is thrown, otherwise all features are returned. selected_features = function() { From 391be3641e1548fcda7d12b7a97a4feb8f1b3361 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 13:46:33 +0100 Subject: [PATCH 4/4] ... --- tests/testthat/test_Learner.R | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index e4c2a454f..4bef5c3ed 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -667,10 +667,13 @@ test_that("configure method works", { test_that("selected_features works", { task = tsk("spam") # alter rpart class to not support feature selection - LearnerClassifRpart2 = R6::R6Class("LearnerClassifRpart2", inherit = LearnerClassifRpart) - LearnerClassifRpart2$public_methods$selected_features = NULL + fun = LearnerClassifRpart$public_methods$selected_features + on.exit({ + LearnerClassifRpart$public_methods$selected_features = fun + }) + LearnerClassifRpart$public_methods$selected_features = NULL - learner = LearnerClassifRpart2$new() + learner = lrn("classif.rpart") expect_error(learner$selected_features(), "No model stored") learner$train(task) expect_error(learner$selected_features(), "Learner does not support feature selection")