From b285c5129c210bc20763fb70847fb19aabf5f240 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Thu, 21 Dec 2023 14:29:09 +0100 Subject: [PATCH] fix: remove task prototype when resample (#981) * refactor: remove task prototype when resample * refactor: add option to store prototype * fix: braket * refactor: null * fix: browser * keep prototypes in state when store_models is TRUE * refactor: only store data_prototype when train * refactor: feature names --------- Co-authored-by: Sebastian Fischer --- R/Learner.R | 5 +++++ R/worker.R | 4 +--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/R/Learner.R b/R/Learner.R index 8ca69a4b5..9f1863548 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -244,6 +244,11 @@ Learner = R6Class("Learner", learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode) + # store data prototype + proto = task$data(rows = integer()) + self$state$data_prototype = proto + self$state$task_prototype = proto + # store the task w/o the data self$state$train_task = task_rm_backend(task$clone(deep = TRUE)) diff --git a/R/worker.R b/R/worker.R index 324eb7652..8fa0008df 100644 --- a/R/worker.R +++ b/R/worker.R @@ -68,15 +68,13 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL log = append_log(NULL, "train", result$log$class, result$log$msg) train_time = result$elapsed - proto = task$data(rows = integer()) learner$state = insert_named(learner$state, list( model = result$result, log = log, train_time = train_time, param_vals = learner$param_set$values, task_hash = task$hash, - data_prototype = proto, - task_prototype = proto, + feature_names = task$feature_names, mlr3_version = mlr_reflections$package_version ))