diff --git a/R/Learner.R b/R/Learner.R index 8ca69a4b5..9f1863548 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -244,6 +244,11 @@ Learner = R6Class("Learner", learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode) + # store data prototype + proto = task$data(rows = integer()) + self$state$data_prototype = proto + self$state$task_prototype = proto + # store the task w/o the data self$state$train_task = task_rm_backend(task$clone(deep = TRUE)) diff --git a/R/worker.R b/R/worker.R index 324eb7652..8fa0008df 100644 --- a/R/worker.R +++ b/R/worker.R @@ -68,15 +68,13 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL log = append_log(NULL, "train", result$log$class, result$log$msg) train_time = result$elapsed - proto = task$data(rows = integer()) learner$state = insert_named(learner$state, list( model = result$result, log = log, train_time = train_time, param_vals = learner$param_set$values, task_hash = task$hash, - data_prototype = proto, - task_prototype = proto, + feature_names = task$feature_names, mlr3_version = mlr_reflections$package_version ))