diff --git a/R/model_matrix.R b/R/model_matrix.R index 974f4c6..8d3ccdd 100644 --- a/R/model_matrix.R +++ b/R/model_matrix.R @@ -1,20 +1,26 @@ #' Combine `model.matrix` Objects of All Formulas of a `dynamiteformula` #' -#' @inheritParams dynamite +#' @inheritParams prepare_stan_input #' @srrstats {RE1.3, RE1.3a} `full_model.matrix` preserves relevant attributes. #' @noRd -full_model.matrix <- function(dformula, data, verbose) { +full_model.matrix <- function(dformula, data, group_var, fixed, verbose) { model_matrices <- vector(mode = "list", length = length(dformula)) model_matrices_type <- vector(mode = "list", length = length(dformula)) types <- c("fixed", "varying", "random") + idx <- data[, + .I[base::seq.int(fixed + 1L, .N)], + by = group, + env = list(fixed = fixed, group = group_var) + ]$V1 + data_nonfixed <- droplevels(data[idx, , env = list(idx = idx)]) for (i in seq_along(dformula)) { mm <- stats::model.matrix.lm( dformula[[i]]$formula, - data = data, + data = data_nonfixed, na.action = na.pass ) if (verbose) { - test_collinearity(dformula[[i]]$resp, mm, data) + test_collinearity(dformula[[i]]$resp, mm, data_nonfixed) } model_matrices_type[[i]] <- list() for (type in c("fixed", "varying", "random")) { @@ -22,14 +28,14 @@ full_model.matrix <- function(dformula, data, verbose) { if (!is.null(type_formula)) { model_matrices_type[[i]][[type]] <- stats::model.matrix.lm( type_formula, - data = data, + data = data_nonfixed, na.action = na.pass ) } } tmp <- do.call(cbind, model_matrices_type[[i]]) ifelse_(identical(length(tmp), 0L), - model_matrices[[i]] <- matrix(nrow = nrow(mm), ncol = 0), + model_matrices[[i]] <- matrix(nrow = nrow(mm), ncol = 0L), model_matrices[[i]] <- tmp ) } diff --git a/R/prepare_stan_input.R b/R/prepare_stan_input.R index ecaeb8c..c28295f 100644 --- a/R/prepare_stan_input.R +++ b/R/prepare_stan_input.R @@ -32,7 +32,7 @@ prepare_stan_input <- function(dformula, data, group_var, time_var, "Can't find variable{?s} {.var {resp[resp_missing]}} in {.arg data}." ) specials <- lapply(dformula, evaluate_specials, data = data) - model_matrix <- full_model.matrix(dformula, data, verbose) + model_matrix <- full_model.matrix(dformula, data, group_var, fixed, verbose) cg <- attr(dformula, "channel_groups") n_cg <- n_unique(cg) n_channels <- length(resp_names) @@ -85,8 +85,7 @@ prepare_stan_input <- function(dformula, data, group_var, time_var, N <- n_unique(group) K <- ncol(model_matrix) X <- model_matrix[, ] - dim(X) <- c(T_full, N, K) - X <- X[T_idx, , , drop = FALSE] + dim(X) <- c(T_full - fixed, N, K) x_tmp <- X[1L, , , drop = FALSE] sd_x <- pmax( stats::setNames(apply(X, 3L, sd, na.rm = TRUE), colnames(model_matrix)), diff --git a/data/categorical_example_fit.rda b/data/categorical_example_fit.rda index fc9e5d8..a4006db 100644 Binary files a/data/categorical_example_fit.rda and b/data/categorical_example_fit.rda differ diff --git a/data/gaussian_example_fit.rda b/data/gaussian_example_fit.rda index 35f8436..44d668a 100644 Binary files a/data/gaussian_example_fit.rda and b/data/gaussian_example_fit.rda differ diff --git a/data/multichannel_example_fit.rda b/data/multichannel_example_fit.rda index 332ec8b..d3ca89e 100644 Binary files a/data/multichannel_example_fit.rda and b/data/multichannel_example_fit.rda differ diff --git a/tests/testthat/test-warnings.R b/tests/testthat/test-warnings.R index 8b2a8cc..7bf2250 100644 --- a/tests/testthat/test-warnings.R +++ b/tests/testthat/test-warnings.R @@ -24,14 +24,24 @@ test_that("factor time conversion warns", { test_that("perfect collinearity warns", { f1 <- obs(y ~ -1 + x + z, family = "gaussian") f2 <- obs(y ~ z, family = "gaussian") - test_data1 <- data.frame(y = rnorm(10), x = rep(1, 10), z = rep(2, 10)) - test_data2 <- data.frame(y = rep(1, 10), x = rep(1, 10), z = rnorm(10)) + test_data1 <- data.table::data.table( + y = rnorm(10), + x = rep(1, 10), + z = rep(2, 10), + id = 1L + ) + test_data2 <- data.table::data.table( + y = rep(1, 10), + x = rep(1, 10), + z = rnorm(10), + id = 1L + ) expect_warning( - full_model.matrix(f1, test_data1, TRUE), + full_model.matrix(f1, test_data1, "id", 0L, TRUE), "Perfect collinearity found between predictor variables of channel `y`\\." ) expect_warning( - full_model.matrix(f2, test_data2, TRUE), + full_model.matrix(f2, test_data2, "id", 0L, TRUE), paste0( "Perfect collinearity found between response and predictor variable:\n", "i Response variable `y` is perfectly collinear ", @@ -39,7 +49,7 @@ test_that("perfect collinearity warns", { ) ) expect_warning( - full_model.matrix(f1, test_data2, TRUE), + full_model.matrix(f1, test_data2, "id", 0L, TRUE), paste0( "Perfect collinearity found between response and predictor variable:\n", "i Response variable `y` is perfectly collinear ", @@ -50,14 +60,15 @@ test_that("perfect collinearity warns", { test_that("too few observations warns", { f <- obs(y ~ x + z + w, family = "gaussian") - test_data <- data.frame( + test_data <- data.table::data.table( y = rnorm(3), x = rnorm(3), z = rnorm(3), - w = rnorm(3) + w = rnorm(3), + id = 1L ) expect_warning( - full_model.matrix(f, test_data, TRUE), + full_model.matrix(f, test_data, "id", 0L, TRUE), paste0( "Number of non-missing observations 3 in channel `y` ", "is less than 4, the number of predictors \\(including possible ", @@ -68,13 +79,14 @@ test_that("too few observations warns", { test_that("zero predictor warns", { f <- obs(y ~ -1 + x + z, family = "gaussian") - test_data <- data.frame( + test_data <- data.table::data.table( y = rnorm(6), x = c(NA, rnorm(2), NA, rnorm(2)), - z = factor(1:3) + z = factor(1:3), + id = 1L ) expect_warning( - full_model.matrix(f, test_data, TRUE), + full_model.matrix(f, test_data, "id", 0L, TRUE), paste0( "Predictor `z1` contains only zeros in the complete case rows of the ", "design matrix for the channel `y`\\."