Skip to content

Commit

Permalink
Merge branch 'main' into selected_features
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 20, 2024
2 parents 391be36 + 89544ff commit eadc63f
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 46 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# mlr3 (development version)

* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.

# mlr3 0.22.1

* fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`.
Expand Down
2 changes: 1 addition & 1 deletion R/DataBackendRename.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
assert_character(old, any.missing = FALSE, unique = TRUE)
assert_subset(old, b$colnames)
assert_character(new, any.missing = FALSE, len = length(old))
assert_names(new, if (allow_utf8_names()) "unique" else "strict")
assert_names(new, "unique")

ii = old != new
old = old[ii]
Expand Down
21 changes: 13 additions & 8 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,6 @@ Learner = R6Class("Learner",
#' @template field_task_type
task_type = NULL,

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
predict_types = NULL,

#' @field feature_types (`character()`)\cr
#' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`.
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
Expand Down Expand Up @@ -214,7 +209,7 @@ Learner = R6Class("Learner",
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
Expand Down Expand Up @@ -642,6 +637,8 @@ Learner = R6Class("Learner",
#' @field predict_type (`character(1)`)\cr
#' Stores the currently active predict type, e.g. `"response"`.
#' Must be an element of `$predict_types`.
#' A few learners already use the predict type during training.
#' So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.
predict_type = function(rhs) {
if (missing(rhs)) {
return(private$.predict_type)
Expand All @@ -663,8 +660,6 @@ Learner = R6Class("Learner",
private$.param_set
},



#' @field fallback ([Learner])\cr
#' Returns the fallback learner set with `$encapsulate()`.
fallback = function(rhs) {
Expand Down Expand Up @@ -698,13 +693,23 @@ Learner = R6Class("Learner",
return(private$.selected_features_impute)
}
private$.selected_features_impute = assert_choice(rhs, c("error", "all"))
},

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
#' This field is read-only.
predict_types = function(rhs) {
assert_ro_binding(rhs)
return(private$.predict_types)
}
),

private = list(
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
.predict_types = NULL,
.param_set = NULL,
.hotstart_stack = NULL,
.selected_features_impute = "error",
Expand Down
10 changes: 3 additions & 7 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,9 @@ Task = R6Class("Task",
cn = self$backend$colnames
rn = self$backend$rownames

if (allow_utf8_names()) {
assert_names(cn, "unique", .var.name = "column names")
if (any(grepl("%", cn, fixed = TRUE))) {
stopf("Column names may not contain special character '%%'")
}
} else {
assert_names(cn, "strict", .var.name = "column names")
assert_names(cn, "unique", .var.name = "column names")
if (any(grepl("%", cn, fixed = TRUE))) {
stopf("Column names may not contain special character '%%'")
}

self$col_info = col_info(self$backend)
Expand Down
5 changes: 0 additions & 5 deletions R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ translate_types = function(x) {
factor(map_values(x, r_types, p_types), levels = p_types)
}


allow_utf8_names = function() {
isTRUE(getOption("mlr3.allow_utf8_names"))
}

get_featureless_learner = function(task_type) {
if (!is.na(task_type)) {
id = paste0(task_type, ".featureless")
Expand Down
4 changes: 0 additions & 4 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@
#' * `"mlr3.debug"`: If set to `TRUE`, parallelization via \CRANpkg{future} is disabled to simplify
#' debugging and provide more concise tracebacks.
#' Note that results computed in debug mode use a different seeding mechanism and are **not reproducible**.
#' * `"mlr3.allow_utf8_names"`: If set to `TRUE`, checks on the feature names are relaxed, allowing
#' non-ascii characters in column names. This is an experimental and temporal option to
#' pave the way for text analysis, and will likely be removed in a future version of the package.
#' analysis.
#' * `"mlr3.warn_version_mismatch"`: Set to `FALSE` to silence warnings raised during predict if a learner has been
#' trained with a different version version of mlr3.
#'
Expand Down
2 changes: 0 additions & 2 deletions inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ generate_generic_tasks = function(learner, proto) {

# task with non-ascii feature names
if (p > 0L) {
opts = options(mlr3.allow_utf8_names = TRUE)
on.exit(options(opts))
sel = proto$feature_types[list(learner$feature_types), "id", on = "type", with = FALSE, nomatch = NULL][[1L]]
tasks$utf8_feature_names = proto$clone(deep = TRUE)$select(sel)
old = sel[1L]
Expand Down
15 changes: 9 additions & 6 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions man/mlr3-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 0 additions & 9 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -552,15 +552,6 @@ test_that("set_levels", {
})

test_that("special chars in feature names (#697)", {
prev = options(mlr3.allow_utf8_names = FALSE)
on.exit(options(prev))

expect_error(
TaskRegr$new("test", data.table(`%^` = 1:3, t = 3:1), target = "t"),
"comply"
)
options(mlr3.allow_utf8_names = TRUE)

expect_error(
TaskRegr$new("test", data.table(`%asd` = 1:3, t = 3:1), target = "t")
,
Expand Down

0 comments on commit eadc63f

Please sign in to comment.