Skip to content

Commit

Permalink
Merge pull request #1387 from tidymodels/type-checkers
Browse files Browse the repository at this point in the history
Add rlang type checkers
  • Loading branch information
EmilHvitfeldt authored Nov 1, 2024
2 parents a0b50c6 + 4902983 commit 28e3e4d
Show file tree
Hide file tree
Showing 103 changed files with 1,455 additions and 152 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

* `prep.recipe(..., strings_as_factors = TRUE)` now only converts string variables that have role "predictor" or "outcome". (@dajmcdon, #1358, #1376)

* All steps and checks now require arguments `trained`, `skip`, `role`, and `id` at all times.

# recipes 1.1.0

## Improvements
Expand Down
4 changes: 3 additions & 1 deletion R/bin2factor.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ step_bin2factor <-
if (length(levels) != 2) {
msg <- c(
msg,
i = "{length(levels)} element{?s} were supplied."
i = "{length(levels)} element{?s} were supplied; two were expected."
)
}
if (!is.character(levels)) {
Expand All @@ -76,6 +76,8 @@ step_bin2factor <-
}
cli::cli_abort(msg)
}
check_bool(ref_first)

add_step(
recipe,
step_bin2factor_new(
Expand Down
2 changes: 2 additions & 0 deletions R/class.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ check_class <-
check_class_new <-
function(terms, role, trained, class_nm,
allow_additional, class_list, skip, id) {
check_character(class_nm, allow_null = TRUE, call = rlang::caller_env(2))
check_bool(allow_additional, call = rlang::caller_env(2))
check(
subclass = "class",
terms = terms,
Expand Down
5 changes: 5 additions & 0 deletions R/classdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ prep.step_classdist <- function(x, training, info = NULL, ...) {
wts <- NULL
}

check_function(x$mean_func)
check_function(x$cov_func)
check_bool(x$pool)
check_string(x$prefix)

x_dat <- split(training[, x_names], training[[class_var]])
if (is.null(wts)) {
wts_split <- map(x_dat, ~NULL)
Expand Down
2 changes: 2 additions & 0 deletions R/classdist_shrunken.R
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ prep.step_classdist_shrunken <- function(x, training, info = NULL, ...) {

sd_offset <- x$sd_offset
check_number_decimal(sd_offset, min = 0, max = 1)
check_bool(x$log)
check_string(x$prefix)

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts)
Expand Down
6 changes: 6 additions & 0 deletions R/corr.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ step_corr_new <-
prep.step_corr <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))
check_number_decimal(x$threshold, min = 0, max = 1, arg = "threshold")
use <- x$use
rlang::arg_match(use, c("all.obs", "complete.obs", "pairwise.complete.obs",
"everything", "na.or.complete"))
method <- x$method
rlang::arg_match(method, c("pearson", "kendall", "spearman"))

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts, unsupervised = TRUE)
Expand Down
3 changes: 3 additions & 0 deletions R/count.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ step_count_new <-
prep.step_count <- function(x, training, info = NULL, ...) {
col_name <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_name], types = c("string", "factor", "ordered"))
check_string(x$pattern, allow_empty = TRUE, arg = "pattern")
check_string(x$result, allow_empty = FALSE, arg = "result")
check_bool(x$normalize, arg = "normalize")

step_count_new(
terms = x$terms,
Expand Down
4 changes: 2 additions & 2 deletions R/cut.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ prep.step_cut <- function(x, training, info = NULL, ...) {

if (!is.numeric(x$breaks)) {
cli::cli_abort(
"{.arg breaks} must be a numeric vector, \\
not {.obj_type_friendly {x$breaks}}."
"{.arg breaks} must be a numeric vector, not {.obj_type_friendly {x$breaks}}."
)
}
check_bool(x$include_outside_range, arg = "include_outside_range")

all_breaks <- vector("list", length(col_names))
names(all_breaks) <- col_names
Expand Down
3 changes: 3 additions & 0 deletions R/date.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ step_date_new <-
prep.step_date <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("date", "datetime"))
check_bool(x$abbr, arg = "abbr")
check_bool(x$label, arg = "label")
check_bool(x$ordinal, arg = "ordinal")

step_date_new(
terms = x$terms,
Expand Down
6 changes: 6 additions & 0 deletions R/depth.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,16 @@ step_depth_new <-
)
}

depth_metric <- c("potential", "halfspace", "Mahalanobis", "simplicialVolume",
"spatial", "zonoid")

#' @export
prep.step_depth <- function(x, training, info = NULL, ...) {
x_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, x_names], types = c("double", "integer"))
metric <- x$metric
rlang::arg_match(metric, depth_metric)
check_string(x$prefix, allow_empty = FALSE, arg = "prefix")

class_var <- x$class[1]

