Skip to content

Commit

Permalink
Merge pull request #17 from harvard-ufds/dev
Browse files Browse the repository at this point in the history
new function for N_j x B looping
  • Loading branch information
joshyam-k authored Feb 9, 2024
2 parents b257050 + 1fce6fc commit e9a5455
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 65 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

dom_preds_calc <- function(beta_lm, beta_glm, J, u_lm, u_glm, design_mats) {
.Call(`_saeczi_dom_preds_calc`, beta_lm, beta_glm, J, u_lm, u_glm, design_mats)
generate_preds <- function(beta_lm, beta_glm, u_lm, u_glm, design_mats, J) {
.Call(`_saeczi_generate_preds`, beta_lm, beta_glm, u_lm, u_glm, design_mats, J)
}

41 changes: 4 additions & 37 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ samp_by_grp <- function(samp, pop, dom_nm, B) {


# fit_zi function

fit_zi <- function(samp_dat,
lin_formula,
log_formula,
Expand Down Expand Up @@ -79,7 +78,6 @@ fit_zi <- function(samp_dat,
lme4::glmer(log_reg_formula, data = samp_dat, family = 'binomial')
)


return(list(lmer = lmer_nz, glmer = glmer_z))

}
Expand Down Expand Up @@ -110,12 +108,12 @@ generate_mse <- function(.data,
u_lm <- u_lm[, order(match(colnames(u_lm), dom_order))]
u_glm <- u_glm[, order(match(colnames(u_glm), dom_order))]

dom_res_wide <- dom_preds_calc(beta_lm = beta_lm_mat,
dom_res_wide <- generate_preds(beta_lm = beta_lm_mat,
beta_glm = beta_glm_mat,
J = n_doms,
u_lm = u_lm,
u_glm = u_glm,
design_mats = design_mat_ls)
design_mats = design_mat_ls,
J = n_doms)


truth_ordered <- truth[order(match(truth$domain, dom_order)), ]
Expand Down Expand Up @@ -294,35 +292,4 @@ capture_all <- function(.f){

}

}

truncateText <- function(x) {
if (length(x) > 1)
x <- paste(x, collapse = "")
w <- options("width")$width
if (nchar(x) <= w)
return(x)

cont <- TRUE
out <- x
while (cont) {
tmp <- out[length(out)]
tmp2 <- substring(tmp, 1, w)

spaceIndex <- gregexpr("[[:space:]]", tmp2)[[1]]
stopIndex <- spaceIndex[length(spaceIndex) - 1] - 1
tmp <- c(substring(tmp2, 1, stopIndex),
substring(tmp, stopIndex + 1))
out <-
if (length(out) == 1)
tmp
else
c(out[1:(length(x) - 1)], tmp)
if (all(nchar(out) <= w))
cont <- FALSE
}

paste(out, collapse = "\n")
}


}
12 changes: 6 additions & 6 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// dom_preds_calc
SEXP dom_preds_calc(const Eigen::MatrixXd& beta_lm, const Eigen::MatrixXd& beta_glm, const int J, const Eigen::MatrixXd& u_lm, const Eigen::MatrixXd& u_glm, const Rcpp::List& design_mats);
RcppExport SEXP _saeczi_dom_preds_calc(SEXP beta_lmSEXP, SEXP beta_glmSEXP, SEXP JSEXP, SEXP u_lmSEXP, SEXP u_glmSEXP, SEXP design_matsSEXP) {
// generate_preds
SEXP generate_preds(const Eigen::MatrixXd& beta_lm, const Eigen::MatrixXd& beta_glm, const Eigen::MatrixXd& u_lm, const Eigen::MatrixXd& u_glm, const Rcpp::List& design_mats, int J);
RcppExport SEXP _saeczi_generate_preds(SEXP beta_lmSEXP, SEXP beta_glmSEXP, SEXP u_lmSEXP, SEXP u_glmSEXP, SEXP design_matsSEXP, SEXP JSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Eigen::MatrixXd& >::type beta_lm(beta_lmSEXP);
Rcpp::traits::input_parameter< const Eigen::MatrixXd& >::type beta_glm(beta_glmSEXP);
Rcpp::traits::input_parameter< const int >::type J(JSEXP);
Rcpp::traits::input_parameter< const Eigen::MatrixXd& >::type u_lm(u_lmSEXP);
Rcpp::traits::input_parameter< const Eigen::MatrixXd& >::type u_glm(u_glmSEXP);
Rcpp::traits::input_parameter< const Rcpp::List& >::type design_mats(design_matsSEXP);
rcpp_result_gen = Rcpp::wrap(dom_preds_calc(beta_lm, beta_glm, J, u_lm, u_glm, design_mats));
Rcpp::traits::input_parameter< int >::type J(JSEXP);
rcpp_result_gen = Rcpp::wrap(generate_preds(beta_lm, beta_glm, u_lm, u_glm, design_mats, J));
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_saeczi_dom_preds_calc", (DL_FUNC) &_saeczi_dom_preds_calc, 6},
{"_saeczi_generate_preds", (DL_FUNC) &_saeczi_generate_preds, 6},
{NULL, NULL, 0}
};

Expand Down
63 changes: 43 additions & 20 deletions src/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,62 @@ double sigmoid(double x) {
return 1.0 / (1.0 + std::exp(-x));
}

void preds_calc(Eigen::MatrixXd& res,
const Eigen::MatrixXd& beta_lm,
const Eigen::MatrixXd& beta_glm,
const Eigen::MatrixXd& dmat_lm,
const Eigen::MatrixXd& dmat_glm,
const Eigen::MatrixXd& u_lm,
const Eigen::MatrixXd& u_glm,
int j,
int B) {

// these are N_j x B
Eigen::MatrixXd pred_lm_j = (dmat_lm * beta_lm.transpose());
Eigen::MatrixXd pred_glm_j = (dmat_glm * beta_glm.transpose());

pred_lm_j.rowwise() += u_lm.col(j).transpose();
pred_glm_j.rowwise() += u_glm.col(j).transpose();

pred_glm_j = pred_glm_j.unaryExpr(&sigmoid);
Eigen::MatrixXd unit_preds_j = pred_lm_j.cwiseProduct(pred_glm_j);

Eigen::MatrixXd dom_preds_j = unit_preds_j.colwise().mean();
res.block(j, 0, 1, B) = dom_preds_j;

}

//[[Rcpp::export]]
SEXP dom_preds_calc(const Eigen::MatrixXd &beta_lm,
const Eigen::MatrixXd &beta_glm,
const int J,
const Eigen::MatrixXd &u_lm,
const Eigen::MatrixXd &u_glm,
const Rcpp::List &design_mats) {
SEXP generate_preds(const Eigen::MatrixXd& beta_lm,
const Eigen::MatrixXd& beta_glm,
const Eigen::MatrixXd& u_lm,
const Eigen::MatrixXd& u_glm,
const Rcpp::List& design_mats,
int J) {

int B = u_lm.rows();

// initialize result matrix
Eigen::MatrixXd result = Eigen::MatrixXd::Zero(J, B);

for (int j = 0; j < J; ++j) {

Rcpp::List dmats_j = design_mats[j];
Eigen::MatrixXd dmat_lm_j = dmats_j[0];
Eigen::MatrixXd dmat_glm_j = dmats_j[1];
// these are N_j x B
Eigen::MatrixXd pred_lm_j = (dmat_lm_j * beta_lm.transpose());
Eigen::MatrixXd pred_glm_j = (dmat_glm_j * beta_glm.transpose());
// now need to add u's
pred_lm_j.rowwise() += u_lm.col(j).transpose();
pred_glm_j.rowwise() += u_glm.col(j).transpose();

pred_glm_j = pred_glm_j.unaryExpr(&sigmoid);
Eigen::MatrixXd unit_preds_j = pred_lm_j.cwiseProduct(pred_glm_j);

Eigen::MatrixXd dom_preds_j = unit_preds_j.colwise().mean();
result.block(j, 0, 1, B) = dom_preds_j;

preds_calc(result,
beta_lm,
beta_glm,
dmat_lm_j,
dmat_glm_j,
u_lm,
u_glm,
j,
B);

}

return Rcpp::wrap(result);

}

0 comments on commit e9a5455

Please sign in to comment.