Skip to content

Commit

Permalink
fix: reduce number of data table threads when running with future (#979)
Browse files Browse the repository at this point in the history
* fix: reduce number of data table threads when running with future

* refactor: pass is sequential to worker

* test: data table threads are not changed in main session

* fix: use default

* feat: reduce blas threads to 1

* docs: namespace

* docs: namespace

* fix: roxygen

* chore: update news
  • Loading branch information
be-marc committed Dec 13, 2023
1 parent 318748b commit 959c9d6
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 2 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Imports:
parallelly,
palmerpenguins,
paradox (>= 0.10.0),
RhpcBLASctl,
uuid
Suggests:
Matrix,
Expand All @@ -73,7 +74,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.2.3
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ import(palmerpenguins)
import(paradox)
importFrom(R6,R6Class)
importFrom(R6,is.R6)
importFrom(RhpcBLASctl,blas_get_num_procs)
importFrom(RhpcBLASctl,blas_set_num_threads)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(future,nbrOfWorkers)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3 (development version)

* Reduce number of threads used by `data.table` and BLAS to 1 when running `resample()` or `benchmark()` in parallel.
* Optimize runtime of `resample()` and `benchmark()` by reducing the number of hashing operations.

# mlr3 0.17.0
Expand Down
2 changes: 2 additions & 0 deletions R/helper_exec.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
}
stdout = if (is_sequential) NA else TRUE

MoreArgs = c(MoreArgs, list(is_sequential = is_sequential))

lg$debug("Running resample() via future with %i iterations", n)
future.apply::future_mapply(
FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE,
Expand Down
10 changes: 9 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,19 @@ learner_predict = function(learner, task, row_ids = NULL) {
}


workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train") {
workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE) {
if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}

# reduce data.table and blas threads to 1
if (!is_sequential) {
setDTthreads(1, restore_after_fork = TRUE)
old_blas_threads = blas_get_num_procs()
on.exit(blas_set_num_threads(old_blas_threads), add = TRUE)
blas_set_num_threads(1)
}

# restore logger thresholds
for (package in names(lgr_threshold)) {
logger = lgr::get_logger(package)
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @importFrom uuid UUIDgenerate
#' @importFrom parallelly availableCores
#' @importFrom future nbrOfWorkers plan
#' @importFrom RhpcBLASctl blas_set_num_threads blas_get_num_procs
#'
#' @section Learn mlr3:
#' * Book on mlr3: \url{https://mlr3book.mlr-org.com}
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test_parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,22 @@ test_that("parallel seed", {
})
expect_equal(rr1$prediction()$prob, rr2$prediction()$prob)
})

test_that("data table threads are not changed in main session", {
old_dt_threads = getDTthreads()
on.exit({
setDTthreads(old_dt_threads)
}, add = TRUE)
setDTthreads(2L)

task = tsk("sonar")
learner = lrn("classif.debug", predict_type = "prob")
resampling = rsmp("cv", folds = 3L)
measure = msr("classif.auc")

rr1 = with_seed(123, with_future(future::sequential, resample(task, learner, resampling)))
expect_equal(getDTthreads(), 2L)

rr2 = with_seed(123, with_future(future::multisession, resample(task, learner, resampling)))
expect_equal(getDTthreads(), 2L)
})

0 comments on commit 959c9d6

Please sign in to comment.