Skip to content

Commit

Permalink
Merge pull request #1378 from tidymodels/misc-sparsevctrs
Browse files Browse the repository at this point in the history
Misc sparsevctrs
  • Loading branch information
EmilHvitfeldt authored Oct 4, 2024
2 parents 104d30d + 183f78f commit 3181a7e
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 31 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Imports:
Matrix,
purrr (>= 1.0.0),
rlang (>= 1.1.0),
sparsevctrs (>= 0.1.0.9001),
sparsevctrs (>= 0.1.0.9002),
stats,
tibble,
tidyr (>= 1.0.0),
Expand Down
2 changes: 1 addition & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ validate_training_data <- function(x, rec, fresh, call = rlang::caller_env()) {
x <- rec$template
} else {
if (is_sparse_matrix(x)) {
x <- sparsevctrs::coerce_to_sparse_tibble(x)
x <- sparsevctrs::coerce_to_sparse_tibble(x, call = call)
}
if (!is_tibble(x)) {
x <- as_tibble(x)
Expand Down
9 changes: 6 additions & 3 deletions R/recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ recipe.default <- function(x, ...) {

# Doing this here since it should work for all types of Matrix classes
if (is_sparse_matrix(x)) {
x <- sparsevctrs::coerce_to_sparse_tibble(x)
x <- sparsevctrs::coerce_to_sparse_tibble(x, call = caller_env(0))
return(recipe(x, ...))
}

Expand Down Expand Up @@ -218,7 +218,7 @@ recipe.formula <- function(formula, data, ...) {
}

if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data)
data <- sparsevctrs::coerce_to_sparse_tibble(data, call = caller_env(0))
}

if (!is_tibble(data)) {
Expand Down Expand Up @@ -705,7 +705,10 @@ bake.recipe <- function(object, new_data, ..., composition = "tibble") {
}

if (is_sparse_matrix(new_data)) {
new_data <- sparsevctrs::coerce_to_sparse_tibble(new_data)
new_data <- sparsevctrs::coerce_to_sparse_tibble(
new_data,
call = caller_env(0)
)
}

if (!is_tibble(new_data)) {
Expand Down
4 changes: 0 additions & 4 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
#' @name sparse_data
NULL

is_sparse_tibble <- function(x) {
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
}

is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}
16 changes: 16 additions & 0 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# recipe() errors if sparse matrix has no colnames

Code
recipe(~., data = hotel_data)
Condition
Error in `recipe()`:
! `x` must have column names.

---

Code
recipe(hotel_data)
Condition
Error in `recipe()`:
! `x` must have column names.

14 changes: 12 additions & 2 deletions tests/testthat/helper-sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ------------------------------------------------------------------------------
# For sparse tibble testing

sparse_hotel_rates <- function() {
sparse_hotel_rates <- function(tibble = FALSE) {
# 99.2 sparsity
hotel_rates <- modeldata::hotel_rates

Expand All @@ -22,5 +22,15 @@ sparse_hotel_rates <- function() {
)

res <- as.matrix(res)
Matrix::Matrix(res, sparse = TRUE)
res <- Matrix::Matrix(res, sparse = TRUE)

if (tibble) {
res <- sparsevctrs::coerce_to_sparse_tibble(res)

# materialize outcome
withr::local_options("sparsevctrs.verbose_materialize" = NULL)
res$avg_price_per_room <- res$avg_price_per_room[]
}

res
}
54 changes: 34 additions & 20 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,38 @@ test_that("recipe() accepts sparse tibbles", {
skip_if_not_installed("modeldata")
withr::local_options("sparsevctrs.verbose_materialize" = 3)

hotel_data <- sparse_hotel_rates()
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
hotel_data <- sparse_hotel_rates(tibble = TRUE)

expect_no_condition(
rec_spec <- recipe(avg_price_per_room ~ ., data = hotel_data)
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)

expect_no_condition(
rec_spec <- recipe(hotel_data)
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)

expect_no_condition(
rec_spec <- recipe(hotel_data, avg_price_per_room ~ .)
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)
})

test_that("prep() accepts sparse tibbles", {
skip_if_not_installed("modeldata")
withr::local_options("sparsevctrs.verbose_materialize" = 3)

hotel_data <- sparse_hotel_rates()
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
hotel_data <- sparse_hotel_rates(tibble = TRUE)

rec_spec <- recipe(avg_price_per_room ~ ., data = hotel_data)

Expand All @@ -44,24 +42,23 @@ test_that("prep() accepts sparse tibbles", {
)

expect_true(
is_sparse_tibble(rec$template)
sparsevctrs::has_sparse_elements(rec$template)
)

expect_no_error(
rec <- prep(rec_spec, training = hotel_data)
)

expect_true(
is_sparse_tibble(rec$template)
sparsevctrs::has_sparse_elements(rec$template)
)
})

test_that("bake() accepts sparse tibbles", {
skip_if_not_installed("modeldata")
withr::local_options("sparsevctrs.verbose_materialize" = 3)

hotel_data <- sparse_hotel_rates()
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
hotel_data <- sparse_hotel_rates(tibble = TRUE)

rec_spec <- recipe(avg_price_per_room ~ ., data = hotel_data) %>%
prep()
Expand All @@ -71,15 +68,15 @@ test_that("bake() accepts sparse tibbles", {
)

expect_true(
is_sparse_tibble(res)
sparsevctrs::has_sparse_elements(res)
)

expect_no_error(
res <- bake(rec_spec, new_data = hotel_data)
)

expect_true(
is_sparse_tibble(res)
sparsevctrs::has_sparse_elements(res)
)
})

Expand All @@ -94,23 +91,23 @@ test_that("recipe() accepts sparse matrices", {
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)

expect_no_condition(
rec_spec <- recipe(hotel_data)
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)

expect_no_condition(
rec_spec <- recipe(hotel_data, avg_price_per_room ~ .)
)

expect_true(
is_sparse_tibble(rec_spec$template)
sparsevctrs::has_sparse_elements(rec_spec$template)
)
})

Expand All @@ -127,15 +124,15 @@ test_that("prep() accepts sparse matrices", {
)

expect_true(
is_sparse_tibble(rec$template)
sparsevctrs::has_sparse_elements(rec$template)
)

expect_no_error(
rec <- prep(rec_spec, training = hotel_data)
)

expect_true(
is_sparse_tibble(rec$template)
sparsevctrs::has_sparse_elements(rec$template)
)
})

Expand All @@ -153,14 +150,31 @@ test_that("bake() accepts sparse matrices", {
)

expect_true(
is_sparse_tibble(res)
sparsevctrs::has_sparse_elements(res)
)

expect_no_error(
res <- bake(rec_spec, new_data = hotel_data)
)

expect_true(
is_sparse_tibble(res)
sparsevctrs::has_sparse_elements(res)
)
})

test_that("recipe() errors if sparse matrix has no colnames", {
skip_if_not_installed("modeldata")

hotel_data <- sparse_hotel_rates()
colnames(hotel_data) <- NULL

expect_snapshot(
error = TRUE,
recipe(~ ., data = hotel_data)
)

expect_snapshot(
error = TRUE,
recipe(hotel_data)
)
})

0 comments on commit 3181a7e

Please sign in to comment.