Skip to content

Commit

Permalink
Cleanup assertions (#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg authored Aug 10, 2022
1 parent 159a0f9 commit 7ad7d70
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ assert_learner = function(learner, task = NULL, task_type = NULL, properties = c
task_type = task_type %??% task$task_type
# check on class(learner) does not work with GraphLearner and AutoTuner
# check on learner$task_type does not work with TaskUnsupervised
if (!is.null(task_type) && fget(mlr_reflections$task_types, task_type, "learner", "type") != fget(mlr_reflections$task_types, learner$task_type, "learner", "type") && fget(mlr_reflections$task_types, task_type, "learner", "type") %nin% class(learner)) {
if (!test_matching_task_type(task_type, learner, "learner")) {
stopf("Learner '%s' must have task type '%s'", learner$id, task_type)
}

Expand All @@ -89,6 +89,20 @@ assert_learner = function(learner, task = NULL, task_type = NULL, properties = c
invisible(learner)
}

test_matching_task_type = function(task_type, object, class) {
if (is.null(task_type) || object$task_type == task_type) {
return(TRUE)
}

cl_task_type = fget(mlr_reflections$task_types, task_type, class, "type")
if (inherits(object, cl_task_type)) {
return(TRUE)
}

cl_object = fget(mlr_reflections$task_types, object$task_type, class, "type")
return(cl_task_type == cl_object)
}


#' @export
#' @param learners (list of [Learner]).
Expand All @@ -104,7 +118,8 @@ assert_task_learner = function(task, learner, cols = NULL) {
}
# check on class(learner) does not work with GraphLearner and AutoTuner
# check on learner$task_type does not work with TaskUnsupervised
if (fget(mlr_reflections$task_types, task$task_type, "learner", "type") != fget(mlr_reflections$task_types, learner$task_type, "learner", "type") && fget(mlr_reflections$task_types, task$task_type, "learner", "type") %nin% class(learner)) {

if (!test_matching_task_type(task$task_type, learner, "learner")) {
stopf("Type '%s' of %s does not match type '%s' of %s",
task$task_type, task$format(), learner$task_type, learner$format())
}
Expand Down Expand Up @@ -155,9 +170,8 @@ assert_predictable = function(task, learner) {
assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vname(measure)) {
assert_class(measure, "Measure", .var.name = .var.name)


if (!is.null(task)) {
if (!is_scalar_na(measure$task_type) && fget(mlr_reflections$task_types, task$task_type, "measure", "type") %nin% class(measure)) {
if (!is_scalar_na(measure$task_type) && !test_matching_task_type(task$task_type, measure, "measure")) {
stopf("Measure '%s' is not compatible with type '%s' of task '%s'",
measure$id, task$task_type, task$id)
}
Expand Down
2 changes: 1 addition & 1 deletion R/helper_exec.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
chunk_size = getOption("mlr3.exec_chunk_size", 1)
stdout = if (is_sequential) NA else TRUE

lg$debug("Running resample() via future with %n iterations", n)
lg$debug("Running resample() via future with %i iterations", n)
future.apply::future_mapply(
FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE,
future.globals = FALSE, future.packages = "mlr3", future.seed = TRUE,
Expand Down

0 comments on commit 7ad7d70

Please sign in to comment.