Skip to content

Commit

Permalink
Merge pull request #1410 from tidymodels/estimate-sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Jan 15, 2025
2 parents 7209781 + 8585981 commit eed3579
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 7 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ S3method(.get_data_types,logical)
S3method(.get_data_types,numeric)
S3method(.get_data_types,ordered)
S3method(.get_data_types,textrecipes_tokenlist)
S3method(.recipes_estimate_sparsity,default)
S3method(.recipes_estimate_sparsity,recipe)
S3method(.recipes_estimate_sparsity,step_dummy)
S3method(bake,check_class)
S3method(bake,check_cols)
S3method(bake,check_missing)
Expand Down Expand Up @@ -521,6 +524,7 @@ S3method(tune_args,step)
S3method(update,step)
export("%>%")
export(.get_data_types)
export(.recipes_estimate_sparsity)
export(.recipes_toggle_sparse_args)
export(add_check)
export(add_role)
Expand Down
77 changes: 76 additions & 1 deletion R/sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,79 @@ is_sparse_matrix <- function(x) {
}
}
x
}
}

#' Estimate sparity of a recipe
#'
#' @param x An object.
#'
#' @details
#' Takes a untrained recipe an provides a rough estimate of the sparsity of the
#' prepped version of the recipe.
#'
#' Sampling of the input is done to avoid slowdown for larger data sets.
#'
#' An estimated sparsity of the input data is calculated. Then each step where
#' `sparse = "auto"` or `sparse = "yes"` is set, an estimate of how many
#' predictores will be created and used to modify the estimate.
#'
#' An initial sparsity of 0 will be used if a zero-row data frame is used in
#' specification of recipe. This is likely a under-estimate of the true
#' sparsity of the data.
#'
#' @return A recipe
#'
#' @keywords internal
#'
#' @export
.recipes_estimate_sparsity <- function(x, ...) {
UseMethod(".recipes_estimate_sparsity")
}

#' @export
.recipes_estimate_sparsity.default <- function(x, ...) {
NULL
}

#' @export
.recipes_estimate_sparsity.recipe <- function(x, ...) {
template <- x$template
n_rows <- nrow(template)
n_cols <- ncol(template)

if (n_rows == 0) {
est_sparsity <- 0
n_rows <- 1 # messed the math up otherwise
} else {
est_sparsity <- sparsevctrs::sparsity(template, sample = 1000)
}
zeroes <- est_sparsity * n_rows * n_cols

for (step in x$steps) {
if (!is.null(step$sparse) && step$sparse != "no") {
col_names <- recipes_eval_select(step$terms, template, x$term_info)

adjustments <- .recipes_estimate_sparsity(step, template[col_names])

for (adjustment in adjustments) {
zeroes <- zeroes +
n_rows * adjustment[["sparsity"]] * adjustment[["n_cols"]]
n_cols <- n_cols + adjustment[["n_cols"]] - 1
}
}
}

zeroes / (n_rows * n_cols)
}

#' @export
.recipes_estimate_sparsity.step_dummy <- function(x, data, ...) {
n_levels <- lapply(data, function(x) length(levels(x)))

lapply(n_levels, function(n_lvl) {
c(
n_cols = ifelse(x$one_hot, n_lvl, n_lvl - 1),
sparsity = 1 - 1 / n_lvl
)
})
}
32 changes: 32 additions & 0 deletions man/dot-recipes_estimate_sparsity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions man/roles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,33 @@ test_that(".recipes_toggle_sparse_args works", {
.recipes_toggle_sparse_args(rec_spec_auto_no, "no"),
rec_spec_no_no
)
})

test_that(".recipes_toggle_sparse_args works", {
rec <- recipe(~., mtcars)

expect_identical(
.recipes_estimate_sparsity(rec),
sparsevctrs::sparsity(mtcars)
)

rec <- recipe(~., iris) %>%
step_normalize(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors())

exp <- rec %>% prep() %>% bake(NULL) %>% sparsevctrs::sparsity()

expect_equal(
.recipes_estimate_sparsity(rec),
exp
)

rec <- recipe(~., iris[0, ]) %>%
step_normalize(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors())

expect_equal(
.recipes_estimate_sparsity(rec),
exp
)
})

0 comments on commit eed3579

Please sign in to comment.