From 959c9d61bfbf442a4a1755b672733f4fa7d0ecfe Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Wed, 13 Dec 2023 16:27:48 +0100 Subject: [PATCH] fix: reduce number of data table threads when running with future (#979) * fix: reduce number of data table threads when running with future * refactor: pass is sequential to worker * test: data table threads are not changed in main session * fix: use default * feat: reduce blas threads to 1 * docs: namespace * docs: namespace * fix: roxygen * chore: update news --- DESCRIPTION | 3 ++- NAMESPACE | 2 ++ NEWS.md | 1 + R/helper_exec.R | 2 ++ R/worker.R | 10 +++++++++- R/zzz.R | 1 + tests/testthat/test_parallel.R | 19 +++++++++++++++++++ 7 files changed, 36 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6a209ec99..298b875c4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,6 +56,7 @@ Imports: parallelly, palmerpenguins, paradox (>= 0.10.0), + RhpcBLASctl, uuid Suggests: Matrix, @@ -73,7 +74,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3.9000 +RoxygenNote: 7.2.3 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' diff --git a/NAMESPACE b/NAMESPACE index 1e4db46ba..f900058f0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -234,6 +234,8 @@ import(palmerpenguins) import(paradox) importFrom(R6,R6Class) importFrom(R6,is.R6) +importFrom(RhpcBLASctl,blas_get_num_procs) +importFrom(RhpcBLASctl,blas_set_num_threads) importFrom(data.table,as.data.table) importFrom(data.table,data.table) importFrom(future,nbrOfWorkers) diff --git a/NEWS.md b/NEWS.md index 52e5d53c0..7f95d6c97 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # mlr3 (development version) +* Reduce number of threads used by `data.table` and BLAS to 1 when running `resample()` or `benchmark()` in parallel. * Optimize runtime of `resample()` and `benchmark()` by reducing the number of hashing operations. # mlr3 0.17.0 diff --git a/R/helper_exec.R b/R/helper_exec.R index dc6a034b1..bf68f13a9 100644 --- a/R/helper_exec.R +++ b/R/helper_exec.R @@ -36,6 +36,8 @@ future_map = function(n, FUN, ..., MoreArgs = list()) { } stdout = if (is_sequential) NA else TRUE + MoreArgs = c(MoreArgs, list(is_sequential = is_sequential)) + lg$debug("Running resample() via future with %i iterations", n) future.apply::future_mapply( FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE, diff --git a/R/worker.R b/R/worker.R index e2d8e312a..324eb7652 100644 --- a/R/worker.R +++ b/R/worker.R @@ -219,11 +219,19 @@ learner_predict = function(learner, task, row_ids = NULL) { } -workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train") { +workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE) { if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } + # reduce data.table and blas threads to 1 + if (!is_sequential) { + setDTthreads(1, restore_after_fork = TRUE) + old_blas_threads = blas_get_num_procs() + on.exit(blas_set_num_threads(old_blas_threads), add = TRUE) + blas_set_num_threads(1) + } + # restore logger thresholds for (package in names(lgr_threshold)) { logger = lgr::get_logger(package) diff --git a/R/zzz.R b/R/zzz.R index caef148c6..9b2e08e85 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -10,6 +10,7 @@ #' @importFrom uuid UUIDgenerate #' @importFrom parallelly availableCores #' @importFrom future nbrOfWorkers plan +#' @importFrom RhpcBLASctl blas_set_num_threads blas_get_num_procs #' #' @section Learn mlr3: #' * Book on mlr3: \url{https://mlr3book.mlr-org.com} diff --git a/tests/testthat/test_parallel.R b/tests/testthat/test_parallel.R index 4e19f4720..6f3acdf23 100644 --- a/tests/testthat/test_parallel.R +++ b/tests/testthat/test_parallel.R @@ -85,3 +85,22 @@ test_that("parallel seed", { }) expect_equal(rr1$prediction()$prob, rr2$prediction()$prob) }) + +test_that("data table threads are not changed in main session", { + old_dt_threads = getDTthreads() + on.exit({ + setDTthreads(old_dt_threads) + }, add = TRUE) + setDTthreads(2L) + + task = tsk("sonar") + learner = lrn("classif.debug", predict_type = "prob") + resampling = rsmp("cv", folds = 3L) + measure = msr("classif.auc") + + rr1 = with_seed(123, with_future(future::sequential, resample(task, learner, resampling))) + expect_equal(getDTthreads(), 2L) + + rr2 = with_seed(123, with_future(future::multisession, resample(task, learner, resampling))) + expect_equal(getDTthreads(), 2L) +})