From 6454a9e4057d5684c92a4ad3b77d0e5f75ad5bbd Mon Sep 17 00:00:00 2001 From: mb706 Date: Wed, 13 Nov 2024 22:34:15 +0100 Subject: [PATCH] get_values() with dependencies and tokens closes #415 --- R/ParamSet.R | 24 ++++++++++++------------ tests/testthat/test_ParamSet.R | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/R/ParamSet.R b/R/ParamSet.R index 1ea05ae6..2827bf56 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -203,6 +203,18 @@ ParamSet = R6Class("ParamSet", values = self$values ns = names(values) + deps = self$deps + if (remove_dependencies && nrow(deps)) { + for (j in seq_row(deps)) { + p1id = deps$id[[j]] + p2id = deps$on[[j]] + cond = deps$cond[[j]] + if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(condition_test(cond, values[[p2id]]))) { + values[p1id] = NULL + } + } + } + if (type == "without_token") { values = discard(values, is, "TuneToken") } else if (type == "only_token") { @@ -218,18 +230,6 @@ ParamSet = R6Class("ParamSet", } } - deps = self$deps - if (remove_dependencies && nrow(deps)) { - for (j in seq_row(deps)) { - p1id = deps$id[[j]] - p2id = deps$on[[j]] - cond = deps$cond[[j]] - if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(condition_test(cond, values[[p2id]]))) { - values[p1id] = NULL - } - } - } - values[match(self$ids(class = class, tags = tags, any_tags = any_tags), names(values), nomatch = 0)] }, diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index f5ae483c..fcddadde 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -513,3 +513,25 @@ test_that("disable internal tuning", { expect_error(param_set$disable_internal_tuning("c")) expect_error(param_set$disable_internal_tuning("b")) }) + +test_that("get_values works with tokens and dependencies", { + ps = ps( + cost = p_dbl(0, default = 1, tags = "train", depends = quote(type == "C-classification")), + kernel = p_fct(c("linear", "polynomial", "radial", "sigmoid"), default = "radial", tags = "train"), + type = p_fct(c("C-classification", "nu-classification"), default = "C-classification", tags = "train") + ) + + ps$set_values(cost = to_tune(1e-5, 1e5, logscale = TRUE), kernel = "radial", type = "C-classification") + expect_equal(ps$get_values(type = "only_token"), list(cost = to_tune(1e-5, 1e5, logscale = TRUE))) + + ps = ps( + cost = p_dbl(0, default = 1, tags = "train"), + kernel = p_fct(c("linear", "polynomial", "radial", "sigmoid"), default = "radial", tags = "train"), + type = p_fct(c("C-classification", "nu-classification"), default = "C-classification", tags = "train") + ) + + ps$set_values(cost = to_tune(1e-5, 1e5, logscale = TRUE), kernel = "radial", type = "C-classification") + ps$get_values(type = "only_token") + + expect_equal(ps$get_values(type = "only_token"), list(cost = to_tune(1e-5, 1e5, logscale = TRUE))) +})