From 2fe8746f71aaa55e12fff49f7934c96cb0d29ee8 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:52:13 +0100 Subject: [PATCH] refactor: predict_type and predict_types (#1233) * refactor: predict_type and predict_types * ... * ... --- NEWS.md | 3 +++ R/Learner.R | 21 +++++++++++++-------- man/Learner.Rd | 13 ++++++++----- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/NEWS.md b/NEWS.md index 83d6dfbb0..6992b98cd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3 (development version) +* BREAKING CHANGE: `Learner$predict_types` is read-only now. +* docs: Clear up behavior of `Learner$predict_type` after training. + # mlr3 0.22.1 * fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`. diff --git a/R/Learner.R b/R/Learner.R index d1b7e7863..461dd150f 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -157,11 +157,6 @@ Learner = R6Class("Learner", #' @template field_task_type task_type = NULL, - #' @field predict_types (`character()`)\cr - #' Stores the possible predict types the learner is capable of. - #' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections]. - predict_types = NULL, - #' @field feature_types (`character()`)\cr #' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`. #' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections]. @@ -214,7 +209,7 @@ Learner = R6Class("Learner", self$task_type = assert_choice(task_type, mlr_reflections$task_types$type) private$.param_set = assert_param_set(param_set) self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types") - self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), + private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), empty.ok = FALSE, .var.name = "predict_types") private$.predict_type = predict_types[1L] self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]])) @@ -627,6 +622,8 @@ Learner = R6Class("Learner", #' @field predict_type (`character(1)`)\cr #' Stores the currently active predict type, e.g. `"response"`. #' Must be an element of `$predict_types`. + #' A few learners already use the predict type during training. + #' So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors. predict_type = function(rhs) { if (missing(rhs)) { return(private$.predict_type) @@ -648,8 +645,6 @@ Learner = R6Class("Learner", private$.param_set }, - - #' @field fallback ([Learner])\cr #' Returns the fallback learner set with `$encapsulate()`. fallback = function(rhs) { @@ -672,6 +667,15 @@ Learner = R6Class("Learner", } assert_r6(rhs, "HotstartStack", null.ok = TRUE) private$.hotstart_stack = rhs + }, + + #' @field predict_types (`character()`)\cr + #' Stores the possible predict types the learner is capable of. + #' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections]. + #' This field is read-only. + predict_types = function(rhs) { + assert_ro_binding(rhs) + return(private$.predict_types) } ), @@ -679,6 +683,7 @@ Learner = R6Class("Learner", .encapsulation = c(train = "none", predict = "none"), .fallback = NULL, .predict_type = NULL, + .predict_types = NULL, .param_set = NULL, .hotstart_stack = NULL, diff --git a/man/Learner.Rd b/man/Learner.Rd index f5ede1944..70d7cf8f9 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -195,10 +195,6 @@ Task type, e.g. \code{"classif"} or \code{"regr"}. For a complete list of possible task types (depending on the loaded packages), see \code{\link[=mlr_reflections]{mlr_reflections$task_types$type}}.} -\item{\code{predict_types}}{(\code{character()})\cr -Stores the possible predict types the learner is capable of. -A complete list of candidate predict types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}}.} - \item{\code{feature_types}}{(\code{character()})\cr Stores the feature types the learner can handle, e.g. \code{"logical"}, \code{"numeric"}, or \code{"factor"}. A complete list of candidate feature types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$task_feature_types}}.} @@ -289,7 +285,9 @@ Hash (unique identifier) for this partial object, excluding some components whic \item{\code{predict_type}}{(\code{character(1)})\cr Stores the currently active predict type, e.g. \code{"response"}. -Must be an element of \verb{$predict_types}.} +Must be an element of \verb{$predict_types}. +A few learners already use the predict type during training. +So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.} \item{\code{param_set}}{(\link[paradox:ParamSet]{paradox::ParamSet})\cr Set of hyperparameters.} @@ -302,6 +300,11 @@ Returns the encapsulation settings set with \verb{$encapsulate()}.} \item{\code{hotstart_stack}}{(\link{HotstartStack})\cr. Stores \code{HotstartStack}.} + +\item{\code{predict_types}}{(\code{character()})\cr +Stores the possible predict types the learner is capable of. +A complete list of candidate predict types, grouped by task type, is stored in \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}}. +This field is read-only.} } \if{html}{\out{}} }