Skip to content

Commit

Permalink
feat: new argument to avoid clones in resample() and benchmark() (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Jan 18, 2022
1 parent 6fd52b4 commit abda590
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 20 deletions.
16 changes: 12 additions & 4 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' @template param_store_backends
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#'
#' @return [BenchmarkResult].
#'
Expand Down Expand Up @@ -75,7 +76,8 @@
#' ## Get the training set of the 2nd iteration of the featureless learner on penguins
#' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]]
#' rr$resampling$train_set(2)
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE) {
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
assert_subset(clone, c("task", "learner", "resampling"))
assert_data_frame(design, min.rows = 1L)
assert_names(names(design), permutation.of = c("task", "learner", "resampling"))
design$task = list(assert_tasks(as_tasks(design$task)))
Expand All @@ -90,9 +92,15 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
# clone inputs
setDT(design)
task = learner = resampling = NULL
design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))]
design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))]
design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))]
if ("task" %in% clone) {
design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))]
}
if ("learner" %in% clone) {
design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))]
}
if ("resampling" %in% clone) {
design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))]
}

# set encapsulation + fallback
set_encapsulation(design$learner, encapsulate)
Expand Down
4 changes: 2 additions & 2 deletions R/benchmark_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ benchmark_grid = function(tasks, learners, resamplings) {
if (any(is_instantiated)) {
task_nrow = unique(map_int(tasks, "nrow"))
if (length(task_nrow) != 1L) {
stopf("All resamplings must be uninstantiated, or must have the same number of rows")
stopf("All resamplings must be uninstantiated, or must operate on tasks with the same number of rows")
}
if (!identical(task_nrow, unique(map_int(resamplings, "task_nrow")))) {
stop("Resampling is instantiated for a task with a different number of observations")
stop("A Resampling is instantiated for a task with a different number of observations")
}
instances = pmap(grid, function(task, resampling) resamplings[[resampling]]$clone())
} else {
Expand Down
1 change: 0 additions & 1 deletion R/helper_hashes.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ task_hashes = function(task, resampling) {
task$properties)
})
}

23 changes: 12 additions & 11 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' @template param_store_backends
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#' @return [ResampleResult].
#'
#' @template section_predict_sets
Expand Down Expand Up @@ -53,20 +54,20 @@
#' bmr1 = as_benchmark_result(rr)
#' bmr2 = as_benchmark_result(rr_featureless)
#' print(bmr1$combine(bmr2))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE) {
task = assert_task(as_task(task, clone = TRUE))
learner = assert_learner(as_learner(learner, clone = TRUE))
resampling = assert_resampling(as_resampling(resampling))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %in% clone))
resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %in% clone))
assert_flag(store_models)
assert_flag(store_backends)
assert_learnable(task, learner)

set_encapsulation(list(learner), encapsulate)
instance = resampling$clone(deep = TRUE)
if (!instance$is_instantiated) {
instance = instance$instantiate(task)
if (!resampling$is_instantiated) {
resampling = resampling$instantiate(task)
}
n = instance$iters
n = resampling$iters
pb = if (isNamespaceLoaded("progressr")) {
# NB: the progress bar needs to be created in this env
pb = progressr::progressor(steps = n)
Expand Down Expand Up @@ -108,15 +109,15 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
lg$info("Running resample() sequentially in debug mode with %i iterations", n)
res = mapply(workhorse,
iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
MoreArgs = list(task = task, resampling = instance, store_models = store_models, lgr_threshold = lgr_threshold,
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold,
pb = pb), SIMPLIFY = FALSE
)
} else {
lg$debug("Running resample() via future with %i iterations", n)

res = future.apply::future_mapply(workhorse,
iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
MoreArgs = list(task = task, resampling = instance, store_models = store_models, lgr_threshold = lgr_threshold,
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold,
pb = pb),
SIMPLIFY = FALSE, future.globals = FALSE, future.scheduling = structure(TRUE, ordering = "random"),
future.packages = "mlr3", future.seed = TRUE, future.stdout = future_stdout()
Expand All @@ -127,7 +128,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
task = list(task),
learner = grid$learner,
learner_state = map(res, "learner_state"),
resampling = list(instance),
resampling = list(resampling),
iteration = seq_len(n),
prediction = map(res, "prediction"),
uhash = UUIDgenerate()
Expand Down
5 changes: 5 additions & 0 deletions man-roxygen/param_clone.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @param clone (`character()`)\cr
#' Select the input objects to be cloned before proceeding by
#' providing a set with possible values `"task"`, `"learner"` and
#' `"resampling"` for [Task], [Learner] and [Resampling], respectively.
#' Per default, all input objects are cloned.
9 changes: 8 additions & 1 deletion man/benchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion man/resample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,24 @@ test_that("encapsulatiion", {
expect_equal(learner$encapsulate[["predict"]], "evaluate")
}
})

test_that("disable cloning", {
grid = benchmark_grid(
tasks = tsk("iris"),
learners = lrn("classif.featureless"),
resamplings = rsmp("holdout")
)
task = grid$task[[1L]]
learner = grid$learner[[1L]]
resampling = grid$resampling[[1L]]

bmr = benchmark(grid, clone = c())

expect_same_address(task, bmr$tasks$task[[1]])
expect_same_address(learner, get_private(bmr)$.data$data$learners$learner[[1]])
expect_same_address(resampling, bmr$resamplings$resampling[[1]])

expect_identical(task$hash, bmr$tasks$task[[1]]$hash)
expect_identical(learner$hash, bmr$learners$learner[[1]]$hash)
expect_identical(resampling$hash, bmr$resamplings$resampling[[1]]$hash)
})
17 changes: 17 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,20 @@ test_that("encapsulation", {
expect_equal(rr$learner$encapsulate[["train"]], "evaluate")
expect_equal(rr$learner$encapsulate[["predict"]], "evaluate")
})

test_that("disable cloning", {
task = tsk("iris")
learner = lrn("classif.featureless")
resampling = rsmp("holdout")

rr = resample(task, learner, resampling, clone = c())

expect_same_address(task, rr$task)
expect_same_address(learner, get_private(rr)$.data$data$learners$learner[[1]])
expect_same_address(resampling, rr$resampling)

expect_identical(task$hash, rr$task$hash)
expect_identical(learner$hash, rr$learner$hash)
expect_true(resampling$is_instantiated)
expect_identical(resampling$hash, rr$resampling$hash)
})

0 comments on commit abda590

Please sign in to comment.