From abda590f03fa0353d9657ed8846c268af4be5ce2 Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Tue, 18 Jan 2022 14:48:06 +0100 Subject: [PATCH] feat: new argument to avoid clones in resample() and benchmark() (#756) --- R/benchmark.R | 16 ++++++++++++---- R/benchmark_grid.R | 4 ++-- R/helper_hashes.R | 1 - R/resample.R | 23 ++++++++++++----------- man-roxygen/param_clone.R | 5 +++++ man/benchmark.Rd | 9 ++++++++- man/resample.Rd | 9 ++++++++- tests/testthat/test_benchmark.R | 21 +++++++++++++++++++++ tests/testthat/test_resample.R | 17 +++++++++++++++++ 9 files changed, 85 insertions(+), 20 deletions(-) create mode 100644 man-roxygen/param_clone.R diff --git a/R/benchmark.R b/R/benchmark.R index e4f1b35ab..457c06282 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -12,6 +12,7 @@ #' @template param_store_backends #' @template param_encapsulate #' @template param_allow_hotstart +#' @template param_clone #' #' @return [BenchmarkResult]. #' @@ -75,7 +76,8 @@ #' ## Get the training set of the 2nd iteration of the featureless learner on penguins #' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]] #' rr$resampling$train_set(2) -benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE) { +benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) { + assert_subset(clone, c("task", "learner", "resampling")) assert_data_frame(design, min.rows = 1L) assert_names(names(design), permutation.of = c("task", "learner", "resampling")) design$task = list(assert_tasks(as_tasks(design$task))) @@ -90,9 +92,15 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps # clone inputs setDT(design) task = learner = resampling = NULL - design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))] - design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))] - design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))] + if ("task" %in% clone) { + design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))] + } + if ("learner" %in% clone) { + design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))] + } + if ("resampling" %in% clone) { + design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))] + } # set encapsulation + fallback set_encapsulation(design$learner, encapsulate) diff --git a/R/benchmark_grid.R b/R/benchmark_grid.R index 04622d6b3..18da53319 100644 --- a/R/benchmark_grid.R +++ b/R/benchmark_grid.R @@ -48,10 +48,10 @@ benchmark_grid = function(tasks, learners, resamplings) { if (any(is_instantiated)) { task_nrow = unique(map_int(tasks, "nrow")) if (length(task_nrow) != 1L) { - stopf("All resamplings must be uninstantiated, or must have the same number of rows") + stopf("All resamplings must be uninstantiated, or must operate on tasks with the same number of rows") } if (!identical(task_nrow, unique(map_int(resamplings, "task_nrow")))) { - stop("Resampling is instantiated for a task with a different number of observations") + stop("A Resampling is instantiated for a task with a different number of observations") } instances = pmap(grid, function(task, resampling) resamplings[[resampling]]$clone()) } else { diff --git a/R/helper_hashes.R b/R/helper_hashes.R index b2c633fd6..31d92f30c 100644 --- a/R/helper_hashes.R +++ b/R/helper_hashes.R @@ -23,4 +23,3 @@ task_hashes = function(task, resampling) { task$properties) }) } - diff --git a/R/resample.R b/R/resample.R index 64839b1ba..b63a85884 100644 --- a/R/resample.R +++ b/R/resample.R @@ -13,6 +13,7 @@ #' @template param_store_backends #' @template param_encapsulate #' @template param_allow_hotstart +#' @template param_clone #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -53,20 +54,20 @@ #' bmr1 = as_benchmark_result(rr) #' bmr2 = as_benchmark_result(rr_featureless) #' print(bmr1$combine(bmr2)) -resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE) { - task = assert_task(as_task(task, clone = TRUE)) - learner = assert_learner(as_learner(learner, clone = TRUE)) - resampling = assert_resampling(as_resampling(resampling)) +resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) { + assert_subset(clone, c("task", "learner", "resampling")) + task = assert_task(as_task(task, clone = "task" %in% clone)) + learner = assert_learner(as_learner(learner, clone = "learner" %in% clone)) + resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %in% clone)) assert_flag(store_models) assert_flag(store_backends) assert_learnable(task, learner) set_encapsulation(list(learner), encapsulate) - instance = resampling$clone(deep = TRUE) - if (!instance$is_instantiated) { - instance = instance$instantiate(task) + if (!resampling$is_instantiated) { + resampling = resampling$instantiate(task) } - n = instance$iters + n = resampling$iters pb = if (isNamespaceLoaded("progressr")) { # NB: the progress bar needs to be created in this env pb = progressr::progressor(steps = n) @@ -108,7 +109,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe lg$info("Running resample() sequentially in debug mode with %i iterations", n) res = mapply(workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = instance, store_models = store_models, lgr_threshold = lgr_threshold, + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb), SIMPLIFY = FALSE ) } else { @@ -116,7 +117,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe res = future.apply::future_mapply(workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = instance, store_models = store_models, lgr_threshold = lgr_threshold, + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb), SIMPLIFY = FALSE, future.globals = FALSE, future.scheduling = structure(TRUE, ordering = "random"), future.packages = "mlr3", future.seed = TRUE, future.stdout = future_stdout() @@ -127,7 +128,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe task = list(task), learner = grid$learner, learner_state = map(res, "learner_state"), - resampling = list(instance), + resampling = list(resampling), iteration = seq_len(n), prediction = map(res, "prediction"), uhash = UUIDgenerate() diff --git a/man-roxygen/param_clone.R b/man-roxygen/param_clone.R new file mode 100644 index 000000000..965f7074a --- /dev/null +++ b/man-roxygen/param_clone.R @@ -0,0 +1,5 @@ +#' @param clone (`character()`)\cr +#' Select the input objects to be cloned before proceeding by +#' providing a set with possible values `"task"`, `"learner"` and +#' `"resampling"` for [Task], [Learner] and [Resampling], respectively. +#' Per default, all input objects are cloned. diff --git a/man/benchmark.Rd b/man/benchmark.Rd index f34c87527..4f67680d1 100644 --- a/man/benchmark.Rd +++ b/man/benchmark.Rd @@ -9,7 +9,8 @@ benchmark( store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, - allow_hotstart = FALSE + allow_hotstart = FALSE, + clone = c("task", "learner", "resampling") ) } \arguments{ @@ -48,6 +49,12 @@ does not already have a fallback configured.} \item{allow_hotstart}{(\code{logical(1)})\cr Determines if learner(s) are hot started with trained models in \verb{$hotstart_stack}. See also \link{HotstartStack}.} + +\item{clone}{(\code{character()})\cr +Select the input objects to be cloned before proceeding by +providing a set with possible values \code{"task"}, \code{"learner"} and +\code{"resampling"} for \link{Task}, \link{Learner} and \link{Resampling}, respectively. +Per default, all input objects are cloned.} } \value{ \link{BenchmarkResult}. diff --git a/man/resample.Rd b/man/resample.Rd index adbf2bf33..6f996329c 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -11,7 +11,8 @@ resample( store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, - allow_hotstart = FALSE + allow_hotstart = FALSE, + clone = c("task", "learner", "resampling") ) } \arguments{ @@ -50,6 +51,12 @@ does not already have a fallback configured.} \item{allow_hotstart}{(\code{logical(1)})\cr Determines if learner(s) are hot started with trained models in \verb{$hotstart_stack}. See also \link{HotstartStack}.} + +\item{clone}{(\code{character()})\cr +Select the input objects to be cloned before proceeding by +providing a set with possible values \code{"task"}, \code{"learner"} and +\code{"resampling"} for \link{Task}, \link{Learner} and \link{Resampling}, respectively. +Per default, all input objects are cloned.} } \value{ \link{ResampleResult}. diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index dc45b4fcd..5800c2ca8 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -305,3 +305,24 @@ test_that("encapsulatiion", { expect_equal(learner$encapsulate[["predict"]], "evaluate") } }) + +test_that("disable cloning", { + grid = benchmark_grid( + tasks = tsk("iris"), + learners = lrn("classif.featureless"), + resamplings = rsmp("holdout") + ) + task = grid$task[[1L]] + learner = grid$learner[[1L]] + resampling = grid$resampling[[1L]] + + bmr = benchmark(grid, clone = c()) + + expect_same_address(task, bmr$tasks$task[[1]]) + expect_same_address(learner, get_private(bmr)$.data$data$learners$learner[[1]]) + expect_same_address(resampling, bmr$resamplings$resampling[[1]]) + + expect_identical(task$hash, bmr$tasks$task[[1]]$hash) + expect_identical(learner$hash, bmr$learners$learner[[1]]$hash) + expect_identical(resampling$hash, bmr$resamplings$resampling[[1]]$hash) +}) diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 68b3168ff..e5f6eb5b3 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -126,3 +126,20 @@ test_that("encapsulation", { expect_equal(rr$learner$encapsulate[["train"]], "evaluate") expect_equal(rr$learner$encapsulate[["predict"]], "evaluate") }) + +test_that("disable cloning", { + task = tsk("iris") + learner = lrn("classif.featureless") + resampling = rsmp("holdout") + + rr = resample(task, learner, resampling, clone = c()) + + expect_same_address(task, rr$task) + expect_same_address(learner, get_private(rr)$.data$data$learners$learner[[1]]) + expect_same_address(resampling, rr$resampling) + + expect_identical(task$hash, rr$task$hash) + expect_identical(learner$hash, rr$learner$hash) + expect_true(resampling$is_instantiated) + expect_identical(resampling$hash, rr$resampling$hash) +})