Skip to content

Commit

Permalink
refactor: predict_type and predict_types (#1233)
Browse files Browse the repository at this point in the history
* refactor: predict_type and predict_types

* ...

* ...
  • Loading branch information
be-marc authored Dec 20, 2024
1 parent aae5342 commit 2fe8746
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3 (development version)

* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.

# mlr3 0.22.1

* fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`.
Expand Down
21 changes: 13 additions & 8 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,6 @@ Learner = R6Class("Learner",
#' @template field_task_type
task_type = NULL,

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
predict_types = NULL,

#' @field feature_types (`character()`)\cr
#' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`.
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
Expand Down Expand Up @@ -214,7 +209,7 @@ Learner = R6Class("Learner",
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
Expand Down Expand Up @@ -627,6 +622,8 @@ Learner = R6Class("Learner",
#' @field predict_type (`character(1)`)\cr
#' Stores the currently active predict type, e.g. `"response"`.
#' Must be an element of `$predict_types`.
#' A few learners already use the predict type during training.
#' So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.
predict_type = function(rhs) {
if (missing(rhs)) {
return(private$.predict_type)
Expand All @@ -648,8 +645,6 @@ Learner = R6Class("Learner",
private$.param_set
},



#' @field fallback ([Learner])\cr
#' Returns the fallback learner set with `$encapsulate()`.
fallback = function(rhs) {
Expand All @@ -672,13 +667,23 @@ Learner = R6Class("Learner",
}
assert_r6(rhs, "HotstartStack", null.ok = TRUE)
private$.hotstart_stack = rhs
},

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
#' This field is read-only.
predict_types = function(rhs) {
assert_ro_binding(rhs)
return(private$.predict_types)
}
),

private = list(
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
.predict_types = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

Expand Down
13 changes: 8 additions & 5 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2fe8746

Please sign in to comment.