Skip to content

Commit

Permalink
Merge pull request #1337 from tidymodels/fix-splines2-errors
Browse files Browse the repository at this point in the history
Fix splines2 errors
  • Loading branch information
EmilHvitfeldt authored Jun 7, 2024
2 parents d488714 + 87389e5 commit 3ac64d1
Show file tree
Hide file tree
Showing 15 changed files with 283 additions and 67 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* New `extract_fit_time()` method has been added that returns the time it took to train the recipe. (#1071)

* `step_spline_b()`, `step_spline_convex()`, `step_spline_monotone()`, and `step_spline_nonnegative()` now throws informative errors if the`degree`, `deg_free`, and `complete_set` arguments causes an error. (#1170)

* Developer helper function `recipes_ptype()` has been added, returning expected input data for `prep()` and `bake()` for a given recipe object. (#1329)

* The `prefix` argument of `step_dummy_multi_choice()` is not properly documented. (#1298)
Expand Down
25 changes: 12 additions & 13 deletions R/spline_b.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,20 @@ step_spline_b_new <-
prep.step_spline_b <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

res <- list()

res <-
purrr::map2(
training[, col_names],
col_names,
~ spline2_create(
.x,
nm = .y,
.fn = "bSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
for (col_name in col_names) {
res[[col_name]] <- spline2_create(
training[[col_name]],
nm = col_name,
.fn = "bSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
}
# check for errors
bas_res <- purrr::map_lgl(res, is.null)
res <- res[!bas_res]
Expand Down
26 changes: 13 additions & 13 deletions R/spline_convex.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,20 @@ prep.step_spline_convex <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

res <-
purrr::map2(
training[, col_names],
col_names,
~ spline2_create(
.x,
nm = .y,
.fn = "cSpline",
df = x$deg_free,
degree = x$degree,
complete_set = x$complete_set,
fn_opts = x$options
)
res <- list()

for (col_name in col_names) {
res[[col_name]] <- spline2_create(
training[[col_name]],
nm = col_name,
.fn = "cSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
}

# check for errors
bas_res <- purrr::map_lgl(res, is.null)
res <- res[!bas_res]
Expand Down
21 changes: 18 additions & 3 deletions R/spline_helpers.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
spline2_create <- function(x, nm = "pred", .fn = "bSpline", df = 3, complete_set = TRUE,
degree = NULL, fn_opts = NULL) {
degree = NULL, fn_opts = NULL, call = rlang::caller_env()) {
vals <- c("bSpline", "cSpline", "iSpline", "mSpline", "naturalSpline", "bernsteinPoly")
.fn <- rlang::arg_match(.fn, vals)
fn_opts <- c(fn_opts, degree = degree)

if (.fn != "bernsteinPoly" && isTRUE(degree > (df - complete_set))) {
if (complete_set) {
cli::cli_abort(
"{.arg degree} ({degree}) must be less than to {.arg deg_free} \\
({df}) when {.code complete_set = FALSE}.",
call = call
)
} else {
cli::cli_abort(
"{.arg degree} ({degree}) must be less than or equal to {.arg deg_free} \\
({df}) when {.code complete_set = TRUE}.",
call = call
)
}
}

.cl <-
rlang::call2(
Expand Down Expand Up @@ -31,8 +47,7 @@ spline2_create <- function(x, nm = "pred", .fn = "bSpline", df = 3, complete_set
spline_msg <- function(x) {
x <- as.character(x)
x <- strsplit(x, "\\n")[[1]]
x <- paste0(x[-1], collapse = ". ")
cli::cli_warn(trimws(x, which = "left"))
cli::cli_abort(x)
}

spline2_apply <- function(object, new_data) {
Expand Down
25 changes: 12 additions & 13 deletions R/spline_monotone.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,19 @@ prep.step_spline_monotone <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

res <-
purrr::map2(
training[, col_names],
col_names,
~ spline2_create(
.x,
nm = .y,
.fn = "iSpline",
df = x$deg_free,
degree = x$degree,
complete_set = x$complete_set,
fn_opts = x$options
)
res <- list()

for (col_name in col_names) {
res[[col_name]] <- spline2_create(
training[[col_name]],
nm = col_name,
.fn = "iSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
}
# check for errors
bas_res <- purrr::map_lgl(res, is.null)
res <- res[!bas_res]
Expand Down
25 changes: 13 additions & 12 deletions R/spline_natural.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,20 @@ prep.step_spline_natural <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

res <-
purrr::map2(
training[, col_names],
col_names,
~ spline2_create(
.x,
nm = .y,
.fn = "naturalSpline",
df = max(x$deg_free, 2),
complete_set = x$complete_set,
fn_opts = x$options
)
res <- list()

for (col_name in col_names) {
res[[col_name]] <- spline2_create(
training[[col_name]],
nm = col_name,
.fn = "naturalSpline",
df = max(x$deg_free, 2),
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
}

# check for errors
bas_res <- purrr::map_lgl(res, is.null)
res <- res[!bas_res]
Expand Down
26 changes: 13 additions & 13 deletions R/spline_nonnegative.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,20 @@ prep.step_spline_nonnegative <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))

res <-
purrr::map2(
training[, col_names],
col_names,
~ spline2_create(
.x,
nm = .y,
.fn = "mSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
res <- list()

for (col_name in col_names) {
res[[col_name]] <- spline2_create(
training[[col_name]],
nm = col_name,
.fn = "mSpline",
df = x$deg_free,
complete_set = x$complete_set,
degree = x$degree,
fn_opts = x$options
)
}

# check for errors
bas_res <- purrr::map_lgl(res, is.null)
res <- res[!bas_res]
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/spline_b.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
# errors if degree > deg_free (#1170)

Code
recipe(~., data = mtcars) %>% step_spline_b(mpg, degree = 3, deg_free = 3,
complete_set = TRUE) %>% prep()
Condition
Error in `step_spline_b()`:
Caused by error in `prep()`:
! `degree` (3) must be less than to `deg_free` (3) when `complete_set = FALSE`.

---

Code
recipe(~., data = mtcars) %>% step_spline_b(mpg, degree = 4, deg_free = 3,
complete_set = FALSE) %>% prep()
Condition
Error in `step_spline_b()`:
Caused by error in `prep()`:
! `degree` (4) must be less than or equal to `deg_free` (3) when `complete_set = TRUE`.

# check_name() is used

Code
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/spline_convex.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
# errors if degree > deg_free (#1170)

Code
recipe(~., data = mtcars) %>% step_spline_convex(mpg, degree = 3, deg_free = 3,
complete_set = TRUE) %>% prep()
Condition
Error in `step_spline_convex()`:
Caused by error in `prep()`:
! `degree` (3) must be less than to `deg_free` (3) when `complete_set = FALSE`.

---

Code
recipe(~., data = mtcars) %>% step_spline_convex(mpg, degree = 4, deg_free = 3,
complete_set = FALSE) %>% prep()
Condition
Error in `step_spline_convex()`:
Caused by error in `prep()`:
! `degree` (4) must be less than or equal to `deg_free` (3) when `complete_set = TRUE`.

# check_name() is used

Code
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/spline_monotone.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
# errors if degree > deg_free (#1170)

Code
recipe(~., data = mtcars) %>% step_spline_monotone(mpg, degree = 3, deg_free = 3,
complete_set = TRUE) %>% prep()
Condition
Error in `step_spline_monotone()`:
Caused by error in `prep()`:
! `degree` (3) must be less than to `deg_free` (3) when `complete_set = FALSE`.

---

Code
recipe(~., data = mtcars) %>% step_spline_monotone(mpg, degree = 4, deg_free = 3,
complete_set = FALSE) %>% prep()
Condition
Error in `step_spline_monotone()`:
Caused by error in `prep()`:
! `degree` (4) must be less than or equal to `deg_free` (3) when `complete_set = TRUE`.

# check_name() is used

Code
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/spline_nonnegative.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
# errors if degree > deg_free (#1170)

Code
recipe(~., data = mtcars) %>% step_spline_nonnegative(mpg, degree = 3,
deg_free = 3, complete_set = TRUE) %>% prep()
Condition
Error in `step_spline_nonnegative()`:
Caused by error in `prep()`:
! `degree` (3) must be less than to `deg_free` (3) when `complete_set = FALSE`.

---

Code
recipe(~., data = mtcars) %>% step_spline_nonnegative(mpg, degree = 4,
deg_free = 3, complete_set = FALSE) %>% prep()
Condition
Error in `step_spline_nonnegative()`:
Caused by error in `prep()`:
! `degree` (4) must be less than or equal to `deg_free` (3) when `complete_set = TRUE`.

# check_name() is used

Code
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-spline_b.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ test_that("correct basis functions", {
expect_equal(hydrogen_ns_te_res, hydrogen_ns_te_exp)
})

test_that("errors if degree > deg_free (#1170)", {
skip_if_not_installed("splines2")

expect_no_error(
recipe(~., data = mtcars) %>%
step_spline_b(mpg, degree = 2, deg_free = 3, complete_set = TRUE) %>%
prep()
)

expect_no_error(
recipe(~., data = mtcars) %>%
step_spline_b(mpg, degree = 3, deg_free = 3, complete_set = FALSE) %>%
prep()
)

expect_snapshot(
error = TRUE,
recipe(~., data = mtcars) %>%
step_spline_b(mpg, degree = 3, deg_free = 3, complete_set = TRUE) %>%
prep()
)

expect_snapshot(
error = TRUE,
recipe(~., data = mtcars) %>%
step_spline_b(mpg, degree = 4, deg_free = 3, complete_set = FALSE) %>%
prep()
)
})

test_that("check_name() is used", {
dat <- mtcars
dat$mpg_01 <- dat$mpg
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-spline_convex.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ test_that("correct convex functions", {
expect_equal(hydrogen_ns_te_res, hydrogen_ns_te_exp)
})

test_that("errors if degree > deg_free (#1170)", {
skip_if_not_installed("splines2")

expect_no_error(
recipe(~., data = mtcars) %>%
step_spline_convex(mpg, degree = 2, deg_free = 3, complete_set = TRUE) %>%
prep()
)

expect_no_error(
recipe(~., data = mtcars) %>%
step_spline_convex(mpg, degree = 3, deg_free = 3, complete_set = FALSE) %>%
prep()
)

expect_snapshot(
error = TRUE,
recipe(~., data = mtcars) %>%
step_spline_convex(mpg, degree = 3, deg_free = 3, complete_set = TRUE) %>%
prep()
)

expect_snapshot(
error = TRUE,
recipe(~., data = mtcars) %>%
step_spline_convex(mpg, degree = 4, deg_free = 3, complete_set = FALSE) %>%
prep()
)
})

test_that("check_name() is used", {
dat <- mtcars
dat$mpg_01 <- dat$mpg
Expand Down
Loading

0 comments on commit 3ac64d1

Please sign in to comment.