Skip to content

Commit

Permalink
fix: remove task prototype when resample (#981)
Browse files Browse the repository at this point in the history
* refactor: remove task prototype when resample

* refactor: add option to store prototype

* fix: braket

* refactor: null

* fix: browser

* keep prototypes in state when store_models is TRUE

* refactor: only store data_prototype when train

* refactor: feature names

---------

Co-authored-by: Sebastian Fischer <sebf.fischer@gmail.com>
  • Loading branch information
be-marc and sebffischer committed Dec 21, 2023
1 parent 959c9d6 commit b285c51
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 5 additions & 0 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ Learner = R6Class("Learner",

learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)

# store data prototype
proto = task$data(rows = integer())
self$state$data_prototype = proto
self$state$task_prototype = proto

# store the task w/o the data
self$state$train_task = task_rm_backend(task$clone(deep = TRUE))

Expand Down
4 changes: 1 addition & 3 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
log = append_log(NULL, "train", result$log$class, result$log$msg)
train_time = result$elapsed

proto = task$data(rows = integer())
learner$state = insert_named(learner$state, list(
model = result$result,
log = log,
train_time = train_time,
param_vals = learner$param_set$values,
task_hash = task$hash,
data_prototype = proto,
task_prototype = proto,
feature_names = task$feature_names,
mlr3_version = mlr_reflections$package_version
))

Expand Down

0 comments on commit b285c51

Please sign in to comment.