Skip to content

Commit

Permalink
Merge pull request #22 from harvard-ufds/dev
Browse files Browse the repository at this point in the history
adding option for totals
  • Loading branch information
graysonwhite authored Feb 22, 2024
2 parents 69c3c46 + 1ddf40b commit 768e44f
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 26 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

generate_preds <- function(beta_lm, beta_glm, u_lm, u_glm, design_mats, J) {
.Call(`_saeczi_generate_preds`, beta_lm, beta_glm, u_lm, u_glm, design_mats, J)
generate_preds <- function(beta_lm, beta_glm, u_lm, u_glm, design_mats, J, estimand) {
.Call(`_saeczi_generate_preds`, beta_lm, beta_glm, u_lm, u_glm, design_mats, J, estimand)
}

34 changes: 26 additions & 8 deletions R/saeczi.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#' @param domain_level A string of the column name in the dataframes that reflect the domain level
#' @param B An integer of the number of reps desired for the bootstrap
#' @param mse_est A boolean that specifies if the user
#' @param estimand A string specifying whether the estimates should be 'totals' or 'means'.
#' @param parallel Compute MSE estimation in parallel
#'
#' @details The arguments `lin_formula`, and `log_formula`
Expand Down Expand Up @@ -62,6 +63,7 @@ saeczi <- function(samp_dat,
domain_level,
B = 100,
mse_est = FALSE,
estimand = "means",
parallel = FALSE) {

funcCall <- match.call()
Expand All @@ -76,6 +78,10 @@ saeczi <- function(samp_dat,
message("log_formula was converted to class 'formula'")
}

if(!(estimand %in% c("means", "totals"))) {
stop("Invalid estimand, must be either 'means' or 'totals'")
}

if (parallel && is(future::plan(), "sequential")) {
message("In order for the internal processes to be run in parallel a `future::plan()` must be specified by the user")
message("See <https://future.futureverse.org/reference/plan.html> for reference on how to use `future::plan()`")
Expand Down Expand Up @@ -112,10 +118,15 @@ saeczi <- function(samp_dat,
as.character(.data[ , domain_level, drop = T])
)

zi_domain_preds <- aggregate(unit_level_preds, by = list(names(unit_level_preds)), FUN = mean)
names(zi_domain_preds) <- c("domain", "Y_hat_j")

original_pred <- zi_domain_preds
if (estimand == "means") {
zi_domain_means <- aggregate(unit_level_preds, by = list(names(unit_level_preds)), FUN = mean)
names(zi_domain_means) <- c("domain", "Y_hat_j")
original_pred <- zi_domain_means
} else {
zi_domain_totals <- aggregate(unit_level_preds, by = list(names(unit_level_preds)), FUN = sum)
names(zi_domain_totals) <- c("domain", "Y_hat_j")
original_pred <- zi_domain_totals
}

if (mse_est) {

Expand Down Expand Up @@ -224,9 +235,15 @@ saeczi <- function(samp_dat,
paste(log_X, collapse = " + ")
)
)
# define these before bootstrap
boot_truth <- stats::setNames(stats::aggregate(response ~ domain, data = boot_pop_data,
FUN = mean), c("domain", "domain_est"))

if (estimand == "means") {
boot_truth <- stats::setNames(stats::aggregate(response ~ domain, data = boot_pop_data,
FUN = mean), c("domain", "domain_est"))
} else {
boot_truth <- stats::setNames(stats::aggregate(response ~ domain, data = boot_pop_data,
FUN = sum), c("domain", "domain_est"))
}


# create bootstrap samples
boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B)
Expand Down Expand Up @@ -336,7 +353,8 @@ saeczi <- function(samp_dat,
u_lm = u_lm,
u_glm = u_glm,
lin_X = lin_X,
log_X = log_X)
log_X = log_X,
estimand = estimand)

log_lst <- res |>
map(.f = ~ .x$log)
Expand Down
13 changes: 9 additions & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ generate_mse <- function(.data,
u_lm,
u_glm,
lin_X,
log_X) {
log_X,
estimand) {

boot_pop_by_dom <- split(.data, f = .data$domain)

Expand All @@ -113,18 +114,22 @@ generate_mse <- function(.data,
u_lm = u_lm,
u_glm = u_glm,
design_mats = design_mat_ls,
J = n_doms)
J = n_doms,
estimand = estimand)




truth_ordered <- truth[order(match(truth$domain, dom_order)), ]
truth_vec <- truth_ordered$domain_est

mean_sq_err <- (dom_res_wide - truth_vec)^2 |>
mse <- (dom_res_wide - truth_vec)^2 |>
rowMeans(na.rm = TRUE)


res_doms <- data.frame(
domain = truth_ordered$domain,
mse = mean_sq_err
mse = mse
)

return(res_doms)
Expand Down
3 changes: 3 additions & 0 deletions man/saeczi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// generate_preds
SEXP generate_preds(const Eigen::MatrixXd& beta_lm, const Eigen::MatrixXd& beta_glm, const Eigen::MatrixXd& u_lm, const Eigen::MatrixXd& u_glm, const Rcpp::List& design_mats, int J);
RcppExport SEXP _saeczi_generate_preds(SEXP beta_lmSEXP, SEXP beta_glmSEXP, SEXP u_lmSEXP, SEXP u_glmSEXP, SEXP design_matsSEXP, SEXP JSEXP) {
SEXP generate_preds(const Eigen::MatrixXd& beta_lm, const Eigen::MatrixXd& beta_glm, const Eigen::MatrixXd& u_lm, const Eigen::MatrixXd& u_glm, const Rcpp::List& design_mats, int J, std::string estimand);
RcppExport SEXP _saeczi_generate_preds(SEXP beta_lmSEXP, SEXP beta_glmSEXP, SEXP u_lmSEXP, SEXP u_glmSEXP, SEXP design_matsSEXP, SEXP JSEXP, SEXP estimandSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -23,13 +23,14 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const Eigen::MatrixXd& >::type u_glm(u_glmSEXP);
Rcpp::traits::input_parameter< const Rcpp::List& >::type design_mats(design_matsSEXP);
Rcpp::traits::input_parameter< int >::type J(JSEXP);
rcpp_result_gen = Rcpp::wrap(generate_preds(beta_lm, beta_glm, u_lm, u_glm, design_mats, J));
Rcpp::traits::input_parameter< std::string >::type estimand(estimandSEXP);
rcpp_result_gen = Rcpp::wrap(generate_preds(beta_lm, beta_glm, u_lm, u_glm, design_mats, J, estimand));
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_saeczi_generate_preds", (DL_FUNC) &_saeczi_generate_preds, 6},
{"_saeczi_generate_preds", (DL_FUNC) &_saeczi_generate_preds, 7},
{NULL, NULL, 0}
};

Expand Down
27 changes: 19 additions & 8 deletions src/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ double sigmoid(double x) {
return 1.0 / (1.0 + std::exp(-x));
}

void preds_calc(Eigen::MatrixXd& res,
void preds_calc(Eigen::MatrixXd& result,
const Eigen::MatrixXd& beta_lm,
const Eigen::MatrixXd& beta_glm,
const Eigen::MatrixXd& dmat_lm,
const Eigen::MatrixXd& dmat_glm,
const Eigen::MatrixXd& u_lm,
const Eigen::MatrixXd& u_glm,
int j,
int B) {
int B,
std::string estimand) {

// these are N_j x B
Eigen::MatrixXd pred_lm_j = (dmat_lm * beta_lm.transpose());
Expand All @@ -28,8 +29,16 @@ void preds_calc(Eigen::MatrixXd& res,
pred_glm_j = pred_glm_j.unaryExpr(&sigmoid);
Eigen::MatrixXd unit_preds_j = pred_lm_j.cwiseProduct(pred_glm_j);

Eigen::MatrixXd dom_preds_j = unit_preds_j.colwise().mean();
res.block(j, 0, 1, B) = dom_preds_j;
int N_j = unit_preds_j.rows();

Eigen::MatrixXd dom_preds_j;
if (estimand == "means") {
dom_preds_j = unit_preds_j.colwise().mean();
} else {
dom_preds_j = unit_preds_j.colwise().sum();
}

result.block(j, 0, 1, B) = dom_preds_j;

}

Expand All @@ -39,11 +48,12 @@ SEXP generate_preds(const Eigen::MatrixXd& beta_lm,
const Eigen::MatrixXd& u_lm,
const Eigen::MatrixXd& u_glm,
const Rcpp::List& design_mats,
int J) {
int J,
std::string estimand) {

int B = u_lm.rows();

// initialize result matrix
// initialize result matrices
Eigen::MatrixXd result = Eigen::MatrixXd::Zero(J, B);

for (int j = 0; j < J; ++j) {
Expand All @@ -60,10 +70,11 @@ SEXP generate_preds(const Eigen::MatrixXd& beta_lm,
u_lm,
u_glm,
j,
B);
B,
estimand);

}

return Rcpp::wrap(result);

}

0 comments on commit 768e44f

Please sign in to comment.