Skip to content

Commit

Permalink
Merge pull request #1412 from tidymodels/estimate-sparsity-dummy-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Jan 16, 2025
2 parents e78ed1c + 1e5e317 commit 8c7bc29
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
20 changes: 20 additions & 0 deletions R/dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,23 @@ tidy.step_dummy <- function(x, ...) {
res$id <- x$id
res
}

#' @export
.recipes_estimate_sparsity.step_dummy <- function(x, data, ...) {
get_levels <- function(x) {
if (is.factor(x)) {
return(length(levels(x)))
} else {
return(vctrs::vec_unique_count(x))
}
}

n_levels <- lapply(data, get_levels)

lapply(n_levels, function(n_lvl) {
c(
n_cols = ifelse(x$one_hot, n_lvl, n_lvl - 1),
sparsity = 1 - 1 / n_lvl
)
})
}
12 changes: 0 additions & 12 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,3 @@ is_sparse_matrix <- function(x) {

zeroes / (n_rows * n_cols)
}

#' @export
.recipes_estimate_sparsity.step_dummy <- function(x, data, ...) {
n_levels <- lapply(data, function(x) length(levels(x)))

lapply(n_levels, function(n_lvl) {
c(
n_cols = ifelse(x$one_hot, n_lvl, n_lvl - 1),
sparsity = 1 - 1 / n_lvl
)
})
}
24 changes: 24 additions & 0 deletions tests/testthat/test-dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,30 @@ test_that("sparse = 'yes' errors on unsupported contrasts", {
)
})

test_that(".recipes_toggle_sparse_args works", {
rec <- recipe(~., iris) %>%
step_dummy(all_nominal_predictors())

exp <- rec %>% prep() %>% bake(NULL) %>% sparsevctrs::sparsity()

expect_equal(
.recipes_estimate_sparsity(rec),
exp
)

iris$Species <- as.character(iris$Species)

rec <- recipe(~., iris) %>%
step_dummy(all_nominal_predictors())

exp <- rec %>% prep() %>% bake(NULL) %>% sparsevctrs::sparsity()

expect_equal(
.recipes_estimate_sparsity(rec),
exp
)
})

# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
Expand Down

0 comments on commit 8c7bc29

Please sign in to comment.