diff --git a/DESCRIPTION b/DESCRIPTION index 800bf5d..2fd86af 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,6 +19,7 @@ Encoding: UTF-8 LazyData: true Imports: stats, + dplyr, lme4, purrr, progressr, diff --git a/NAMESPACE b/NAMESPACE index 6a6dc20..3e2a8a5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,7 +6,9 @@ export(unit_zi) import(stats) importFrom(furrr,furrr_options) importFrom(furrr,future_map) +importFrom(furrr,future_map2) importFrom(methods,is) importFrom(progressr,progressor) importFrom(progressr,with_progress) importFrom(purrr,map) +importFrom(purrr,map2) diff --git a/R/unit_zi.R b/R/unit_zi.R index 8ab3e70..aa4b871 100644 --- a/R/unit_zi.R +++ b/R/unit_zi.R @@ -23,8 +23,8 @@ #' @export unit_zi #' @import stats #' @importFrom progressr progressor with_progress -#' @importFrom furrr future_map furrr_options -#' @importFrom purrr map +#' @importFrom furrr future_map furrr_options future_map2 +#' @importFrom purrr map map2 #' @importFrom methods is unit_zi <- function(samp_dat, @@ -180,37 +180,35 @@ unit_zi <- function(samp_dat, boot_truth <- stats::setNames(stats::aggregate(response ~ domain, data = boot_pop_data, FUN = mean), c("domain", "domain_est")) - by_domains <- split(boot_pop_data, f = boot_pop_data$domain) + # create bootstrap samples + boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B) - num_plots <- data.frame(table(samp_dat[ , domain_level])) - + # goal is to not pass boot_pop_data to the map at all + + # still need to implement here... # furrr with progress bar - boot_rep_with_progress_bar <- function(x) { + boot_rep_with_progress_bar <- function(x, boot_lst) { p <- progressor(steps = length(x)) - res <- x |> future_map( ~{ - p() - out <- boot_rep( - boot_pop_data, - samp_dat, - domain_level, - num_plots, - boot_lin_formula, - boot_log_formula, - boot_truth, - by_domains - ) - out - }, - .options = furrr_options(seed = TRUE)) + res <- + furrr::future_map(.x = boot_lst, + .f = \(.x) { + p() + boot_rep(boot_samp = .x, + pop_boot = boot_pop_data, + domain_level, + boot_lin_formula, + boot_log_formula, + boot_truth) + }, + .options = furrr_options(seed = TRUE)) res_lst <- res |> map(.f = ~ .x$sqerr) res_df <- do.call("rbind", res_lst) - - # res_df <- do.call("rbind", res) + log_lst <- res |> map(.f = ~ .x$log) @@ -221,21 +219,22 @@ unit_zi <- function(samp_dat, if (parallel) { with_progress({ - boot_res <- boot_rep_with_progress_bar(1:B) + boot_res <- boot_rep_with_progress_bar(x = 1:B, + boot_lst = boot_samp_ls) }) } else { res <- - map(.x = 1:B, - .f = \(i) boot_rep(boot_pop_data, - samp_dat, - domain_level, - num_plots, - boot_lin_formula, - boot_log_formula, - boot_truth, - by_domains)) + purrr::map(.x = boot_samp_ls, + .f = \(.x) { + boot_rep(boot_samp = .x, + pop_boot = boot_pop_data, + domain_level, + boot_lin_formula, + boot_log_formula, + boot_truth) + }) res_lst <- res |> map(.f = ~ .x$sqerr) diff --git a/R/utils.R b/R/utils.R index f0adb6d..48a6e82 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,4 +1,41 @@ +# fast samp-by-grp +samp_by_grp <- function(samp, pop, dom_nm, B) { + + num_plots <- dplyr::count(samp, !!rlang::sym(dom_nm)) + # our boot_pop_data has column name domain as its group variable + setup <- dplyr::count(pop, domain) |> + dplyr::left_join(num_plots, by = c(domain = dom_nm)) |> + dplyr::mutate(add_to = dplyr::lag(cumsum(n.x), default = 0)) |> + dplyr::rowwise() |> + dplyr::mutate(map_args = list(list(n.x, n.y, add_to))) + + all_samps <- vector("list", length = B) + + for (i in 1:B) { + ids <- setup |> + dplyr::mutate(samps = purrr::pmap(.l = map_args, .f = \ (x, y, z) { + sample(1:x, size = y, replace = TRUE) + z + })) |> + dplyr::pull(samps) |> + unlist() + + out <- pop[ids, ] + all_samps[[i]] <- out + } + + return(all_samps) +} + + + # fit_zi function + +# don't do prediction here +# predict_zi + +# take the mean of the pixels in that county +# then predict on those means + fit_zi <- function(samp_dat, pop_dat, lin_formula, @@ -35,11 +72,15 @@ fit_zi <- function(samp_dat, # Fit logistic mixed effects on ALL data glmer_z <- suppressMessages(lme4::glmer(log_reg_formula, data = samp_dat, family = "binomial")) + # dont do this unit_level_preds <- setNames( stats::predict(lmer_nz, pop_dat, allow.new.levels = TRUE) * stats::predict(glmer_z, pop_dat, type = "response"), as.character(pop_dat[ , domain_level, drop = T]) ) + # idea: just return model params and fit later + + zi_domain_preds <- aggregate(unit_level_preds, by = list(names(unit_level_preds)), FUN = mean) names(zi_domain_preds) <- c("domain", "Y_hat_j") @@ -48,6 +89,11 @@ fit_zi <- function(samp_dat, } + +# predict_zi <- function(mod1, mod2, data) { +# +# } + # base version of dplyr::slice_sample slice_samp <- function(.data, n, replace = TRUE) { .data[sample(nrow(.data), n, replace = replace),] @@ -61,45 +107,40 @@ str_extract_all_base <- function(string, pattern) { # bootstrap rep helper -boot_rep <- function(pop_boot, - samp_dat, +boot_rep <- function(boot_samp, + pop_boot, domain_level, - num_plots, boot_lin_formula, boot_log_formula, - boot_truth, - by_domains) { - - boot_data_ls <- purrr::map2(.x = by_domains, .y = num_plots$Freq, slice_samp) - boot_data <- do.call("rbind", boot_data_ls) + boot_truth) { # capture warnings and messages silently when bootstrapping fit_zi_capture <- capture_all(fit_zi) - # nested tryCatch - # tries resampling once and if it fails again returns properly structured output filled with NAs boot_samp_fit <- tryCatch( { - fit_zi_capture(boot_data, pop_boot, boot_lin_formula, boot_log_formula, domain_level) + fit_zi_capture(boot_samp, + pop_boot, + boot_lin_formula, + boot_log_formula, + domain_level) }, error = function(cond) { - boot_data_ls <- purrr::map2(.x = by_domains, .y = num_plots$Freq, slice_samp) - boot_data <- do.call("rbind", boot_data_ls) - tryCatch( - { - fit_zi_capture(boot_data, pop_boot, boot_lin_formula, boot_log_formula, domain_level) - }, - error = function(cond) { - zi_domain_preds <- boot_truth - zi_domain_preds$domain_est <- NA - names(zi_domain_preds) <- c("domain", "Y_hat_j") - list(result = list(lmer = NA, glmer = NA, pred = zi_domain_preds), log = cond) - } - ) + zi_domain_preds <- boot_truth + zi_domain_preds$domain_est <- NA + names(zi_domain_preds) <- c("domain", "Y_hat_j") + list(result = list(lmer = NA, + glmer = NA, + pred = zi_domain_preds), + log = cond) + } ) - squared_error <- merge(x = boot_samp_fit$result$pred, y = boot_truth, by = "domain", all.x = TRUE) |> + squared_error <- merge(x = boot_samp_fit$result$pred, + y = boot_truth, + by = "domain", + all.x = TRUE) |> transform(sq_error = (Y_hat_j - domain_est)^2) squared_error <- squared_error[ , c("domain", "sq_error")] diff --git a/R/zzz.R b/R/zzz.R index 2ea41a1..d1ea7f7 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,2 +1,3 @@ utils::globalVariables(c("domain", "response", "Y_hat_j", "domain_est", - "sq_error", "grp", "mse")) \ No newline at end of file + "sq_error", "grp", "mse", "map_args", "n.x", + "n.y", "samps", "add_to")) \ No newline at end of file diff --git a/tests/testthat/_snaps/unit_zi.md b/tests/testthat/_snaps/unit_zi.md index 6061d55..ba691df 100644 --- a/tests/testthat/_snaps/unit_zi.md +++ b/tests/testthat/_snaps/unit_zi.md @@ -1,4 +1,4 @@ -# result is as expected +# printed result is as expected Code result @@ -27,3 +27,46 @@ COUNTYFIPS (Intercept) 0.87583 +# result is as expected + + Code + result$res + Output + domain mse est + 1 41001 52.079431 14.8549464 + 2 41003 80.502624 97.7496673 + 3 41005 157.956667 86.0220677 + 4 41007 316.436197 76.2475194 + 5 41009 91.341855 70.2862446 + 6 41011 111.134371 87.6507212 + 7 41013 221.068461 11.0312390 + 8 41015 7.931719 104.4564778 + 9 41017 274.547503 25.6193318 + 10 41019 278.269298 89.7724802 + 11 41021 17.427338 0.5406902 + 12 41023 144.475154 23.6541480 + 13 41025 85.969011 1.9659769 + 14 41027 18.103483 73.6439139 + 15 41029 61.009924 67.6980088 + 16 41031 111.606598 19.8731946 + 17 41033 84.149689 66.7685216 + 18 41035 348.333311 35.4898212 + 19 41037 169.634020 9.2608227 + 20 41039 125.914934 120.8521093 + 21 41041 58.846876 107.7729877 + 22 41043 41.556125 81.6518967 + 23 41045 65.143036 0.4838125 + 24 41047 52.844388 62.1872275 + 25 41049 21.282101 6.8828560 + 26 41051 344.608395 72.1014683 + 27 41053 168.157246 85.3369556 + 28 41055 6.807745 0.5432959 + 29 41057 190.890004 101.2380754 + 30 41059 68.151056 13.4063726 + 31 41061 308.413912 27.6249648 + 32 41063 157.857125 21.3592092 + 33 41065 74.227749 17.1744016 + 34 41067 40.250553 56.9720489 + 35 41069 132.581564 14.3059884 + 36 41071 116.645285 58.7331579 + diff --git a/tests/testthat/test-unit_zi.R b/tests/testthat/test-unit_zi.R index 0795b57..03e0205 100644 --- a/tests/testthat/test-unit_zi.R +++ b/tests/testthat/test-unit_zi.R @@ -1,5 +1,4 @@ library(saeczi) - data(pop) data(samp) @@ -14,10 +13,14 @@ result <- unit_zi(samp, B = 5, parallel = FALSE) -test_that("result is as expected", { +test_that("printed result is as expected", { expect_snapshot(result) }) +test_that("result is as expected", { + expect_snapshot(result$res) +}) + test_that("result[[2]] is a df", { expect_s3_class(result[[2]], "data.frame") })