From c2d5d4729da826034ce24a44e3b0b0ebd0fac2ed Mon Sep 17 00:00:00 2001 From: Josh Yamamoto Date: Tue, 12 Mar 2024 21:34:27 -0500 Subject: [PATCH] doc updates and checks --- R/saeczi.R | 68 ++++++++++-------------------------- R/utils.R | 36 ++++++++++++++++++- README.Rmd | 5 +-- README.md | 14 ++++---- man/saeczi.Rd | 2 +- tests/testthat/test-saeczi.R | 2 +- 6 files changed, 65 insertions(+), 62 deletions(-) diff --git a/R/saeczi.R b/R/saeczi.R index 6579eba..28c3240 100644 --- a/R/saeczi.R +++ b/R/saeczi.R @@ -62,53 +62,38 @@ saeczi <- function(samp_dat, lin_formula, log_formula = lin_formula, domain_level, - B = 100, + B = 100L, mse_est = FALSE, estimand = "means", parallel = FALSE) { funcCall <- match.call() - if(!("formula" %in% class(lin_formula))) { - lin_formula <- as.formula(lin_formula) - message("lin_formula was converted to class 'formula'") - } + check_inherits(list(samp_dat, pop_dat), "data.frame") + check_inherits(list(lin_formula, log_formula), "formula") + check_inherits(list(domain_level, estimand), "character") + check_inherits(B, "integer") + check_inherits(list(mse_est, parallel), "logical") - if(!("formula" %in% class(log_formula))) { - log_formula <- as.formula(log_formula) - message("log_formula was converted to class 'formula'") - } + check_parallel(parallel) if(!(estimand %in% c("means", "totals"))) { stop("Invalid estimand, must be either 'means' or 'totals'") } - if (parallel && ("sequential" %in% class(future::plan()))) { - message("In order for the internal processes to be run in parallel a `future::plan()` must be specified by the user") - message("See for reference on how to use `future::plan()`") - } - # creating strings of original X, Y names Y <- deparse(lin_formula[[2]]) - lin_X <- unlist(str_extract_all_base( - deparse(lin_formula[[3]]), - "\\w+" - )) + lin_X <- unlist(str_extract_all_base(deparse(lin_formula[[3]]), "\\w+")) - log_X <- unlist(str_extract_all_base( - deparse(log_formula[[3]]), - "\\w+" - )) + log_X <- unlist(str_extract_all_base(deparse(log_formula[[3]]), "\\w+")) all_preds <- unique(lin_X, log_X) - original_out <- fit_zi( - samp_dat, - lin_formula, - log_formula, - domain_level - ) + original_out <- fit_zi(samp_dat, + lin_formula, + log_formula, + domain_level) mod1 <- original_out$lmer mod2 <- original_out$glmer @@ -187,21 +172,9 @@ saeczi <- function(samp_dat, response = linear_preds * boot_dat_params$delta_i_star ) - ## bootstrapping ------------------------------------------------------------- - - boot_lin_formula <- as.formula( - paste0( - "response ~ ", - paste(lin_X, collapse = " + ") - ) - ) + boot_lin_formula <- as.formula(paste0("response ~ ", paste(lin_X, collapse = " + "))) - boot_log_formula <- as.formula( - paste0( - "response ~ ", - paste(log_X, collapse = " + ") - ) - ) + boot_log_formula <- as.formula(paste0("response ~ ", paste(log_X, collapse = " + "))) if (estimand == "means") { boot_truth <- boot_pop_data |> @@ -217,7 +190,6 @@ saeczi <- function(samp_dat, boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B) if (parallel) { - with_progress({ boot_res <- boot_rep_par(x = 1:B, boot_lst = boot_samp_ls, @@ -230,9 +202,7 @@ saeczi <- function(samp_dat, lin_X, log_X) }) - } else { - res <- purrr::map(.x = boot_samp_ls, .f = \(.x) { @@ -277,6 +247,7 @@ saeczi <- function(samp_dat, log_X = log_X, estimand = estimand) + log_lst <- res |> map(.f = ~ .x$log) @@ -284,11 +255,8 @@ saeczi <- function(samp_dat, } - - mse_df <- setNames( - boot_res$preds, - c(domain_level, "mse") - ) + mse_df <- setNames(boot_res$preds, + c(domain_level, "mse")) final_df <- mse_df |> left_join(original_pred, by = domain_level) diff --git a/R/utils.R b/R/utils.R index 840deab..51d6bcc 100644 --- a/R/utils.R +++ b/R/utils.R @@ -172,7 +172,6 @@ generate_mse <- function(.data, return(res_doms) - } #' Bootstrap procedure for the parallel option @@ -458,4 +457,39 @@ capture_all <- function(.f){ } +} + +#' Checking if a param inherits a class +#' +#' @param x The parameter input(s) to check +#' @param what What class to check if the parameter input inherits +#' +#' @return Nothing if the check is passed, but an error if the check fails +#' @noRd +check_inherits <- function(x, what) { + for (i in seq_along(x)) { + if (!inherits(x[[i]], what)) { + stop(paste0(x[[i]], " needs to be of class ", what)) + } + } + invisible(x) +} + +#' Checking if parallel functionality is properly set up +#' +#' @param x The parameter input to check +#' @param call The caller environment to check in +#' +#' @return Nothing if the check is passed, but an error if the check fails +#' @noRd +check_parallel <- function(x, call = rlang::caller_env()) { + + if (x) { + if (eval(!inherits(future::plan(), "sequential"), envir = call)) { + message("In order for the internal processes to be run in parallel a `future::plan()` must be specified by the user") + message("See for reference on how to use `future::plan()`") + } + } + + invisible(x) } \ No newline at end of file diff --git a/README.Rmd b/README.Rmd index 3fa074b..58aaf14 100644 --- a/README.Rmd +++ b/README.Rmd @@ -60,14 +60,15 @@ library(saeczi) data(pop) data(samp) +future::plan('multisession', workers = 6) result <- saeczi(samp_dat = samp, pop_dat = pop, lin_formula = DRYBIO_AG_TPA_live_ADJ ~ tcc16 + elev, log_formula = DRYBIO_AG_TPA_live_ADJ ~ tcc16 + elev, domain_level = "COUNTYFIPS", mse_est = TRUE, - B = 500, - parallel = FALSE) + B = 500L, + parallel = TRUE) ``` diff --git a/README.md b/README.md index 36b10da..de9e755 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ result <- saeczi(samp_dat = samp, log_formula = DRYBIO_AG_TPA_live_ADJ ~ tcc16 + elev, domain_level = "COUNTYFIPS", mse_est = TRUE, - B = 500, + B = 500L, parallel = FALSE) ``` @@ -98,10 +98,10 @@ few rows of the results: ``` r result$res |> head() #> COUNTYFIPS mse est -#> 1 41001 453.74637 14.85495 -#> 2 41003 35.01620 97.74967 -#> 3 41005 295.83622 86.02207 -#> 4 41007 78.80944 76.24752 -#> 5 41009 91.07024 70.28624 -#> 6 41011 277.73623 87.65072 +#> 1 41001 226.91454 14.85495 +#> 2 41003 89.32900 97.74967 +#> 3 41005 350.67805 86.02207 +#> 4 41007 608.49682 76.24752 +#> 5 41009 97.27606 70.28624 +#> 6 41011 81.05661 87.65072 ``` diff --git a/man/saeczi.Rd b/man/saeczi.Rd index f207e0c..63087de 100644 --- a/man/saeczi.Rd +++ b/man/saeczi.Rd @@ -10,7 +10,7 @@ saeczi( lin_formula, log_formula = lin_formula, domain_level, - B = 100, + B = 100L, mse_est = FALSE, estimand = "means", parallel = FALSE diff --git a/tests/testthat/test-saeczi.R b/tests/testthat/test-saeczi.R index e59e1f7..8d62cc8 100644 --- a/tests/testthat/test-saeczi.R +++ b/tests/testthat/test-saeczi.R @@ -13,7 +13,7 @@ result <- saeczi(samp, lin_formula, domain_level = "COUNTYFIPS", mse_est = TRUE, - B = 10, + B = 10L, parallel = FALSE) test_that("result$res is a df", {