diff --git a/R/RcppExports.R b/R/RcppExports.R index d5f0f215..5933210d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -53,20 +53,36 @@ forwardbackwardx <- function(transition, emission, init, obs, coef, X, numberOfS .Call(`_seqHMM_forwardbackwardx`, transition, emission, init, obs, coef, X, numberOfStates, forwardonly, threads) } -get_omega <- function(gamma_omega_raw, X, logspace) { - .Call(`_seqHMM_get_omega`, gamma_omega_raw, X, logspace) +get_omega <- function(gamma_raw, X, logspace) { + .Call(`_seqHMM_get_omega`, gamma_raw, X, logspace) +} + +get_omega_all <- function(gamma_raw, X, logspace) { + .Call(`_seqHMM_get_omega_all`, gamma_raw, X, logspace) } get_pi <- function(gamma_raw, X, logspace) { .Call(`_seqHMM_get_pi`, gamma_raw, X, logspace) } +get_pi_all <- function(gamma_raw, X, logspace) { + .Call(`_seqHMM_get_pi_all`, gamma_raw, X, logspace) +} + get_A <- function(gamma_raw, X, logspace, tv) { .Call(`_seqHMM_get_A`, gamma_raw, X, logspace, tv) } -get_B <- function(gamma_raw, X, M, logspace, add_missing, tv) { - .Call(`_seqHMM_get_B`, gamma_raw, X, M, logspace, add_missing, tv) +get_A_all <- function(gamma_raw, X, logspace, tv) { + .Call(`_seqHMM_get_A_all`, gamma_raw, X, logspace, tv) +} + +get_B <- function(gamma_raw, X, logspace, add_missing, tv) { + .Call(`_seqHMM_get_B`, gamma_raw, X, logspace, add_missing, tv) +} + +get_B_all <- function(gamma_raw, X, logspace, add_missing, tv) { + .Call(`_seqHMM_get_B_all`, gamma_raw, X, logspace, add_missing, tv) } logLikHMM <- function(transition, emission, init, obs, threads) { @@ -109,20 +125,20 @@ log_objective <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymbo .Call(`_seqHMM_log_objective`, transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, threads) } -log_objective_nhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B) { - .Call(`_seqHMM_log_objective_nhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B) +log_objective_nhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti) { + .Call(`_seqHMM_log_objective_nhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti) } -log_objective_nhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B) { - .Call(`_seqHMM_log_objective_nhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B) +log_objective_nhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti) { + .Call(`_seqHMM_log_objective_nhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti) } -log_objective_mnhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) { - .Call(`_seqHMM_log_objective_mnhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) +log_objective_mnhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti) { + .Call(`_seqHMM_log_objective_mnhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti) } -log_objective_mnhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) { - .Call(`_seqHMM_log_objective_mnhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) +log_objective_mnhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti) { + .Call(`_seqHMM_log_objective_mnhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti) } log_objectivex <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, coef, X, numberOfStates, threads) { diff --git a/R/build_mnhmm.R b/R/build_mnhmm.R index ecbd3483..13941b44 100644 --- a/R/build_mnhmm.R +++ b/R/build_mnhmm.R @@ -40,6 +40,8 @@ build_mnhmm <- function( np_pi = n_clusters * out$extras$np_pi, np_A = n_clusters * out$extras$np_A, np_B = n_clusters * out$extras$np_B, - np_omega = out$extras$np_omega + np_omega = out$extras$np_omega, + missing_X_transition = out$extras$missing_X_transition, + missing_X_emission = out$extras$missing_X_emission ) } diff --git a/R/build_nhmm.R b/R/build_nhmm.R index d30642db..0d6ca62a 100644 --- a/R/build_nhmm.R +++ b/R/build_nhmm.R @@ -24,6 +24,8 @@ build_nhmm <- function( tv_B = out$extras$tv_B, np_pi = out$extras$np_pi, np_A = out$extras$np_A, - np_B = out$extras$np_B + np_B = out$extras$np_B, + missing_X_transition = out$extras$missing_X_transition, + missing_X_emission = out$extras$missing_X_emission ) } diff --git a/R/check_build_arguments.R b/R/check_build_arguments.R index 1888b0c6..bc3d3727 100644 --- a/R/check_build_arguments.R +++ b/R/check_build_arguments.R @@ -1,8 +1,8 @@ #' Create observations for the model objects #' -#' Note that for historical reasons `length_of_sequences` refers to the maximum -#' length of sequences, whereas `sequence_lengths` refers to the actual non-void -#' lengths of each sequence. +#' Note that for backward compatibility reasons `length_of_sequences` refers +#' to the maximum length of sequences, whereas `sequence_lengths` refers to +#' the actual non-void lengths of each sequence. #'@noRd #' .check_observations <- function(x, channel_names = NULL, diff --git a/R/create_base_nhmm.R b/R/create_base_nhmm.R index 9f44c403..868c9498 100644 --- a/R/create_base_nhmm.R +++ b/R/create_base_nhmm.R @@ -172,6 +172,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states, coef_names_cluster = if(mixture) omega$coef_names else NULL ), extras = list( + missing_X_transition <- A$missing, + missing_X_emission <- B$missing, np_pi = pi$n_pars, np_A = A$n_pars, np_B = B$n_pars, diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 41339f25..eda37e1b 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -47,6 +47,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { K_s <- nrow(X_s) K_o <- nrow(X_o) K_d <- nrow(X_d) + Ti <- model$sequence_lengths dots <- list(...) if (is.null(dots$algorithm)) dots$algorithm <- "NLOPT_LD_LBFGS" @@ -65,7 +66,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { ) out <- log_objective_mnhmm_singlechannel( gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega + gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, + Ti ) list(objective = - out$loglik, gradient = - unlist(out[-1])) @@ -107,7 +109,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { ) out <- log_objective_mnhmm_multichannel( gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega + gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, + Ti ) list(objective = - out$loglik, gradient = - unlist(out[-1])) diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index c3b3a7db..03a07794 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -43,6 +43,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { K_i <- nrow(X_i) K_s <- nrow(X_s) K_o <- nrow(X_o) + Ti <- model$sequence_lengths dots <- list(...) if (is.null(dots$algorithm)) dots$algorithm <- "NLOPT_LD_LBFGS" @@ -58,7 +59,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { ) out <- log_objective_nhmm_singlechannel( gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, - iv_pi, iv_A, iv_B, tv_A, tv_B + iv_pi, iv_A, iv_B, tv_A, tv_B, Ti ) list(objective = - out$loglik, gradient = - unlist(out[-1])) @@ -87,7 +88,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { ) out <- log_objective_nhmm_multichannel( gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, - iv_pi, iv_A, iv_B, tv_A, tv_B + iv_pi, iv_A, iv_B, tv_A, tv_B, Ti ) list(objective = - out$loglik, gradient = - unlist(out[-1])) diff --git a/R/get_probs.R b/R/get_probs.R index 617e87e4..61431199 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -41,11 +41,9 @@ get_initial_probs.nhmm <- function(model, ...) { d <- data.frame( id = rep(ids, each = model$n_states), state = model$state_names, - estimate = c(apply( - model$X_initial, 2, function(z) { - get_pi(model$coefficients$gamma_pi_raw, z, FALSE) - } - )) + estimate = c( + get_pi_all(model$coefficients$gamma_pi_raw, model$X_initial, FALSE) + ) ) } stats::setNames(d, c(model$id_variable, "state", "estimate")) @@ -66,8 +64,9 @@ get_initial_probs.mnhmm <- function(model, ...) { cluster = model$cluster_names[i], id = rep(ids, each = model$n_states), state = model$state_names[[i]], - estimate = get_pi( - model$coefficients$gamma_pi_raw[[i]], model$X_initial[, 1L], FALSE + estimate = c( + get_pi(model$coefficients$gamma_pi_raw[[i]], model$X_initial[, 1L], + FALSE) ) ) }) @@ -80,11 +79,11 @@ get_initial_probs.mnhmm <- function(model, ...) { cluster = model$cluster_names[i], id = rep(ids, each = model$n_states), state = model$state_names[[i]], - estimate = c(apply( - model$X_initial, 2, function(z) { - get_pi(model$coefficients$gamma_pi_raw[[i]], z, FALSE) - } - )) + estimate = c( + get_pi_all( + model$coefficients$gamma_pi_raw[[i]], model$X_initial, FALSE + ) + ) ) }) ) @@ -109,6 +108,7 @@ get_initial_probs.mhmm <- function(model, ...) { get_transition_probs.nhmm <- function(model, ...) { S <- model$n_states T_ <- model$length_of_sequences + model$X_transition[attr(model, "missing_X_transition")] <- NA if (model$n_channels == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) @@ -133,14 +133,12 @@ get_transition_probs.nhmm <- function(model, ...) { time = rep(times, each = S^2), state_from = model$state_names, state_to = rep(model$state_names, each = S), - estimate = c(apply( - model$X_transition, 3, function(z) { - get_A( - model$coefficients$gamma_A_raw, matrix(z, ncol = T_), FALSE, - attr(model, "tv_A") - ) - } - )) + estimate = c( + get_A_all( + model$coefficients$gamma_A_raw, model$X_transition, FALSE, + attr(model, "tv_A") + ) + ) ) } stats::setNames( @@ -157,6 +155,7 @@ get_transition_probs.mnhmm <- function(model, ...) { S <- model$n_states T_ <- model$length_of_sequences D <- model$n_clusters + model$X_transition[attr(model, "missing_X_transition")] <- NA if (model$n_channels == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) @@ -191,14 +190,12 @@ get_transition_probs.mnhmm <- function(model, ...) { time = rep(times, each = S^2), state_from = model$state_names[[i]], state_to = rep(model$state_names[[i]], each = S), - estimate = c(apply( - model$X_transition, 3, function(z) { - get_A( - model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T_), FALSE, - attr(model, "tv_A") - ) - } - )) + estimate = c( + get_A_all( + model$coefficients$gamma_A_raw[[i]], model$X_transition, FALSE, + attr(model, "tv_A") + ) + ) ) }) ) @@ -229,6 +226,7 @@ get_emission_probs.nhmm <- function(model, ...) { C <- model$n_channels T_ <- model$length_of_sequences M <- model$n_symbols + model$X_emission[attr(model, "missing_X_emission")] <- NA if (C == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) @@ -250,8 +248,8 @@ get_emission_probs.nhmm <- function(model, ...) { state = model$state_names, channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), - estimate = unlist(get_B( - model$coefficients$gamma_B_raw[i], X, M[i], FALSE, FALSE, + estimate = c(get_B( + model$coefficients$gamma_B_raw[[i]], X, FALSE, FALSE, attr(model, "tv_B")) ) ) @@ -267,13 +265,11 @@ get_emission_probs.nhmm <- function(model, ...) { state = model$state_names, channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), - estimate = apply( - model$X_emission, 3, function(z) { - unlist(get_B( - model$coefficients$gamma_B_raw[i], matrix(z, ncol = T_), M[i], - FALSE, FALSE, attr(model, "tv_B") - )) - } + estimate = c( + get_B_all( + model$coefficients$gamma_B_raw[[i]], model$X_emission, + FALSE, FALSE, attr(model, "tv_B") + ) ) ) }) @@ -293,14 +289,14 @@ get_emission_probs.mnhmm <- function(model, ...) { D <- model$n_clusters T_ <- model$length_of_sequences M <- model$n_symbols + model$X_emission[attr(model, "missing_X_emission")] <- NA if (C == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) symbol_names <- list(model$symbol_names) - for (i in seq_len(D)) { - model$coefficients$gamma_B_raw[[i]] <- - list(model$coefficients$gamma_B_raw[[i]]) - } + model$coefficients$gamma_B_raw <- lapply( + model$coefficients$gamma_B_raw, list + ) } else { ids <- rownames(model$observations[[1]]) times <- colnames(model$observations[[1]]) @@ -322,7 +318,7 @@ get_emission_probs.mnhmm <- function(model, ...) { channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), estimate = unlist(get_B( - model$coefficients$gamma_B_raw[[j]][i], X, M[i], FALSE, FALSE, + model$coefficients$gamma_B_raw[[j]][[i]], X, FALSE, FALSE, attr(model, "tv_B")) ) ) @@ -344,13 +340,11 @@ get_emission_probs.mnhmm <- function(model, ...) { state = model$state_names[[j]], channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), - estimate = apply( - model$X_emission, 3, function(z) { - unlist(get_B( - model$coefficients$gamma_B_raw[[j]][i], - matrix(z, ncol = T_), M[i], FALSE, FALSE, attr(model, "tv_B") - )) - } + estimate = c( + get_B_all( + model$coefficients$gamma_B_raw[[j]][[i]], model$X_emission, + FALSE, FALSE,attr(model, "tv_B") + ) ) ) }) @@ -399,11 +393,11 @@ get_cluster_probs.mnhmm <- function(model, ...) { d <- data.frame( cluster = model$cluster_names, id = rep(ids, each = model$n_clusters), - estimate = c(apply( - model$X_cluster, 2, function(z) { - get_omega(model$coefficients$gamma_omega_raw, z, FALSE) - } - )) + estimate = c( + get_omega_all( + model$coefficients$gamma_omega_raw, model$X_cluster, FALSE + ) + ) ) } stats::setNames(d, c("cluster", model$id_variable, "estimate")) @@ -478,8 +472,8 @@ get_probs.nhmm <- function(model, newdata = NULL, remove_voids = TRUE, ...) { if (remove_voids) { list( initial_probs = out$initial_probs, - transition_probs = remove_voids(model, out$transition_probs), - emission_probs = remove_voids(model, out$emission_probs) + transition_probs = out$transition_probs[complete.cases(out$transition_probs), ], + emission_probs = out$emission_probs[complete.cases(out$emission_probs), ] ) } else out } @@ -526,8 +520,8 @@ get_probs.mnhmm <- function(model, newdata = NULL, remove_voids = TRUE, ...) { if (remove_voids) { list( initial_probs = out$initial_probs, - transition_probs = remove_voids(model, out$transition_probs), - emission_probs = remove_voids(model, out$emission_probs), + transition_probs = out$transition_probs[complete.cases(out$transition_probs), ], + emission_probs = out$emission_probs[complete.cases(out$emission_probs), ], cluster_probs = out$cluster_probs ) } else out diff --git a/R/model_matrix.R b/R/model_matrix.R index ca1417b6..0817623d 100644 --- a/R/model_matrix.R +++ b/R/model_matrix.R @@ -67,6 +67,7 @@ model_matrix_transition_formula <- function(formula, data, n_sequences, X <- array(1, c(length_of_sequences, n_sequences, 1L)) coef_names <- "(Intercept)" iv <- tv <- FALSE + missing_values <- integer(0) } else { X <- stats::model.matrix.lm( formula, @@ -77,10 +78,14 @@ model_matrix_transition_formula <- function(formula, data, n_sequences, if (length(missing_values) > 0) { ends <- sequence_lengths[match(data[[id]], unique(data[[id]]))] stopifnot_( - all(z <- data[missing_values, time] <= ends[missing_values]), + all(z <- data[missing_values, time] > ends[missing_values]), c( - "Missing cases are not allowed in covariates of `transition_formula`.", - "Use {.fn complete.cases} to detect them, then fix or impute them.", + paste0( + "Missing cases are not allowed in covariates of ", + "{.arg transition_formula}, unless they correspond to void ", + "response values at the end of the sequences.", + "Use {.fn complete.cases} to detect them, then fix or impute them." + ), paste0( "First missing value found for ID ", "{data[missing_values, id][which(!z)[1]]} at time point ", @@ -89,16 +94,19 @@ model_matrix_transition_formula <- function(formula, data, n_sequences, ) ) } - # Replace NAs in void cases with zero - X[is.na(X)] <- 0 coef_names <- colnames(X) dim(X) <- c(length_of_sequences, n_sequences, ncol(X)) + missing_values <- which(is.na(X)) + # Replace NAs in void cases with zero + X[is.na(X)] <- 0 n_pars <- n_states * (n_states - 1L) * dim(X)[3] iv <- iv_X(X) tv <- tv_X(X) } - list(formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), - coef_names = coef_names, iv = iv, tv = tv) + list( + formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), + coef_names = coef_names, iv = iv, tv = tv, missing = missing_values + ) } #' Create the Model Matrix based on NHMM Formulas #' @@ -113,6 +121,7 @@ model_matrix_emission_formula <- function(formula, data, n_sequences, X <- array(1, c(length_of_sequences, n_sequences, 1L)) coef_names <- "(Intercept)" iv <- tv <- FALSE + missing_values <- integer(0) } else { X <- stats::model.matrix.lm( formula, @@ -123,28 +132,37 @@ model_matrix_emission_formula <- function(formula, data, n_sequences, if (length(missing_values) > 0) { ends <- sequence_lengths[match(data[[id]], unique(data[[id]]))] stopifnot_( - all(z <- data[missing_values, time] <= ends[missing_values]), - c( - "Missing cases are not allowed in covariates of `emission_formula`.", - "Use {.fn complete.cases} to detect them, then fix or impute them.", - paste0( - "First missing value found for ID ", - "{data[missing_values, id][which(!z)[1]]} at time point ", - "{data[missing_values, time][which(!z)[1]]}." - ) + all(z <- data[missing_values, time] > ends[missing_values]), + c(paste0( + "Missing cases are not allowed in covariates of ", + "{.arg emission_formula}, unless they correspond to missing ", + "void reponses at the end of the sequences. ", + "Use {.fn complete.cases} to detect them, then fix or impute them. ", + "Note that the missing covariates in {.arg emission_formula} ", + "corresponding to time points where response variables are also ", + "missing can be set to arbitrary value." + ), + paste0( + "First missing value found for ID ", + "{data[missing_values, id][which(!z)[1]]} at time point ", + "{data[missing_values, time][which(!z)[1]]}." + ) ) ) } - # Replace NAs in void cases with zero - X[is.na(X)] <- 0 coef_names <- colnames(X) dim(X) <- c(length_of_sequences, n_sequences, ncol(X)) + missing_values <- which(is.na(X)) + # Replace NAs in void cases with zero + X[is.na(X)] <- 0 n_pars <- sum(n_states * (n_symbols - 1L) * dim(X)[3]) iv <- iv_X(X) tv <- tv_X(X) } - list(formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), - coef_names = coef_names, iv = iv, tv = tv) + list( + formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), + coef_names = coef_names, iv = iv, tv = tv, missing = missing_values + ) } #' Create the Model Matrix based on NHMM Formulas #' diff --git a/R/update.R b/R/update.R index 2fcaccee..c26ea01d 100644 --- a/R/update.R +++ b/R/update.R @@ -1,7 +1,7 @@ #' Update Covariate Values of NHMM #' #' This function can be used to replace original covariate values of NHMMs. -#' The model formulae and estimated coefficients are not altered. +#' The responses, model formulae and estimated coefficients are not altered. #' @param object An object of class `nhmm` or `mnhmm`. #' @param newdata A data frame containing the new covariate values. #' @param ... Ignored. @@ -9,8 +9,6 @@ #' @export update.nhmm <- function(object, newdata, ...) { newdata <- .check_data(newdata, object$time_variable, object$id_variable) - object$n_sequences <- length(unique(newdata[[object$id_variable]])) - object$length_of_sequences <- length(unique(newdata[[object$time_variable]])) if (!is.null(object$data)) object$data <- newdata X <- model_matrix_initial_formula( object$initial_formula, newdata, object$n_sequences, @@ -27,6 +25,7 @@ update.nhmm <- function(object, newdata, ...) { object$X_transition <- X$X attr(object, "iv_A") <- X$iv attr(object, "tv_A") <- X$tv + attr(object, "missing_X_transition") <- X$missing_X_transition X <- model_matrix_emission_formula( object$emission_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$n_symbols, @@ -36,14 +35,13 @@ update.nhmm <- function(object, newdata, ...) { object$X_emission <- X$X attr(object, "iv_B") <- X$iv attr(object, "tv_B") <- X$tv + attr(object, "missing_X_emission") <- X$missing_X_emission object } #' @rdname update_nhmm #' @export update.mnhmm <- function(object, newdata, ...) { newdata <- .check_data(newdata, object$time_variable, object$id_variable) - object$n_sequences <- length(unique(newdata[[object$id_variable]])) - object$length_of_sequences <- length(unique(newdata[[object$time_variable]])) if (!is.null(object$data)) object$data <- newdata X <- model_matrix_initial_formula( object$initial_formula, newdata, object$n_sequences, @@ -60,6 +58,7 @@ update.mnhmm <- function(object, newdata, ...) { object$X_transition <- X$X attr(object, "iv_A") <- X$iv attr(object, "tv_A") <- X$tv + attr(object, "missing_X_transition") <- X$missing_X_transition X <- model_matrix_emission_formula( object$emission_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$n_symbols, @@ -69,6 +68,7 @@ update.mnhmm <- function(object, newdata, ...) { object$X_emission <- X$X attr(object, "iv_B") <- X$iv attr(object, "tv_B") <- X$tv + attr(object, "missing_X_emission") <- X$missing_X_emission X <- model_matrix_cluster_formula( object$cluster_formula, newdata, object$n_sequences, object$n_clusters, object$time_variable, object$id_variable diff --git a/R/utilities.R b/R/utilities.R index 1b622b18..de0ffa69 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -121,27 +121,3 @@ create_emissionArray <- function(model) { } emissionArray } - -#' Remove rows corresponding to void symbols -#' @noRd -remove_voids <- function(model, x) { - - x_id <- x[[model$id_variable]] - x_time <- x[[model$time_variable]] - if (model$n_channels == 1) { - time <- colnames(model$observations) - id <- rownames(model$observations) - } else { - time <- colnames(model$observations[[1]]) - id <- rownames(model$observations[[1]]) - } - do.call( - rbind, - lapply(seq_len(model$n_sequences), function(i) { - idx <- which( - x_id == id[i] & x_time %in% time[seq_len(model$sequence_lengths[i])] - ) - x[idx, ] - }) - ) -} diff --git a/man/cluster_probs.Rd b/man/cluster_probs.Rd index 3aec0632..c1571969 100644 --- a/man/cluster_probs.Rd +++ b/man/cluster_probs.Rd @@ -13,7 +13,9 @@ get_cluster_probs(model, ...) \method{get_cluster_probs}{mhmm}(model, ...) } \arguments{ -\item{model}{An object of class \code{mnhmm} or `mhmm.} +\item{model}{A mixture hidden Markov model.} + +\item{...}{Ignored.} } \description{ Extract the Prior Cluster Probabilities of MHMM or MNHMM diff --git a/man/emission_probs.Rd b/man/emission_probs.Rd index b87aafb7..729c44d5 100644 --- a/man/emission_probs.Rd +++ b/man/emission_probs.Rd @@ -18,6 +18,11 @@ get_emission_probs(model, ...) \method{get_emission_probs}{mhmm}(model, ...) } +\arguments{ +\item{model}{A hidden Markov model.} + +\item{...}{Ignored.} +} \description{ Extract the Emission Probabilities of Hidden Markov Model } diff --git a/man/initial_probs.Rd b/man/initial_probs.Rd index 00cf048b..e6ca2571 100644 --- a/man/initial_probs.Rd +++ b/man/initial_probs.Rd @@ -18,6 +18,11 @@ get_initial_probs(model, ...) \method{get_initial_probs}{mhmm}(model, ...) } +\arguments{ +\item{model}{A hidden Markov model.} + +\item{...}{Ignored.} +} \description{ Extract the Initial State Probabilities of Hidden Markov Model } diff --git a/man/plot.ame.Rd b/man/plot.ame.Rd index d2baa097..cfae6637 100644 --- a/man/plot.ame.Rd +++ b/man/plot.ame.Rd @@ -15,6 +15,8 @@ \item{probs}{A numeric vector of length 2 with the lower and upper limits for confidence intervals. Default is \code{c(0.025, 0.975)}. If the limits are not found in the input object \code{x}, an error is thrown.} + +\item{...}{Ignored.} } \description{ Visualize Average Marginal Effects diff --git a/man/transition_probs.Rd b/man/transition_probs.Rd index 6b6099e7..a8d8bfbf 100644 --- a/man/transition_probs.Rd +++ b/man/transition_probs.Rd @@ -18,6 +18,11 @@ get_transition_probs(model, ...) \method{get_transition_probs}{mhmm}(model, ...) } +\arguments{ +\item{model}{A hidden Markov model.} + +\item{...}{Ignored.} +} \description{ Extract the State Transition Probabilities of Hidden Markov Model } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index d995b17f..792dae85 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -240,31 +240,57 @@ BEGIN_RCPP END_RCPP } // get_omega -arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace); -RcppExport SEXP _seqHMM_get_omega(SEXP gamma_omega_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { +arma::vec get_omega(const arma::mat& gamma_raw, const arma::vec& X, const bool logspace); +RcppExport SEXP _seqHMM_get_omega(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const arma::mat& >::type gamma_omega_raw(gamma_omega_rawSEXP); - Rcpp::traits::input_parameter< const arma::vec >::type X(XSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + rcpp_result_gen = Rcpp::wrap(get_omega(gamma_raw, X, logspace)); + return rcpp_result_gen; +END_RCPP +} +// get_omega_all +arma::mat get_omega_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace); +RcppExport SEXP _seqHMM_get_omega_all(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); - rcpp_result_gen = Rcpp::wrap(get_omega(gamma_omega_raw, X, logspace)); + rcpp_result_gen = Rcpp::wrap(get_omega_all(gamma_raw, X, logspace)); return rcpp_result_gen; END_RCPP } // get_pi -arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const bool logspace); +arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec& X, const bool logspace); RcppExport SEXP _seqHMM_get_pi(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::mat& >::type gamma_raw(gamma_rawSEXP); - Rcpp::traits::input_parameter< const arma::vec >::type X(XSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP); Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); rcpp_result_gen = Rcpp::wrap(get_pi(gamma_raw, X, logspace)); return rcpp_result_gen; END_RCPP } +// get_pi_all +arma::mat get_pi_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace); +RcppExport SEXP _seqHMM_get_pi_all(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + rcpp_result_gen = Rcpp::wrap(get_pi_all(gamma_raw, X, logspace)); + return rcpp_result_gen; +END_RCPP +} // get_A arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool tv); RcppExport SEXP _seqHMM_get_A(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP, SEXP tvSEXP) { @@ -279,19 +305,47 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// get_A_all +arma::field get_A_all(const arma::cube& gamma_raw, const arma::cube& X, const bool logspace, const bool tv); +RcppExport SEXP _seqHMM_get_A_all(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP, SEXP tvSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::cube& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP); + rcpp_result_gen = Rcpp::wrap(get_A_all(gamma_raw, X, logspace, tv)); + return rcpp_result_gen; +END_RCPP +} // get_B -arma::field get_B(const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, const bool logspace, const bool add_missing, const bool tv); -RcppExport SEXP _seqHMM_get_B(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP MSEXP, SEXP logspaceSEXP, SEXP add_missingSEXP, SEXP tvSEXP) { +arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool add_missing, const bool tv); +RcppExport SEXP _seqHMM_get_B(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP, SEXP add_missingSEXP, SEXP tvSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const arma::field& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type gamma_raw(gamma_rawSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); - Rcpp::traits::input_parameter< const arma::uvec& >::type M(MSEXP); Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); Rcpp::traits::input_parameter< const bool >::type add_missing(add_missingSEXP); Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP); - rcpp_result_gen = Rcpp::wrap(get_B(gamma_raw, X, M, logspace, add_missing, tv)); + rcpp_result_gen = Rcpp::wrap(get_B(gamma_raw, X, logspace, add_missing, tv)); + return rcpp_result_gen; +END_RCPP +} +// get_B_all +arma::field get_B_all(const arma::cube& gamma_raw, const arma::cube& X, const bool logspace, const bool add_missing, const bool tv); +RcppExport SEXP _seqHMM_get_B_all(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP, SEXP add_missingSEXP, SEXP tvSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::cube& >::type gamma_raw(gamma_rawSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X(XSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type add_missing(add_missingSEXP); + Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP); + rcpp_result_gen = Rcpp::wrap(get_B_all(gamma_raw, X, logspace, add_missing, tv)); return rcpp_result_gen; END_RCPP } @@ -468,8 +522,8 @@ BEGIN_RCPP END_RCPP } // log_objective_nhmm_singlechannel -Rcpp::List log_objective_nhmm_singlechannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B); -RcppExport SEXP _seqHMM_log_objective_nhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP) { +Rcpp::List log_objective_nhmm_singlechannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uvec& Ti); +RcppExport SEXP _seqHMM_log_objective_nhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP TiSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -485,13 +539,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B)); + Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti)); return rcpp_result_gen; END_RCPP } // log_objective_nhmm_multichannel -Rcpp::List log_objective_nhmm_multichannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B); -RcppExport SEXP _seqHMM_log_objective_nhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP) { +Rcpp::List log_objective_nhmm_multichannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uvec& Ti); +RcppExport SEXP _seqHMM_log_objective_nhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP TiSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -508,13 +563,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B)); + Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, Ti)); return rcpp_result_gen; END_RCPP } // log_objective_mnhmm_singlechannel -Rcpp::List log_objective_mnhmm_singlechannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega); -RcppExport SEXP _seqHMM_log_objective_mnhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP) { +Rcpp::List log_objective_mnhmm_singlechannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega, const arma::uvec& Ti); +RcppExport SEXP _seqHMM_log_objective_mnhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP, SEXP TiSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -533,13 +589,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); Rcpp::traits::input_parameter< const bool >::type iv_omega(iv_omegaSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega)); + Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti)); return rcpp_result_gen; END_RCPP } // log_objective_mnhmm_multichannel -Rcpp::List log_objective_mnhmm_multichannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega); -RcppExport SEXP _seqHMM_log_objective_mnhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP) { +Rcpp::List log_objective_mnhmm_multichannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega, const arma::uvec& Ti); +RcppExport SEXP _seqHMM_log_objective_mnhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP, SEXP TiSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -559,7 +616,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); Rcpp::traits::input_parameter< const bool >::type iv_omega(iv_omegaSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega)); + Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega, Ti)); return rcpp_result_gen; END_RCPP } @@ -771,9 +829,13 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_forwardbackward", (DL_FUNC) &_seqHMM_forwardbackward, 6}, {"_seqHMM_forwardbackwardx", (DL_FUNC) &_seqHMM_forwardbackwardx, 9}, {"_seqHMM_get_omega", (DL_FUNC) &_seqHMM_get_omega, 3}, + {"_seqHMM_get_omega_all", (DL_FUNC) &_seqHMM_get_omega_all, 3}, {"_seqHMM_get_pi", (DL_FUNC) &_seqHMM_get_pi, 3}, + {"_seqHMM_get_pi_all", (DL_FUNC) &_seqHMM_get_pi_all, 3}, {"_seqHMM_get_A", (DL_FUNC) &_seqHMM_get_A, 4}, - {"_seqHMM_get_B", (DL_FUNC) &_seqHMM_get_B, 6}, + {"_seqHMM_get_A_all", (DL_FUNC) &_seqHMM_get_A_all, 4}, + {"_seqHMM_get_B", (DL_FUNC) &_seqHMM_get_B, 5}, + {"_seqHMM_get_B_all", (DL_FUNC) &_seqHMM_get_B_all, 5}, {"_seqHMM_logLikHMM", (DL_FUNC) &_seqHMM_logLikHMM, 5}, {"_seqHMM_logLikMixHMM", (DL_FUNC) &_seqHMM_logLikMixHMM, 8}, {"_seqHMM_logSumExp", (DL_FUNC) &_seqHMM_logSumExp, 1}, @@ -784,10 +846,10 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_log_logLikHMM", (DL_FUNC) &_seqHMM_log_logLikHMM, 5}, {"_seqHMM_log_logLikMixHMM", (DL_FUNC) &_seqHMM_log_logLikMixHMM, 8}, {"_seqHMM_log_objective", (DL_FUNC) &_seqHMM_log_objective, 9}, - {"_seqHMM_log_objective_nhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_singlechannel, 12}, - {"_seqHMM_log_objective_nhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_multichannel, 13}, - {"_seqHMM_log_objective_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_singlechannel, 15}, - {"_seqHMM_log_objective_mnhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_multichannel, 16}, + {"_seqHMM_log_objective_nhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_singlechannel, 13}, + {"_seqHMM_log_objective_nhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_multichannel, 14}, + {"_seqHMM_log_objective_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_singlechannel, 16}, + {"_seqHMM_log_objective_mnhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_multichannel, 17}, {"_seqHMM_log_objectivex", (DL_FUNC) &_seqHMM_log_objectivex, 12}, {"_seqHMM_objective", (DL_FUNC) &_seqHMM_objective, 9}, {"_seqHMM_objectivex", (DL_FUNC) &_seqHMM_objectivex, 12}, diff --git a/src/get_parameters.cpp b/src/get_parameters.cpp index 8b7d468e..7c27b4ea 100644 --- a/src/get_parameters.cpp +++ b/src/get_parameters.cpp @@ -5,18 +5,30 @@ // gamma_omega_raw is (D - 1) x K (start from, covariates) // X a vector of length K // [[Rcpp::export]] -arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace) { - arma::mat gamma_omega = arma::join_cols(arma::zeros(gamma_omega_raw.n_cols), gamma_omega_raw); - return softmax(gamma_omega * X, logspace); +arma::vec get_omega(const arma::mat& gamma_raw, const arma::vec& X, const bool logspace) { + arma::mat beta = arma::join_cols(arma::zeros(gamma_raw.n_cols), gamma_raw); + return softmax(beta * X, logspace); +} +// [[Rcpp::export]] +arma::mat get_omega_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace) { + arma::mat beta = arma::join_cols(arma::zeros(gamma_raw.n_cols), gamma_raw); + return softmax(beta * X, logspace).t(); } - // gamma_raw is (S - 1) x K (start from, covariates) // X a vector of length K // [[Rcpp::export]] -arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const bool logspace) { +arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec& X, const bool logspace) { arma::mat beta = arma::join_cols(arma::zeros(gamma_raw.n_cols), gamma_raw); return softmax(beta * X, logspace); } +// gamma_raw is (S - 1) x K (start from, covariates) +// X a K x N matrix +// [[Rcpp::export]] +arma::mat get_pi_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace) { + arma::mat beta = arma::join_cols(arma::zeros(gamma_raw.n_cols), gamma_raw); + return softmax(beta * X, logspace).t(); +} + // gamma_raw is (S - 1) x K x S (transition to, covariates, transition from) // X is K x T matrix (covariates, time points) // [[Rcpp::export]] @@ -46,8 +58,48 @@ arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, } return A; } +// gamma_raw is (S - 1) x K x S (transition to, covariates, transition from) +// X is K x T x N cube (covariates, time points, sequences) +// [[Rcpp::export]] +arma::field get_A_all(const arma::cube& gamma_raw, + const arma::cube& X, const bool logspace, + const bool tv) { + unsigned int S = gamma_raw.n_slices; + unsigned int K = X.n_rows; + unsigned int T = X.n_cols; + unsigned int N = X.n_slices; + arma::cube beta(S, K, S); + for (unsigned int i = 0; i < S; i++) { + beta.slice(i) = arma::join_cols( + arma::zeros(K), gamma_raw.slice(i) + ); + } + arma::field A(N); + arma::mat Atmp(S, S); + if (tv) { + for (unsigned int i = 0; i < N; i++) { + A(i) = arma::cube(S, S, T); + for (unsigned int t = 0; t < T; t++) { // time + for (unsigned int j = 0; j < S; j ++) { // from states + Atmp.col(j) = softmax(beta.slice(j) * X.slice(i).col(t), logspace); + } + A(i).slice(t) = Atmp.t(); + } + } + } else { + for (unsigned int i = 0; i < N; i++) { + A(i) = arma::cube(S, S, T); + for (unsigned int j = 0; j < S; j ++) { // from states + Atmp.col(j) = softmax(beta.slice(j) * X.slice(i).col(0), logspace); + } + A(i).each_slice() = Atmp.t(); + } + } + return A; +} // gamma_raw is (M - 1) x K x S (symbols, covariates, transition from) // X is K x T (covariates, time points) +// [[Rcpp::export]] arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool add_missing, const bool tv) { unsigned int S = gamma_raw.n_slices; @@ -84,7 +136,6 @@ arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, } // gamma_raw is a a field of (M_c - 1) x K x S cubes // X is K x T (covariates, time point) -// [[Rcpp::export]] arma::field get_B( const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, @@ -96,3 +147,51 @@ arma::field get_B( } return B; } + +// gamma_raw is (M - 1) x K x S (symbols, covariates, transition from) +// X is K x T (covariates, time points) +// [[Rcpp::export]] +arma::field get_B_all( + const arma::cube& gamma_raw, const arma::cube& X, const bool logspace, + const bool add_missing, const bool tv) { + unsigned int S = gamma_raw.n_slices; + unsigned int M = gamma_raw.n_rows + 1; + unsigned int K = X.n_rows; + unsigned int T = X.n_cols; + unsigned int N = X.n_slices; + arma::cube beta(M, K, S); + for (unsigned int i = 0; i < S; i++) { + beta.slice(i) = arma::join_cols( + arma::zeros(K), gamma_raw.slice(i) + ); + } + arma::field B(N); + arma::mat Btmp(M + add_missing, S); + if (add_missing) { + Btmp.row(M).fill(1.0 - logspace); + } + if (tv) { + for (unsigned int i = 0; i < N; i++) { + B(i) = arma::cube(S, M + add_missing, T); + for (unsigned int t = 0; t < T; t++) { // time + for (unsigned int j = 0; j < S; j ++) { // from states + Btmp.col(j).rows(0, M - 1) = softmax( + beta.slice(j) * X.slice(i).col(t), logspace + ); + } + B(i).slice(t) = Btmp.t(); + } + } + } else { + for (unsigned int i = 0; i < N; i++) { + B(i) = arma::cube(S, M + add_missing, T); + for (unsigned int j = 0; j < S; j ++) { // from states + Btmp.col(j).rows(0, M - 1) = softmax( + beta.slice(j) * X.slice(i).col(0), logspace + ); + } + B(i).each_slice() = Btmp.t(); + } + } + return B; +} \ No newline at end of file diff --git a/src/get_parameters.h b/src/get_parameters.h index e61c03cc..f1913a9c 100644 --- a/src/get_parameters.h +++ b/src/get_parameters.h @@ -4,21 +4,21 @@ #include arma::vec get_omega( - const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace + const arma::mat& gamma_raw, const arma::vec& X, const bool logspace ); arma::vec get_pi( - const arma::mat& beta_raw, const arma::vec X, const bool logspace + const arma::mat& gamma_raw, const arma::vec& X, const bool logspace ); arma::cube get_A( - const arma::cube& beta_raw, const arma::mat& X, const bool logspace, + const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool tv = true ); arma::cube get_B( - const arma::cube& beta_raw, const arma::mat& X, const bool logspace, + const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool add_missing = false, const bool tv = true ); arma::field get_B( - const arma::field& beta_raw, const arma::mat& X, + const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, const bool logspace, const bool add_missing = false, const bool tv = true ); diff --git a/src/log_objective_nhmm.cpp b/src/log_objective_nhmm.cpp index 58e24dc5..bcd973e4 100644 --- a/src/log_objective_nhmm.cpp +++ b/src/log_objective_nhmm.cpp @@ -10,7 +10,7 @@ Rcpp::List log_objective_nhmm_singlechannel( const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, - const bool tv_A, const bool tv_B) { + const bool tv_A, const bool tv_B, const arma::uvec& Ti) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -44,7 +44,7 @@ Rcpp::List log_objective_nhmm_singlechannel( if (iv_B || i == 0) { log_B = get_B(gamma_B_raw, X_o.slice(i), true, true, tv_B); } - for (unsigned int t = 0; t < T; t++) { + for (unsigned int t = 0; t < Ti[i]; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = log_B(s, obs(t, i), t); } @@ -64,7 +64,7 @@ Rcpp::List log_objective_nhmm_singlechannel( grad_pi += gradmat_S.rows(1, S - 1) * gradvec_S * X_i.col(i).t(); // gradient wrt gamma_A - for (unsigned int t = 0; t < (T - 1); t++) { + for (unsigned int t = 0; t < (Ti[i] - 1); t++) { A = exp(log_A.slice(t)); for (unsigned int s = 0; s < S; s++) { // d loglik / d a_s @@ -85,7 +85,7 @@ Rcpp::List log_objective_nhmm_singlechannel( double grad = exp(log_Pi(s) + log_beta(s, 0) - ll); grad_B.slice(s) += gradmat_M.rows(1, M - 1).col(obs(0, i)) * grad * X_o.slice(i).col(0).t(); } - for (unsigned int t = 0; t < (T - 1); t++) { + for (unsigned int t = 0; t < (Ti[i] - 1); t++) { if (obs(t + 1, i) < M) { Brow = exp(log_B.slice(t + 1).row(s).cols(0, M - 1)); gradmat_M = -Brow.t() * Brow; @@ -112,120 +112,121 @@ Rcpp::List log_objective_nhmm_multichannel( const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, - const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B) { - - unsigned int C = M.n_elem; - unsigned int N = X_s.n_slices; - unsigned int T = X_s.n_cols; - unsigned int S = gamma_A_raw.n_slices; - arma::vec loglik(N); - arma::mat log_alpha(S, T); - arma::mat log_beta(S, T); - arma::mat log_py(S, T); - arma::vec log_Pi(S); - arma::cube log_A(S, S, T); - arma::field log_B(C); - arma::mat grad_pi(S - 1, X_i.n_rows, arma::fill::zeros); - arma::cube grad_A(S - 1, X_s.n_rows, S, arma::fill::zeros); - arma::field grad_B(C); - for (unsigned int c = 0; c < C; c++) { - grad_B(c) = arma::cube(M(c) - 1, X_o.n_rows, S, arma::fill::zeros); - } - arma::vec gradvec_S(S); - arma::mat gradmat_S(S, S); - arma::mat A(S, S); - for (unsigned int i = 0; i < N; i++) { - if (iv_pi || i == 0) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); - } - if (iv_A || i == 0) { - log_A = get_A(gamma_A_raw, X_s.slice(i), true, tv_A); - } - if (iv_B || i == 0) { - log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true, tv_B); - } - for (unsigned int t = 0; t < T; t++) { - for (unsigned int s = 0; s < S; s++) { - log_py(s, t) = 0; - for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c)(s, obs(c, t, i), t); - } - } - } - log_alpha = univariate_forward_nhmm(log_Pi, log_A, log_py); - log_beta = univariate_backward_nhmm(log_A, log_py); - double ll = logSumExp(log_alpha.col(T - 1)); - loglik(i) = ll; - // gradient wrt gamma_pi - // d loglik / d pi - gradvec_S = exp(log_py.col(0) + log_beta.col(0) - ll); - // d pi / d gamma_pi - arma::vec Pi = exp(log_Pi); - gradmat_S = -Pi * Pi.t(); - gradmat_S.diag() += Pi; - grad_pi += gradmat_S.rows(1, S - 1) * gradvec_S * X_i.col(i).t(); - // gradient wrt gamma_A - for (unsigned int t = 0; t < (T - 1); t++) { - A = exp(log_A.slice(t)); - for (unsigned int s = 0; s < S; s++) { - // d loglik / d a_s - gradvec_S = exp(log_alpha(s, t) + log_py.col(t + 1) + - log_beta.col(t + 1) - ll); - // d a_s / d gamma_A - gradmat_S = -A.row(s).t() * A.row(s); - gradmat_S.diag() += A.row(s); - grad_A.slice(s) += gradmat_S.rows(1, S - 1) * - gradvec_S * X_s.slice(i).col(t).t(); - } - } - for (unsigned int c = 0; c < C; c++) { - arma::mat gradmat_M(M(c), M(c)); - arma::rowvec Brow(M(c)); - double logpy; - for (unsigned int s = 0; s < S; s++) { - if (obs(c, 0, i) < M(c)) { - Brow = exp(log_B(c).slice(0).row(s).cols(0, M(c) - 1)); - gradmat_M = -Brow.t() * Brow; - gradmat_M.diag() += Brow; - logpy = 0; - for (unsigned int cc = 0; cc < C; cc++) { - if (cc != c) { - logpy += log_B(cc)(s, obs(cc, 0, i), 0); - } - } - double grad = exp(log_Pi(s) + logpy + log_beta(s, 0) - ll); - grad_B(c).slice(s) += - gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * - grad * X_o.slice(i).col(0).t(); - } - for (unsigned int t = 0; t < (T - 1); t++) { - if (obs(c, t + 1, i) < M(c)) { - Brow = exp(log_B(c).slice(t + 1).row(s).cols(0, M(c) - 1)); - gradmat_M = -Brow.t() * Brow; - gradmat_M.diag() += Brow; - logpy = 0; - for (unsigned int cc = 0; cc < C; cc++) { - if (cc != c) { - logpy += log_B(cc)(s, obs(cc, t + 1, i), t + 1); - } - } - double grad = arma::accu( - exp(log_alpha.col(t) + log_A.slice(t).col(s) + - logpy + log_beta(s, t + 1) - ll)); - grad_B(c).slice(s) += - gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * - grad * X_o.slice(i).col(t + 1).t(); - } - } - } - } - } - return Rcpp::List::create( - Rcpp::Named("loglik") = sum(loglik), - Rcpp::Named("gradient_pi") = Rcpp::wrap(grad_pi), - Rcpp::Named("gradient_A") = Rcpp::wrap(grad_A), - Rcpp::Named("gradient_B") = Rcpp::wrap(grad_B) - ); + const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, + const arma::uvec& Ti) { + + unsigned int C = M.n_elem; + unsigned int N = X_s.n_slices; + unsigned int T = X_s.n_cols; + unsigned int S = gamma_A_raw.n_slices; + arma::vec loglik(N); + arma::mat log_alpha(S, T); + arma::mat log_beta(S, T); + arma::mat log_py(S, T); + arma::vec log_Pi(S); + arma::cube log_A(S, S, T); + arma::field log_B(C); + arma::mat grad_pi(S - 1, X_i.n_rows, arma::fill::zeros); + arma::cube grad_A(S - 1, X_s.n_rows, S, arma::fill::zeros); + arma::field grad_B(C); + for (unsigned int c = 0; c < C; c++) { + grad_B(c) = arma::cube(M(c) - 1, X_o.n_rows, S, arma::fill::zeros); + } + arma::vec gradvec_S(S); + arma::mat gradmat_S(S, S); + arma::mat A(S, S); + for (unsigned int i = 0; i < N; i++) { + if (iv_pi || i == 0) { + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + } + if (iv_A || i == 0) { + log_A = get_A(gamma_A_raw, X_s.slice(i), true, tv_A); + } + if (iv_B || i == 0) { + log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true, tv_B); + } + for (unsigned int t = 0; t < Ti[i]; t++) { + for (unsigned int s = 0; s < S; s++) { + log_py(s, t) = 0; + for (unsigned int c = 0; c < C; c++) { + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); + } + } + } + log_alpha = univariate_forward_nhmm(log_Pi, log_A, log_py); + log_beta = univariate_backward_nhmm(log_A, log_py); + double ll = logSumExp(log_alpha.col(T - 1)); + loglik(i) = ll; + // gradient wrt gamma_pi + // d loglik / d pi + gradvec_S = exp(log_py.col(0) + log_beta.col(0) - ll); + // d pi / d gamma_pi + arma::vec Pi = exp(log_Pi); + gradmat_S = -Pi * Pi.t(); + gradmat_S.diag() += Pi; + grad_pi += gradmat_S.rows(1, S - 1) * gradvec_S * X_i.col(i).t(); + // gradient wrt gamma_A + for (unsigned int t = 0; t < (Ti[i] - 1); t++) { + A = exp(log_A.slice(t)); + for (unsigned int s = 0; s < S; s++) { + // d loglik / d a_s + gradvec_S = exp(log_alpha(s, t) + log_py.col(t + 1) + + log_beta.col(t + 1) - ll); + // d a_s / d gamma_A + gradmat_S = -A.row(s).t() * A.row(s); + gradmat_S.diag() += A.row(s); + grad_A.slice(s) += gradmat_S.rows(1, S - 1) * + gradvec_S * X_s.slice(i).col(t).t(); + } + } + for (unsigned int c = 0; c < C; c++) { + arma::mat gradmat_M(M(c), M(c)); + arma::rowvec Brow(M(c)); + double logpy; + for (unsigned int s = 0; s < S; s++) { + if (obs(c, 0, i) < M(c)) { + Brow = exp(log_B(c).slice(0).row(s).cols(0, M(c) - 1)); + gradmat_M = -Brow.t() * Brow; + gradmat_M.diag() += Brow; + logpy = 0; + for (unsigned int cc = 0; cc < C; cc++) { + if (cc != c) { + logpy += log_B(cc)(s, obs(cc, 0, i), 0); + } + } + double grad = exp(log_Pi(s) + logpy + log_beta(s, 0) - ll); + grad_B(c).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * + grad * X_o.slice(i).col(0).t(); + } + for (unsigned int t = 0; t < (Ti[i] - 1); t++) { + if (obs(c, t + 1, i) < M(c)) { + Brow = exp(log_B(c).slice(t + 1).row(s).cols(0, M(c) - 1)); + gradmat_M = -Brow.t() * Brow; + gradmat_M.diag() += Brow; + logpy = 0; + for (unsigned int cc = 0; cc < C; cc++) { + if (cc != c) { + logpy += log_B(cc)(s, obs(cc, t + 1, i), t + 1); + } + } + double grad = arma::accu( + exp(log_alpha.col(t) + log_A.slice(t).col(s) + + logpy + log_beta(s, t + 1) - ll)); + grad_B(c).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * + grad * X_o.slice(i).col(t + 1).t(); + } + } + } + } + } + return Rcpp::List::create( + Rcpp::Named("loglik") = sum(loglik), + Rcpp::Named("gradient_pi") = Rcpp::wrap(grad_pi), + Rcpp::Named("gradient_A") = Rcpp::wrap(grad_A), + Rcpp::Named("gradient_B") = Rcpp::wrap(grad_B) + ); } // [[Rcpp::export]] @@ -235,7 +236,8 @@ Rcpp::List log_objective_mnhmm_singlechannel( const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, - const bool tv_A, const bool tv_B, const bool iv_omega) { + const bool tv_A, const bool tv_B, const bool iv_omega, + const arma::uvec& Ti) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -285,7 +287,7 @@ Rcpp::List log_objective_mnhmm_singlechannel( if (iv_B || i == 0) { log_B(d) = get_B(gamma_B_raw(d), X_o.slice(i), true, true, tv_B); } - for (unsigned int t = 0; t < T; t++) { + for (unsigned int t = 0; t < Ti[i]; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t, d) = log_B(d)(s, obs(t, i), t); } @@ -369,7 +371,7 @@ Rcpp::List log_objective_mnhmm_multichannel( const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, - const bool iv_omega) { + const bool iv_omega, const arma::uvec& Ti) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -422,7 +424,7 @@ Rcpp::List log_objective_mnhmm_multichannel( true, tv_B ); } - for (unsigned int t = 0; t < T; t++) { + for (unsigned int t = 0; t < Ti[i]; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t, d) = 0; for (unsigned int c = 0; c < C; c++) { diff --git a/tests/testthat/test-gradients.R b/tests/testthat/test-gradients.R index 90beeec4..04626367 100644 --- a/tests/testthat/test-gradients.R +++ b/tests/testthat/test-gradients.R @@ -17,12 +17,13 @@ test_that("Gradients for singlechannel-NHMM are correct", { time = rep(1:n_time, each = n_id), id = rep(1:n_id, n_time) ) - data <- data[-10L, ] - data$y[10:15] <- NA - data$x[12] <- NA + data[data$id < 3 & data$time > 6, c("y", "x", "z")] <- NA + data[data$time > 9, c("y", "x", "z")] <- NA + data$x[12:15] <- 0 + data$y[10:25] <- NA model <- build_nhmm( "y", S, initial_formula = ~ x, transition_formula = ~z, - emission_formula = ~ z, data = data, time = "time", id = "id") + emission_formula = ~ x, data = data, time = "time", id = "id") n_i <- attr(model, "np_pi") n_s <- attr(model, "np_A") @@ -47,7 +48,7 @@ test_that("Gradients for singlechannel-NHMM are correct", { gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - obs, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + obs, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)$loglik } g <- function(pars) { @@ -60,7 +61,7 @@ test_that("Gradients for singlechannel-NHMM are correct", { gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - obs, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + obs, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)[-1])) } expect_equal(g(pars), numDeriv::grad(f, pars)) }) @@ -92,10 +93,11 @@ test_that("Gradients for multichannel-NHMM are correct", { id = rep(1:n_id, n_time) ) - data <- data[-10L, ] - data$y1[10:15] <- NA - data$y2[12:20] <- NA - data$x[12] <- NA + data[data$id < 3 & data$time > 6, c("y1", "y2", "x", "z")] <- NA + data[data$time > 9, c("y1", "y2", "x", "z")] <- NA + data$x[12:15] <- 0 + data$y1[10:25] <- NA + data$y2[10:35] <- NA model <- build_nhmm( c("y1", "y2"), S, initial_formula = ~ x, transition_formula = ~z, @@ -123,7 +125,7 @@ test_that("Gradients for multichannel-NHMM are correct", { gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - obs, M, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)$loglik } g <- function(pars) { @@ -136,7 +138,7 @@ test_that("Gradients for multichannel-NHMM are correct", { gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, - obs, M, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)[-1])) } expect_equal(g(pars), numDeriv::grad(f, pars)) }) @@ -161,9 +163,10 @@ test_that("Gradients for singlechannel-NHMM are correct", { time = rep(1:n_time, each = n_id), id = rep(1:n_id, n_time) ) - data <- data[-10L, ] - data$y[10:15] <- NA - data$x[12] <- NA + data[data$id < 3 & data$time > 6, c("y", "x", "z")] <- NA + data[data$time > 9, c("y", "x", "z")] <- NA + data$x[12:15] <- 0 + data$y[10:25] <- NA model <- build_mnhmm( "y", S, D, initial_formula = ~ x, transition_formula = ~z, emission_formula = ~ z, cluster_formula = ~ z, data = data, @@ -199,7 +202,7 @@ test_that("Gradients for singlechannel-NHMM are correct", { gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, - obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)$loglik } g <- function(pars) { @@ -216,10 +219,11 @@ test_that("Gradients for singlechannel-NHMM are correct", { gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, - obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)[-1])) } expect_equal(g(pars), numDeriv::grad(f, pars)) }) + test_that("Gradients for multichannel-MNHMM are correct", { set.seed(123) M <- c(2, 5) @@ -248,9 +252,11 @@ test_that("Gradients for multichannel-MNHMM are correct", { id = rep(1:n_id, n_time) ) - data <- data[-10L, ] - data$y1[10:15] <- NA - data$x[12] <- NA + data[data$id < 3 & data$time > 6, c("y1", "y2", "x", "z")] <- NA + data[data$time > 9, c("y1", "y2", "x", "z")] <- NA + data$x[12:15] <- 0 + data$y1[10:25] <- NA + data$y2[10:35] <- NA model <- build_mnhmm( c("y1", "y2"), S, D, initial_formula = ~ x, transition_formula = ~z, @@ -289,7 +295,7 @@ test_that("Gradients for multichannel-MNHMM are correct", { gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, - obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)$loglik } g <- function(pars) { @@ -309,7 +315,7 @@ test_that("Gradients for multichannel-MNHMM are correct", { gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, - obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, model$sequence_lengths)[-1])) } expect_equal(g(pars), numDeriv::grad(f, pars)) })