Skip to content

Commit

Permalink
speed up data parsing, keep data column order
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Nov 28, 2023
1 parent c2d0d88 commit 14d94b3
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 32 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: dynamite
Title: Bayesian Modeling and Causal Inference for Multivariate
Longitudinal Data
Version: 1.4.7
Version: 1.4.8
Authors@R: c(
person("Santtu", "Tikka", email = "santtuth@gmail.com",
role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4039-4342")),
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# dynamite 1.4.8

* Made several performance improvements to data parsing.
* `dynamite()` will now retain the original column order of `data` in all circumstances.

# dynamite 1.4.7

* Added a note on priors vignette regarding default priors for $\tau$ parameters.
Expand Down
20 changes: 12 additions & 8 deletions R/dynamite.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
#' needs the actual `CmdStan` software. See https://mc-stan.org/cmdstanr/ for
#' details.
#' @param verbose \[`logical(1)`]\cr All warnings and messages are suppressed
#' if set to `FALSE`. Defaults to `TRUE`.
#' if set to `FALSE`. Defaults to `TRUE`. Setting this to `FALSE` will also
#' disable checks for perfect collinearity in the model matrix.
#' @param verbose_stan \[`logical(1)`]\cr This is the `verbose` argument for
#' [rstan::sampling()]. Defaults to `FALSE`.
#' @param stanc_options \[`list()`]\cr This is the `stanc_options` argument
Expand Down Expand Up @@ -846,9 +847,9 @@ parse_data <- function(dformula, data, group_var, time_var, verbose) {
"Non-finite values were found in variable{?s}
{.var {data_names[!finite_cols]}} of {.arg data}."
)
data.table::setkeyv(data, c(group_var, time_var))
data <- fill_time(data, group_var, time_var)
drop_unused(dformula, data, group_var, time_var)
data.table::setkeyv(data, c(group_var, time_var))
data
}

Expand Down Expand Up @@ -1129,8 +1130,10 @@ fill_time <- function(data, group_var, time_var) {
n_group <- length(group)
time_duplicated <- logical(n_group)
time_missing <- logical(n_group)
group_bounds <- c(0, data[, max(.I), by = group_var]$V1)
for (i in seq_len(n_group)) {
idx_group <- which(data_groups == group[i])
idx_group <- seq(group_bounds[i] + 1, group_bounds[i + 1])
#idx_group <- which(data_groups == group[i])
sub <- data[idx_group, ]
time_duplicated[i] <- any(duplicated(sub[[time_var]]))
time_missing[i] <- !identical(sub[[time_var]], full_time)
Expand All @@ -1148,26 +1151,27 @@ fill_time <- function(data, group_var, time_var) {
all(time_ivals[!is.na(time_ivals)] %% time_scale == 0),
"Observations must occur at regular time intervals."
)

# time_missing <- data[,
# !identical(time_var, full_time),
# by = group_var,
# env = list(time_var = time_var, full_time = full_time)
# ]$V1
if (any(time_missing)) {
data_names <- names(data)
full_data_template <- data.table::as.data.table(
expand.grid(
time = full_time,
group = unique(data[[group_var]])
group = unique(data[[group_var]]),
time = full_time
)
)
names(full_data_template) <- c(time_var, group_var)
names(full_data_template) <- c(group_var, time_var)
data <- data.table::merge.data.table(
full_data_template,
data,
by = c(time_var, group_var),
by = c(group_var, time_var),
all.x = TRUE
)
data.table::setcolorder(data, data_names)
}
data
}
16 changes: 9 additions & 7 deletions R/predict_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ parse_newdata <- function(dformulas, newdata, data, type, eval_type,
newdata[[i]] <- factor(newdata[[i]], levels = l_orig)
}
}
data.table::setDT(newdata, key = c(group_var, time_var))
newdata <- fill_time_predict(
newdata,
group_var,
time_var,
time_scale = original_times[2L] - original_times[1L]
)
data.table::setDT(newdata, key = c(group_var, time_var))
clear_names <- intersect(names(newdata), clear_names)
if (length(clear_names) > 0L) {
# TODO no need check length when data.table package is updated
Expand Down Expand Up @@ -270,8 +270,9 @@ fill_time_predict <- function(data, group_var, time_var, time_scale) {
has_gaps = logical(n_group)
)
time_missing <- logical(n_group)
group_bounds <- c(0, data[, max(.I), by = group_var]$V1)
for (i in seq_len(n_group)) {
idx_group <- which(data_groups == group[i])
idx_group <- seq(group_bounds[i] + 1, group_bounds[i + 1])
sub <- data[idx_group, ]
time_duplicated[i] <- any(duplicated(sub[[time_var]]))
time_groups$has_missing[i] <- !identical(sub[[time_var]], full_time)
Expand All @@ -288,7 +289,6 @@ fill_time_predict <- function(data, group_var, time_var, time_scale) {
)
)
if (length(time) > 1L) {
original_order <- colnames(data)
# time_groups <- data[,
# {
# has_missing = !identical(time_var, full_time)
Expand All @@ -303,6 +303,7 @@ fill_time_predict <- function(data, group_var, time_var, time_scale) {
# )
# ]
if (any(time_groups$has_missing)) {
data_names <- colnames(data)
if (any(time_groups$has_gaps)) {
warning_(c(
"Time index variable {.var {time_var}} of {.arg newdata} has gaps:",
Expand All @@ -312,16 +313,17 @@ fill_time_predict <- function(data, group_var, time_var, time_scale) {
))
}
full_data_template <- data.table::as.data.table(expand.grid(
time = full_time,
group = unique(data[[group_var]])
group = unique(data[[group_var]]),
time = full_time
))
names(full_data_template) <- c(time_var, group_var)
names(full_data_template) <- c(group_var, time_var)
data <- data.table::merge.data.table(
full_data_template,
data,
by = c(time_var, group_var),
by = c(group_var, time_var),
all.x = TRUE
)
data.table::setcolorder(data, data_names)
}
}
data
Expand Down
37 changes: 22 additions & 15 deletions R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ prepare_stan_input <- function(dformula, data, group_var, time_var,
)
N <- n_unique(group)
K <- ncol(model_matrix)
X <- aperm(
array(
as.numeric(unlist(split(model_matrix, gl(T_full, 1L, N * T_full)))),
dim = c(N, K, T_full)
),
c(3L, 1L, 2L)
)[T_idx, , , drop = FALSE]
#X <- aperm(
# array(
# as.numeric(unlist(split(model_matrix, gl(T_full, 1L, N * T_full)))),
# dim = c(N, K, T_full)
# ),
# c(3L, 1L, 2L)
#)[T_idx, , , drop = FALSE]
X <- model_matrix[,]
dim(X) <- c(T_full, N, K)
X <- X[T_idx, , , drop = FALSE]
x_tmp <- X[1L, , , drop = FALSE]
sd_x <- pmax(
setNames(apply(X, 3L, sd, na.rm = TRUE), colnames(model_matrix)),
Expand All @@ -107,12 +110,14 @@ prepare_stan_input <- function(dformula, data, group_var, time_var,
for (i in seq_len(n_channels)) {
y <- resp[i]
y_name <- resp_names[i]
y_split <- split(
data[, .SD, .SDcols = c(y, group_var)],
by = group_var,
keep.by = FALSE
)
Y <- array(as.numeric(unlist(y_split)), dim = c(T_full, N))
#y_split <- split(
# data[, .SD, .SDcols = c(y, group_var)],
# by = group_var,
# keep.by = FALSE
#)
#Y <- array(as.numeric(unlist(y_split)), dim = c(T_full, N))
Y <- as.numeric(data[[y]])
dim(Y) <- c(T_full, N)
Y <- Y[T_idx, , drop = FALSE]
tmp <- initialize_univariate_channel(
dformula = dformula[[i]],
Expand Down Expand Up @@ -323,8 +328,10 @@ initialize_univariate_channel <- function(dformula, specials, fixed_pars,
for (spec in formula_special_funs) {
if (!is.null(specials[[spec]])) {
spec_idx <- seq.int(fixed + 1L, T_full)
spec_split <- split(specials[[spec]], group)
spec_array <- array(as.numeric(unlist(spec_split)), dim = c(T_full, N))
#spec_split <- split(specials[[spec]], group)
spec_array <- as.numeric(specials[[spec]])
dim(spec_array) <- c(T_full, N)
#spec_array <- array(as.numeric(unlist(spec_split)), dim = c(T_full, N))
spec_na <- spec_na | is.na(spec_array[spec_idx, , drop = FALSE])
spec_name <- paste0(spec, "_", y_name)
sampling[[spec_name]] <- ifelse_(
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
Binary file modified data/categorical_example_fit.rda
Binary file not shown.
Binary file modified data/gaussian_example_fit.rda
Binary file not shown.
Binary file modified data/gaussian_simulation_fit.rda
Binary file not shown.
Binary file modified data/multichannel_example_fit.rda
Binary file not shown.
4 changes: 3 additions & 1 deletion tests/testthat/test-edgecases.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ individuals <- 5
total_obs <- timepoints * individuals

test_data <- data.frame(
time = 1:timepoints,
group = gl(individuals, timepoints),
time = 1:timepoints,
offset = sample(50:100, size = total_obs, replace = TRUE),
trials = sample(50:100, size = total_obs, replace = TRUE)
) |>
Expand Down Expand Up @@ -487,6 +487,8 @@ test_that("data expansion to full time scale works", {
expected_data_single <- droplevels(expected_data_single)
expected_data_single$trials <- NULL
expected_data_single$offset <- NULL
expected_data_single$group <- NULL
expected_data_single$.group <- 1L
data.table::setDT(expected_data_single, key = c("time"))
expect_equal(fit_single$data, expected_data_single, ignore_attr = TRUE)
})
Expand Down

0 comments on commit 14d94b3

Please sign in to comment.