Skip to content

Commit

Permalink
get_values() with dependencies and tokens
Browse files Browse the repository at this point in the history
closes #415
  • Loading branch information
mb706 committed Nov 13, 2024
1 parent dfe06da commit 6454a9e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
24 changes: 12 additions & 12 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ ParamSet = R6Class("ParamSet",
values = self$values
ns = names(values)

deps = self$deps
if (remove_dependencies && nrow(deps)) {
for (j in seq_row(deps)) {
p1id = deps$id[[j]]
p2id = deps$on[[j]]
cond = deps$cond[[j]]
if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(condition_test(cond, values[[p2id]]))) {
values[p1id] = NULL
}
}
}

if (type == "without_token") {
values = discard(values, is, "TuneToken")
} else if (type == "only_token") {
Expand All @@ -218,18 +230,6 @@ ParamSet = R6Class("ParamSet",
}
}

deps = self$deps
if (remove_dependencies && nrow(deps)) {
for (j in seq_row(deps)) {
p1id = deps$id[[j]]
p2id = deps$on[[j]]
cond = deps$cond[[j]]
if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(condition_test(cond, values[[p2id]]))) {
values[p1id] = NULL
}
}
}

values[match(self$ids(class = class, tags = tags, any_tags = any_tags), names(values), nomatch = 0)]
},

Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,25 @@ test_that("disable internal tuning", {
expect_error(param_set$disable_internal_tuning("c"))
expect_error(param_set$disable_internal_tuning("b"))
})

test_that("get_values works with tokens and dependencies", {
ps = ps(
cost = p_dbl(0, default = 1, tags = "train", depends = quote(type == "C-classification")),
kernel = p_fct(c("linear", "polynomial", "radial", "sigmoid"), default = "radial", tags = "train"),
type = p_fct(c("C-classification", "nu-classification"), default = "C-classification", tags = "train")
)

ps$set_values(cost = to_tune(1e-5, 1e5, logscale = TRUE), kernel = "radial", type = "C-classification")
expect_equal(ps$get_values(type = "only_token"), list(cost = to_tune(1e-5, 1e5, logscale = TRUE)))

ps = ps(
cost = p_dbl(0, default = 1, tags = "train"),
kernel = p_fct(c("linear", "polynomial", "radial", "sigmoid"), default = "radial", tags = "train"),
type = p_fct(c("C-classification", "nu-classification"), default = "C-classification", tags = "train")
)

ps$set_values(cost = to_tune(1e-5, 1e5, logscale = TRUE), kernel = "radial", type = "C-classification")
ps$get_values(type = "only_token")

expect_equal(ps$get_values(type = "only_token"), list(cost = to_tune(1e-5, 1e5, logscale = TRUE)))
})

0 comments on commit 6454a9e

Please sign in to comment.