Expand Down
25 changes: 12 additions & 13 deletions R/discretize.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,7 @@ discretize.numeric <-
...) {
unique_vals <- length(unique(x))
missing_lab <- "_missing"

if (cuts < 2) {
cli::cli_abort(
"There should be at least 2 {.arg cuts} but {.val {cuts}} was supplied."
)
}
check_number_whole(cuts, min = 2)

dots <- list(...)
if (keep_na) {
Expand All @@ -115,8 +110,8 @@ discretize.numeric <-
breaks <- unique(breaks)
if (num_breaks > length(breaks)) {
cli::cli_warn(
"Not enough data for {cuts} breaks. \\
Only {length(breaks)} breaks were used."
"Not enough data for {cuts} breaks. Only {length(breaks)} breaks
were used."
)
}
if (infs) {
Expand All @@ -129,8 +124,8 @@ discretize.numeric <-
prefix <- prefix[1]
if (make.names(prefix) != prefix && !is.null(prefix)) {
cli::cli_warn(
"The prefix {.val {prefix}} is not a valid R name. \\
It has been changed to {.val {make.names(prefix)}}."
"The prefix {.val {prefix}} is not a valid R name. It has been
changed to {.val {make.names(prefix)}}."
)
prefix <- make.names(prefix)
}
Expand All @@ -150,8 +145,8 @@ discretize.numeric <-
} else {
out <- list(bins = 0)
cli::cli_warn(
"Data not binned; too few unique values per bin. \\
Adjust {.arg min_unique} as needed."
"Data not binned; too few unique values per bin. Adjust
{.arg min_unique} as needed."
)
}
class(out) <- "discretize"
Expand Down Expand Up @@ -301,8 +296,10 @@ step_discretize <- function(recipe,
options = list(prefix = "bin"),
skip = FALSE,
id = rand_id("discretize")) {
if (any(names(options) %in% c("cuts", "min_unique"))) {
if (any(names(options) == "cuts")) {
num_breaks <- options$cuts
}
if (any(names(options) == "min_unique")) {
min_unique <- options$min_unique
}

Expand Down Expand Up @@ -348,6 +345,8 @@ bin_wrapper <- function(x, args) {
prep.step_discretize <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))
check_number_whole(x$num_breaks, min = 1, arg = "num_breaks")
check_number_whole(x$min_unique, min = 2, arg = "min_unique")

if (length(col_names) > 1 & any(names(x$options) %in% c("prefix", "labels"))) {
cli::cli_warn(
Expand Down
14 changes: 8 additions & 6 deletions R/dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ step_dummy <-
keep_original_cols = FALSE,
skip = FALSE,
id = rand_id("dummy")) {

if (lifecycle::is_present(preserve)) {
lifecycle::deprecate_stop(
"0.1.16",
Expand Down Expand Up @@ -172,6 +172,8 @@ step_dummy_new <-
prep.step_dummy <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("factor", "ordered"))
check_bool(x$one_hot, arg = "one_hot")
check_function(x$naming, arg = "naming", allow_empty = FALSE)

if (length(col_names) > 0) {
## I hate doing this but currently we are going to have
Expand Down Expand Up @@ -229,14 +231,14 @@ warn_new_levels <- function(dat, lvl, column, step, details = NULL) {
msg <- c("!" = "There are new levels in {.var {column}}: {.val {lvl2}}.")
if (any(is.na(lvl2))) {
msg <- c(
msg,
msg,
"i" = "Consider using {.help [step_unknown()](recipes::step_unknown)} \\
before {.fn {step}} to handle missing values."
)
}
if (!all(is.na(lvl2))) {
msg <- c(
msg,
msg,
"i" = "Consider using {.help [step_novel()](recipes::step_novel)} \\
before {.fn {step}} to handle unseen values."
)
Expand Down Expand Up @@ -278,9 +280,9 @@ bake.step_dummy <- function(object, new_data, ...) {
}

warn_new_levels(
new_data[[col_name]],
levels_values,
col_name,
new_data[[col_name]],
levels_values,
col_name,
step = "step_dummy"
)

Expand Down
25 changes: 13 additions & 12 deletions R/dummy_extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
#' @inheritParams step_center
#' @inheritParams step_other
#' @inheritParams step_dummy
#' @param sep Character vector containing a regular expression to use
#' @param sep Character string containing a regular expression to use
#' for splitting. [strsplit()] is used to perform the split. `sep` takes
#' priority if `pattern` is also specified.
#' @param pattern Character vector containing a regular expression used
#' @param pattern Character string containing a regular expression used
#' for extraction. [gregexpr()] and [regmatches()] are used to perform
#' pattern extraction using `perl = TRUE`.
#' @template step-return
Expand Down Expand Up @@ -88,10 +88,10 @@
#' step_dummy_extract(colors, pattern = "(?<=')[^',]+(?=')") %>%
#' prep()
#'
#' dommies_data_color <- dummies_color %>%
#' dummies_data_color <- dummies_color %>%
#' bake(new_data = NULL)
#'
#' dommies_data_color
#' dummies_data_color
step_dummy_extract <-
function(recipe,
...,
Expand All @@ -107,14 +107,6 @@ step_dummy_extract <-
skip = FALSE,
id = rand_id("dummy_extract")) {

if (!is_tune(threshold)) {
if (threshold >= 1) {
check_number_whole(threshold)
} else {
check_number_decimal(threshold, min = 0)
}
}

add_step(
recipe,
step_dummy_extract_new(
Expand Down Expand Up @@ -160,6 +152,15 @@ step_dummy_extract_new <-
prep.step_dummy_extract <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("string", "factor", "ordered"))
if (x$threshold >= 1) {
check_number_whole(x$threshold, arg = "threshold")
} else {
check_number_decimal(x$threshold, min = 0, arg = "threshold")
}
check_string(x$other, arg = "other", allow_null = TRUE)
check_string(x$sep, arg = "sep", allow_null = TRUE)
check_string(x$pattern, arg = "pattern", allow_null = TRUE)
check_function(x$naming, arg = "naming", allow_empty = FALSE)

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts, unsupervised = TRUE)
Expand Down
33 changes: 16 additions & 17 deletions R/dummy_multi_choice.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
#' This is `NULL` until the step is trained by [prep()].
#' @template step-return
#' @family dummy variable and encoding steps
#'
#'
#' @details
#' The overall proportion (or total counts) of the categories are computed. The
#' `"other"` category is used in place of any categorical levels whose
#' The overall proportion (or total counts) of the categories are computed. The
#' `"other"` category is used in place of any categorical levels whose
#' individual proportion (or frequency) in the training set is less than
#' `threshold`.
#'
#'
#' This step produces a number of columns, based on the number of categories it
#' finds. The naming of the columns is determined by the function based on the
#' `naming` argument. The default is to return `<prefix>_<category name>`. By
#' default `prefix` is `NULL`, which means the name of the first column
#' finds. The naming of the columns is determined by the function based on the
#' `naming` argument. The default is to return `<prefix>_<category name>`. By
#' default `prefix` is `NULL`, which means the name of the first column
#' selected will be used in place.
#'
#' @template dummy-naming
#'
#' @details
#'
#'
#' ```{r, echo = FALSE, results="asis"}
#' step <- "step_dummy_multi_choice"
#' result <- knitr::knit_child("man/rmd/tunable-args.Rmd")
Expand Down Expand Up @@ -74,7 +74,7 @@
#'
#' bake(dummy_multi_choice_rec2, new_data = NULL)
#' tidy(dummy_multi_choice_rec2, number = 1)
#'
#'
#' @export
step_dummy_multi_choice <- function(recipe,
...,
Expand All @@ -90,14 +90,6 @@ step_dummy_multi_choice <- function(recipe,
skip = FALSE,
id = rand_id("dummy_multi_choice")) {

if (!is_tune(threshold)) {
if (threshold >= 1) {
check_number_whole(threshold)
} else {
check_number_decimal(threshold, min = 0)
}
}

add_step(
recipe,
step_dummy_multi_choice_new(
Expand Down Expand Up @@ -141,6 +133,13 @@ step_dummy_multi_choice_new <-
prep.step_dummy_multi_choice <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("nominal", "logical"))
if (x$threshold >= 1) {
check_number_whole(x$threshold, arg = "threshold")
} else {
check_number_decimal(x$threshold, min = 0, arg = "threshold")
}
check_string(x$other, arg = "other", allow_null = TRUE)
check_function(x$naming, arg = "naming", allow_empty = FALSE)

levels <- purrr::map(training[, col_names], levels)
levels <- vctrs::list_unchop(levels, ptype = character(), name_spec = rlang::zap())
Expand Down
1 change: 1 addition & 0 deletions R/filter_missing.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ step_filter_missing_new <-
#' @export
prep.step_filter_missing <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_number_decimal(x$threshold, min = 0, max = 1, arg = "threshold")

wts <- get_case_weights(info, training)
were_weights_used <- are_weights_used(wts, unsupervised = TRUE)
Expand Down
8 changes: 6 additions & 2 deletions R/hyperbolic.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ step_hyperbolic <-
skip = FALSE,
id = rand_id("hyperbolic")) {

func <- rlang::arg_match(func)

if (!is_tune(func)) {
func <- rlang::arg_match(func)
}
add_step(
recipe,
step_hyperbolic_new(
Expand Down Expand Up @@ -94,6 +95,9 @@ step_hyperbolic_new <-
prep.step_hyperbolic <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)
check_type(training[, col_names], types = c("double", "integer"))
func <- x$func
x$func <- rlang::arg_match(func, c("sinh", "cosh", "tanh"), error_arg = "func")
check_bool(x$inverse, error_arg = "inverse")

step_hyperbolic_new(
terms = x$terms,
Expand Down
Loading

0 comments on commit 28e3e4d

Please sign in to comment.