Skip to content

Commit

Permalink
feat: set default fallback with set_fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Aug 30, 2024
1 parent 57b6109 commit f07c045
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 7 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Collate:
'predict.R'
'reexports.R'
'resample.R'
'set_fallback.R'
'set_threads.R'
'set_validate.R'
'task_converters.R'
Expand Down
9 changes: 2 additions & 7 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,8 @@ Learner = R6Class("Learner",
assert_names(names(rhs), subset.of = c("train", "predict"))
private$.encapsulate = insert_named(default, rhs)

if (is.null(private$.fallback)) {
# if there is no fallback, we get a default one from the reflections table
fallback_id = mlr_reflections$learner_fallback[[self$task_type]]
if (!is.null(fallback_id)) {
self$fallback = lrn(fallback_id, predict_type = self$predict_type)
}
}
# if there is no fallback, we get a default one from the reflections table
if (is.null(private$.fallback)) set_fallback(self)
},

#' @field fallback ([Learner])\cr
Expand Down
45 changes: 45 additions & 0 deletions R/set_fallback.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#' @title Set a Fallback Learner
#'
#' @description
#' Set a fallback learner for a given learner.
#' The function searches for a suitable fallback learner based on the task type.
#' Additional checks are performed to ensure that the fallback learner supports the predict type.
#'
#' @param learner [Learner]\cr
#' The learner for which a fallback learner should be set.
#'
#' @return
#' Returns the learner itself, but modified **by reference**.
set_fallback = function(learner) {
assert_learner(learner)

# search for suitable fallback learner
fallback_id = mlr_reflections$learner_fallback[[learner$task_type]]

if (is.null(fallback_id)) {
stopf("No fallback learner available for task type '%s'.", learner$task_type)
}

fallback = lrn(fallback_id)

# set predict type
if (learner$predict_type %nin% fallback$predict_types) {
stopf("Fallback learner '%s' does not support predict type '%s'.", fallback_id, learner$predict_type)
}

fallback$predict_type = learner$predict_type

# set quantiles
if (learner$predict_type == "quantiles") {

if (is.null(learner$quantiles) || is.null(learner$quantile_response)) {
stopf("Cannot set quantiles for fallback learner. Set `$quantiles` and `$quantile_response` in %s.", learner$id)
}

fallback$quantiles = learner$quantiles
fallback$quantile_response = learner$quantile_response
}

learner$fallback = fallback
return(learner)
}
20 changes: 20 additions & 0 deletions man/set_fallback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_set_fallback.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
test_that("set_fallback() works", {
learner = lrn("classif.rpart")
set_fallback(learner)

expect_class(learner, "LearnerClassifRpart")
expect_class(learner$fallback, "LearnerClassifFeatureless")
expect_equal(learner$fallback$predict_type, "response")

learner = lrn("classif.rpart", predict_type = "prob")
set_fallback(learner)

expect_class(learner, "LearnerClassifRpart")
expect_class(learner$fallback, "LearnerClassifFeatureless")
expect_equal(learner$fallback$predict_type, "prob")

learner = lrn("regr.rpart")
set_fallback(learner)

expect_class(learner, "LearnerRegrRpart")
expect_class(learner$fallback, "LearnerRegrFeatureless")
expect_equal(learner$fallback$predict_type, "response")

learner = lrn("regr.debug", predict_type = "se")
set_fallback(learner)

expect_class(learner, "LearnerRegrDebug")
expect_class(learner$fallback, "LearnerRegrFeatureless")
expect_equal(learner$fallback$predict_type, "se")
})

0 comments on commit f07c045

Please sign in to comment.