Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mlr-org/mlr3
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 19, 2024
2 parents 81136e6 + 666ad86 commit 43ceea5
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 0 deletions.
44 changes: 44 additions & 0 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,50 @@ Learner = R6Class("Learner",
private$.encapsulation = c(train = method, predict = method)
private$.fallback = fallback

return(invisible(self))
},

#' @description
#' Sets parameter values and fields of the learner.
#' All arguments whose names match the name of a parameter of the [paradox::ParamSet] are set as parameters.
#' All remaining arguments are assumed to be regular fields.
#'
#' @param ... (named `any`)\cr
#' Named arguments to set parameter values and fields.
#' @param .values (named `any`)\cr
#' Named list of parameter values and fields.
configure = function(..., .values = list()) {
dots = list(...)
assert_list(dots, names = "unique")
assert_list(.values, names = "unique")
assert_disjunct(names(dots), names(.values))
new_values = insert_named(dots, .values)

# set params in ParamSet
if (length(new_values)) {
param_ids = self$param_set$ids()
ii = names(new_values) %in% param_ids
if (any(ii)) {
self$param_set$values = insert_named(self$param_set$values, new_values[ii])
new_values = new_values[!ii]
}
} else {
param_ids = character()
}

# remaining args go into fields
if (length(new_values)) {
ndots = names(new_values)
for (i in seq_along(new_values)) {
nn = ndots[[i]]
if (!exists(nn, envir = self, inherits = FALSE)) {
stopf("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
nn, class(self)[1L], did_you_mean(nn, c(param_ids, setdiff(names(self), ".__enclos_env__")))) # nolint
}
self[[nn]] = new_values[[i]]
}
}

return(invisible(self))
}
),
Expand Down
24 changes: 24 additions & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.rpart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.rpart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,37 @@ test_that("predict time is cumulative", {
t2 = learner$timings["predict"]
expect_true(t1 > t2)
})

test_that("configure method works", {
learner = lrn("classif.rpart")

expect_learner(learner$configure())
expect_learner(learner$configure(.values = list()))

# set new hyperparameter value
learner$configure(cp = 0.1)
expect_equal(learner$param_set$values$cp, 0.1)

# overwrite existing hyperparameter value
learner$configure(xval = 10)
expect_equal(learner$param_set$values$xval, 10)

# set field
learner$configure(predict_sets = "train")
expect_equal(learner$predict_sets, "train")

# hyperparameter and field
learner$configure(minbucket = 2, parallel_predict = TRUE)
expect_equal(learner$param_set$values$minbucket, 2)
expect_true(learner$parallel_predict)

# unknown hyperparameter and field
expect_error(learner$configure(xvald = 1), "Cannot set argument")

# use .values
learner = lrn("classif.rpart")
learner$configure(.values = list(cp = 0.1, xval = 10, predict_sets = "train"))
expect_equal(learner$param_set$values$cp, 0.1)
expect_equal(learner$param_set$values$xval, 10)
expect_equal(learner$predict_sets, "train")
})

0 comments on commit 43ceea5

Please sign in to comment.