diff --git a/DESCRIPTION b/DESCRIPTION index ca7492a28..5f5c658b3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -181,6 +181,7 @@ Collate: 'benchmark.R' 'benchmark_grid.R' 'bibentries.R' + 'default_fallback.R' 'default_measures.R' 'fix_factor_levels.R' 'helper.R' diff --git a/NAMESPACE b/NAMESPACE index cd7d2137c..eadee06bf 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -74,6 +74,9 @@ S3method(col_info,DataBackend) S3method(col_info,data.table) S3method(create_empty_prediction_data,TaskClassif) S3method(create_empty_prediction_data,TaskRegr) +S3method(default_fallback,Learner) +S3method(default_fallback,LearnerClassif) +S3method(default_fallback,LearnerRegr) S3method(default_values,Learner) S3method(default_values,LearnerClassifRpart) S3method(default_values,LearnerRegrRpart) diff --git a/NEWS.md b/NEWS.md index 279e5f748..e9f06a425 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,7 @@ * feat: Add option to calculate the mean of the true values on the train set in `msr("regr.rsq")`. * feat: Default fallback learner is set when encapsulation is activated. * feat: Learners classif.debug and regr.debug have new methods `$importance()` and `$selected_features()` for testing, also in downstream packages +* feat: Create default fallback learner with `default_fallback()`. * feat: Check column roles when using `$set_col_roles()` and `$col_roles`. # mlr3 0.20.2 diff --git a/R/Learner.R b/R/Learner.R index b3ff1840e..94ff3edad 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -567,11 +567,11 @@ Learner = R6Class("Learner", assert_names(names(rhs), subset.of = c("train", "predict")) private$.encapsulate = insert_named(default, rhs) + # if there is no fallback, we get a default one if (is.null(private$.fallback)) { - # if there is no fallback, we get a default one from the reflections table - fallback_id = mlr_reflections$learner_fallback[[self$task_type]] - if (!is.null(fallback_id)) { - self$fallback = lrn(fallback_id, predict_type = self$predict_type) + fallback = default_fallback(self) + if (!is.null(fallback)) { + self$fallback = fallback } } }, diff --git a/R/default_fallback.R b/R/default_fallback.R new file mode 100644 index 000000000..5d618a583 --- /dev/null +++ b/R/default_fallback.R @@ -0,0 +1,64 @@ +#' @title Create a Fallback Learner +#' +#' @description +#' Create a fallback learner for a given learner. +#' The function searches for a suitable fallback learner based on the task type. +#' Additional checks are performed to ensure that the fallback learner supports the predict type. +#' +#' @param learner [Learner]\cr +#' The learner for which a fallback learner should be created. +#' @param ... `any`\cr +#' ignored. +#' +#' @return [Learner] +default_fallback = function(learner, ...) { + UseMethod("default_fallback") +} + +#' @rdname default_fallback +#' @export +default_fallback.Learner = function(learner, ...) { + # FIXME: remove when new encapsulate/fallback system is in place + return(NULL) +} + +#' @rdname default_fallback +#' @export +default_fallback.LearnerClassif = function(learner, ...) { + fallback = lrn("classif.featureless") + + # set predict type + if (learner$predict_type %nin% fallback$predict_types) { + stopf("Fallback learner '%s' does not support predict type '%s'.", fallback$id, learner$predict_type) + } + + fallback$predict_type = learner$predict_type + + return(fallback) +} + +#' @rdname default_fallback +#' @export +default_fallback.LearnerRegr = function(learner, ...) { + fallback = lrn("regr.featureless") + + # set predict type + if (learner$predict_type %nin% fallback$predict_types) { + stopf("Fallback learner '%s' does not support predict type '%s'.", fallback$id, learner$predict_type) + } + + fallback$predict_type = learner$predict_type + + # set quantiles + if (learner$predict_type == "quantiles") { + + if (is.null(learner$quantiles) || is.null(learner$quantile_response)) { + stopf("Cannot set quantiles for fallback learner. Set `$quantiles` and `$quantile_response` in %s.", learner$id) + } + + fallback$quantiles = learner$quantiles + fallback$quantile_response = learner$quantile_response + } + + return(fallback) +} diff --git a/R/mlr_reflections.R b/R/mlr_reflections.R index f103871b4..6a18adc32 100644 --- a/R/mlr_reflections.R +++ b/R/mlr_reflections.R @@ -127,11 +127,6 @@ local({ regr = list(response = "response", se = c("response", "se"), quantiles = c("response", "quantiles"), distr = c("response", "se", "distr")) ) - mlr_reflections$learner_fallback = list( - classif = "classif.featureless", - regr = "regr.featureless" - ) - # Allowed tags for parameters mlr_reflections$learner_param_tags = c("train", "predict", "hotstart", "importance", "threads", "required", "internal_tuning") diff --git a/man/default_fallback.Rd b/man/default_fallback.Rd new file mode 100644 index 000000000..becad7031 --- /dev/null +++ b/man/default_fallback.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/default_fallback.R +\name{default_fallback} +\alias{default_fallback} +\alias{default_fallback.Learner} +\alias{default_fallback.LearnerClassif} +\alias{default_fallback.LearnerRegr} +\title{Create a Fallback Learner} +\usage{ +default_fallback(learner, ...) + +\method{default_fallback}{Learner}(learner, ...) + +\method{default_fallback}{LearnerClassif}(learner, ...) + +\method{default_fallback}{LearnerRegr}(learner, ...) +} +\arguments{ +\item{learner}{\link{Learner}\cr +The learner for which a fallback learner should be created.} + +\item{...}{\code{any}\cr +ignored.} +} +\value{ +\link{Learner} +} +\description{ +Create a fallback learner for a given learner. +The function searches for a suitable fallback learner based on the task type. +Additional checks are performed to ensure that the fallback learner supports the predict type. +} diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 41461b0e7..57c90b87b 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -78,6 +78,7 @@ reference: - starts_with("mlr_learners") - as_learner - HotstartStack + - default_fallback - title: Measures contents: - starts_with("mlr_measures") diff --git a/tests/testthat/test_set_fallback.R b/tests/testthat/test_set_fallback.R new file mode 100644 index 000000000..ea1060ed6 --- /dev/null +++ b/tests/testthat/test_set_fallback.R @@ -0,0 +1,25 @@ +test_that("fallback = default_fallback() works", { + learner = lrn("classif.rpart") + fallback = default_fallback(learner) + + expect_class(fallback, "LearnerClassifFeatureless") + expect_equal(fallback$predict_type, "response") + + learner = lrn("classif.rpart", predict_type = "prob") + fallback = default_fallback(learner) + + expect_class(fallback, "LearnerClassifFeatureless") + expect_equal(fallback$predict_type, "prob") + + learner = lrn("regr.rpart") + fallback = default_fallback(learner) + + expect_class(fallback, "LearnerRegrFeatureless") + expect_equal(fallback$predict_type, "response") + + learner = lrn("regr.debug", predict_type = "se") + fallback = default_fallback(learner) + + expect_class(fallback, "LearnerRegrFeatureless") + expect_equal(fallback$predict_type, "se") +})