Skip to content

Commit

Permalink
refactor: reduce number of hashes (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc authored Dec 13, 2023
1 parent 244572f commit 318748b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3 (development version)

* Optimize runtime of `resample()` and `benchmark()` by reducing the number of hashing operations.

# mlr3 0.17.0

* Learners cannot be added to the `HotstartStack` anymore when the model is missing.
Expand Down
34 changes: 17 additions & 17 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,33 @@ ResultData = R6Class("ResultData",
if (nrow(data) == 0L) {
self$data = star_init()
} else {
fact = data[, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"),
with = FALSE]
set(fact, j = "task_hash", value = hashes(fact$task))
set(fact, j = "learner_phash", value = phashes(fact$learner))
set(fact, j = "resampling_hash", value = hashes(fact$resampling))

uhashes = data.table(uhash = unique(fact$uhash))
tasks = fact[, list(task = .SD$task[1L]),
setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"))
uhashes = data.table(uhash = unique(data$uhash))
setkeyv(data, c("uhash", "iteration"))

data[, task_hash := task[[1]]$hash, by = "uhash"]
data[, learner_phash := learner[[1]]$phash, by = "uhash"]
data[, resampling_hash := resampling[[1]]$hash, by = "uhash"]

tasks = data[, list(task = .SD$task[1L]),
keyby = "task_hash"]
learners = fact[, list(learner = list(.SD$learner[[1L]]$reset())),
learners = data[, list(learner = list(.SD$learner[[1L]]$reset())),
keyby = "learner_phash"]
resamplings = fact[, list(resampling = .SD$resampling[1L]),
resamplings = data[, list(resampling = .SD$resampling[1L]),
keyby = "resampling_hash"]
learner_components = fact[, list(learner_param_vals = list(.SD$param_values[[1]])),
learner_components = data[, list(learner_param_vals = list(.SD$param_values[[1]])),
keyby = "learner_hash"]

set(fact, j = "task", value = NULL)
set(fact, j = "learner", value = NULL)
set(fact, j = "resampling", value = NULL)
set(fact, j = "param_values", value = NULL)
setkeyv(fact, c("uhash", "iteration"))
set(data, j = "task", value = NULL)
set(data, j = "learner", value = NULL)
set(data, j = "resampling", value = NULL)
set(data, j = "param_values", value = NULL)

if (!store_backends) {
set(tasks, j = "task", value = lapply(tasks$task, task_rm_backend))
}

self$data = list(fact = fact, uhashes = uhashes, tasks = tasks, learners = learners,
self$data = list(fact = data, uhashes = uhashes, tasks = tasks, learners = learners,
resamplings = resamplings, learner_components = learner_components)
}
}
Expand Down

0 comments on commit 318748b

Please sign in to comment.