diff --git a/NAMESPACE b/NAMESPACE index 47cf9aed..f2d4210b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,8 @@ S3method("cluster_names<-",mhmm) S3method("cluster_names<-",mnhmm) S3method("state_names<-",hmm) S3method("state_names<-",mhmm) +S3method("state_names<-",mnhmm) +S3method("state_names<-",nhmm) S3method(ame,mnhmm) S3method(ame,nhmm) S3method(cluster_names,mhmm) @@ -14,8 +16,22 @@ S3method(forward_backward,hmm) S3method(forward_backward,mhmm) S3method(forward_backward,mnhmm) S3method(forward_backward,nhmm) +S3method(get_cluster_probs,mhmm) +S3method(get_cluster_probs,mnhmm) +S3method(get_emission_probs,hmm) +S3method(get_emission_probs,mhmm) +S3method(get_emission_probs,mnhmm) +S3method(get_emission_probs,nhmm) +S3method(get_initial_probs,hmm) +S3method(get_initial_probs,mhmm) +S3method(get_initial_probs,mnhmm) +S3method(get_initial_probs,nhmm) S3method(get_probs,mnhmm) S3method(get_probs,nhmm) +S3method(get_transition_probs,hmm) +S3method(get_transition_probs,mhmm) +S3method(get_transition_probs,mnhmm) +S3method(get_transition_probs,nhmm) S3method(hidden_paths,hmm) S3method(hidden_paths,mhmm) S3method(hidden_paths,mnhmm) @@ -41,6 +57,8 @@ S3method(print,summary.mnhmm) S3method(print,summary.nhmm) S3method(state_names,hmm) S3method(state_names,mhmm) +S3method(state_names,mnhmm) +S3method(state_names,nhmm) S3method(summary,mhmm) S3method(summary,mnhmm) S3method(summary,nhmm) @@ -61,7 +79,11 @@ export(estimate_mnhmm) export(estimate_nhmm) export(fit_model) export(forward_backward) +export(get_cluster_probs) +export(get_emission_probs) +export(get_initial_probs) export(get_probs) +export(get_transition_probs) export(gridplot) export(hidden_paths) export(mc_to_sc) diff --git a/R/RcppExports.R b/R/RcppExports.R index f4a63786..d5f0f215 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -61,12 +61,12 @@ get_pi <- function(gamma_raw, X, logspace) { .Call(`_seqHMM_get_pi`, gamma_raw, X, logspace) } -get_A <- function(gamma_raw, X, logspace) { - .Call(`_seqHMM_get_A`, gamma_raw, X, logspace) +get_A <- function(gamma_raw, X, logspace, tv) { + .Call(`_seqHMM_get_A`, gamma_raw, X, logspace, tv) } -get_B <- function(gamma_raw, X, M, logspace, add_missing) { - .Call(`_seqHMM_get_B`, gamma_raw, X, M, logspace, add_missing) +get_B <- function(gamma_raw, X, M, logspace, add_missing, tv) { + .Call(`_seqHMM_get_B`, gamma_raw, X, M, logspace, add_missing, tv) } logLikHMM <- function(transition, emission, init, obs, threads) { @@ -109,20 +109,20 @@ log_objective <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymbo .Call(`_seqHMM_log_objective`, transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, threads) } -log_objective_nhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs) { - .Call(`_seqHMM_log_objective_nhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs) +log_objective_nhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B) { + .Call(`_seqHMM_log_objective_nhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B) } -log_objective_nhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M) { - .Call(`_seqHMM_log_objective_nhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M) +log_objective_nhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B) { + .Call(`_seqHMM_log_objective_nhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B) } -log_objective_mnhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs) { - .Call(`_seqHMM_log_objective_mnhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs) +log_objective_mnhmm_singlechannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) { + .Call(`_seqHMM_log_objective_mnhmm_singlechannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) } -log_objective_mnhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M) { - .Call(`_seqHMM_log_objective_mnhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M) +log_objective_mnhmm_multichannel <- function(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) { + .Call(`_seqHMM_log_objective_mnhmm_multichannel`, gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega) } log_objectivex <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, coef, X, numberOfStates, threads) { diff --git a/R/ame.R b/R/ame.R index b0b155e7..c709a175 100644 --- a/R/ame.R +++ b/R/ame.R @@ -72,9 +72,9 @@ ame.nhmm <- function( } # use same RNG seed so that the same samples of coefficients are drawn newdata[[variable]] <- values[1] - pred <- predict(model, newdata, dontchange_colnames = TRUE) + pred <- get_probs(model, newdata) newdata[[variable]] <- values[2] - pred2 <- predict(model, newdata, dontchange_colnames = TRUE) + pred2 <- get_probs(model, newdata) pars <- c("initial_probs", "transition_probs", "emission_probs") for (i in pars) { pred[[i]]$estimate <- pred[[i]]$estimate - pred2[[i]]$estimate @@ -170,10 +170,10 @@ ame.mnhmm <- function( } # use same RNG seed so that the same samples of coefficients are drawn newdata[[variable]] <- values[1] - pred <- predict(model, newdata, dontchange_colnames = TRUE) + pred <- get_probs(model, newdata) newdata[[variable]] <- values[2] - pred2 <- predict(model, newdata, dontchange_colnames = TRUE) + pred2 <- get_probs(model, newdata) pars <- c("initial_probs", "transition_probs", "emission_probs", "cluster_probs") for (i in pars) { @@ -186,18 +186,18 @@ ame.mnhmm <- function( "states" = c("cluster", "time", channel, "observation"), "sequences" = c("cluster", "time", "state", channel, "observation") ) - + out <- list() out$initial_probs <- pred$initial_probs |> - dplyr::group_by(cluster, state) |> - dplyr::summarise(estimate = mean(estimate)) |> - dplyr::ungroup() + dplyr::group_by(cluster, state) |> + dplyr::summarise(estimate = mean(estimate)) |> + dplyr::ungroup() out$transition_probs <- pred$transition_probs |> - dplyr::group_by(cluster, time, state_from, state_to) |> - dplyr::summarise(estimate = mean(estimate)) |> - dplyr::ungroup() |> - dplyr::rename(!!time := time) + dplyr::group_by(cluster, time, state_from, state_to) |> + dplyr::summarise(estimate = mean(estimate)) |> + dplyr::ungroup() |> + dplyr::rename(!!time := time) out$emission_probs <- pred$emission_probs |> dplyr::group_by(dplyr::pick(group_by_B)) |> diff --git a/R/build_mnhmm.R b/R/build_mnhmm.R index ba59aa3b..ecbd3483 100644 --- a/R/build_mnhmm.R +++ b/R/build_mnhmm.R @@ -27,8 +27,19 @@ build_mnhmm <- function( out$model, class = "mnhmm", nobs = attr(out$observations, "nobs"), - df = out$extras$n_pars, + df = out$extras$np_omega + + n_clusters * (out$extras$np_pi + out$extras$np_A + out$extras$np_B), type = paste0(out$extras$multichannel, "mnhmm"), - intercept_only = out$extras$intercept_only + intercept_only = out$extras$intercept_only, + iv_pi = out$extras$iv_pi, + iv_A = out$extras$iv_A, + iv_B = out$extras$iv_B, + iv_omega = out$extras$iv_omega, + tv_A = out$extras$tv_A, + tv_B = out$extras$tv_B, + np_pi = n_clusters * out$extras$np_pi, + np_A = n_clusters * out$extras$np_A, + np_B = n_clusters * out$extras$np_B, + np_omega = out$extras$np_omega ) } diff --git a/R/build_nhmm.R b/R/build_nhmm.R index 72f00a56..d30642db 100644 --- a/R/build_nhmm.R +++ b/R/build_nhmm.R @@ -14,8 +14,16 @@ build_nhmm <- function( out$model, class = "nhmm", nobs = attr(out$observations, "nobs"), - df = out$extras$n_pars, + df = out$extras$np_pi + out$extras$np_A + out$extras$np_B, type = paste0(out$extras$multichannel, "nhmm"), - intercept_only = out$extras$intercept_only + intercept_only = out$extras$intercept_only, + iv_pi = out$extras$iv_pi, + iv_A = out$extras$iv_A, + iv_B = out$extras$iv_B, + tv_A = out$extras$tv_A, + tv_B = out$extras$tv_B, + np_pi = out$extras$np_pi, + np_A = out$extras$np_A, + np_B = out$extras$np_B ) } diff --git a/R/coef.R b/R/coef.R index 97161e98..41eacd75 100644 --- a/R/coef.R +++ b/R/coef.R @@ -43,47 +43,47 @@ coef.nhmm <- function(object, probs = c(0.025, 0.975), ...) { ) } - stopifnot_( - checkmate::test_numeric( - x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L - ), - "Argument {.arg probs} must be a {.cls numeric} vector with values - between 0 and 1." - ) - p_i <- length(gamma_pi_raw) - p_s <- length(gamma_A_raw) - p_o <- length(gamma_B_raw) - sds <- try( - diag(solve(-object$estimation_results$hessian)), - silent = TRUE - ) - if (inherits(sds, "try-error")) { - warning_( - paste0( - "Standard errors could not be computed due to singular Hessian. ", - "Confidence intervals will not be provided." - ) - ) - sds <- rep(NA, p_i + p_s + p_o) - } else { - if (any(sds < 0)) { - warning_( - paste0( - "Standard errors could not be computed due to negative variances. ", - "Confidence intervals will not be provided." - ) - ) - sds <- rep(NA, p_i + p_s + p_o) - } else { - sds <- sqrt(sds) - } - } - for(i in seq_along(probs)) { - q <- qnorm(probs[i]) - gamma_pi[paste0("q", 100 * probs[i])] <- gamma_pi_raw + q * sds[seq_len(p_i)] - gamma_A[paste0("q", 100 * probs[i])] <- gamma_A_raw + q * sds[p_i + seq_len(p_s)] - gamma_B[paste0("q", 100 * probs[i])] <- gamma_B_raw + q * sds[p_i + p_s + seq_len(p_o)] - } + # stopifnot_( + # checkmate::test_numeric( + # x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L + # ), + # "Argument {.arg probs} must be a {.cls numeric} vector with values + # between 0 and 1." + # ) + # p_i <- length(gamma_pi_raw) + # p_s <- length(gamma_A_raw) + # p_o <- length(gamma_B_raw) + # sds <- try( + # diag(solve(-object$estimation_results$hessian)), + # silent = TRUE + # ) + # if (inherits(sds, "try-error")) { + # warning_( + # paste0( + # "Standard errors could not be computed due to singular Hessian. ", + # "Confidence intervals will not be provided." + # ) + # ) + # sds <- rep(NA, p_i + p_s + p_o) + # } else { + # if (any(sds < 0)) { + # warning_( + # paste0( + # "Standard errors could not be computed due to negative variances. ", + # "Confidence intervals will not be provided." + # ) + # ) + # sds <- rep(NA, p_i + p_s + p_o) + # } else { + # sds <- sqrt(sds) + # } + # } + # for(i in seq_along(probs)) { + # q <- qnorm(probs[i]) + # gamma_pi[paste0("q", 100 * probs[i])] <- gamma_pi_raw + q * sds[seq_len(p_i)] + # gamma_A[paste0("q", 100 * probs[i])] <- gamma_A_raw + q * sds[p_i + seq_len(p_s)] + # gamma_B[paste0("q", 100 * probs[i])] <- gamma_B_raw + q * sds[p_i + p_s + seq_len(p_o)] + # } list( gamma_pinitial = gamma_pi, @@ -101,38 +101,40 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) { gamma_pi_raw <- unlist(object$coefficients$gamma_pi_raw) K_i <- length(object$coef_names_initial) gamma_pi <- data.frame( - state = object$state_names[-1], + state = unlist(lapply(object$state_names, function(x) x[-1])), parameter = rep(object$coef_names_initial, each = (S - 1)), - estimate = gamma_pi_raw, + estimate = unlist(gamma_pi_raw), cluster = rep(object$cluster_names, each = (S - 1) * K_i) ) gamma_A_raw <- unlist(object$coefficients$gamma_A_raw) K_s <- length(object$coef_names_transition) gamma_A <- data.frame( - state_from = object$state_names, - state_to = rep(object$state_names[-1], each = S), + state_from = unlist(object$state_names), + state_to = rep( + unlist(lapply(object$state_names, function(x) x[-1])), + each = S + ), parameter = rep(object$coef_names_transition, each = S * (S - 1)), - estimate = gamma_A_raw, + estimate = unlist(gamma_A_raw), cluster = rep(object$cluster_names, each = (S - 1) * S * K_s) ) - gamma_B_raw <- unlist(object$coefficients$gamma_B_raw) K_o <- length(object$coef_names_emission) if (object$n_channels == 1) { gamma_B <- data.frame( - state = object$state_names, + state = unlist(object$state_names), observations = rep(object$symbol_names[-1], each = S), parameter = rep(object$coef_names_emission, each = S * (M - 1)), - estimate = gamma_B_raw, + estimate = unlist(gamma_B_raw), cluster = rep(object$cluster_names, each = S * (S - 1) * K_o) ) } else { gamma_B <- data.frame( - state = object$state_names, + state = unlist(object$state_names), observations = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S), parameter = unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$coef_names_emission, each = S * (M[i] - 1)) })), - estimate = gamma_B_raw, + estimate = unlist(gamma_B_raw), cluster = unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$cluster_names, each = S * (M[i] - 1) * K_o) })) @@ -145,36 +147,36 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) { estimate = gamma_omega_raw ) - stopifnot_( - checkmate::test_numeric( - x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L - ), - "Argument {.arg probs} must be a {.cls numeric} vector with values - between 0 and 1." - ) - p_i <- length(gamma_pi_raw) - p_s <- length(gamma_A_raw) - p_o <- length(gamma_B_raw) - p_d <- length(gamma_omega_raw) - sds <- try( - sqrt(diag(solve(-object$estimation_results$hessian))), - silent = TRUE - ) - if (inherits(sds, "try-error")) { - warning_( - "Standard errors could not be computed due to singular Hessian. - Confidence intervals will not be provided." - ) - sds <- rep(NA, p_i + p_s + p_o + p_d) - } - - for(i in seq_along(probs)) { - q <- qnorm(probs[i]) - gamma_pi[paste0("q", 100 * probs[i])] <- gamma_pi_raw + q * sds[seq_len(p_i)] - gamma_A[paste0("q", 100 * probs[i])] <- gamma_A_raw + q * sds[p_i + seq_len(p_s)] - gamma_B[paste0("q", 100 * probs[i])] <- gamma_B_raw + q * sds[p_i + p_s + seq_len(p_o)] - gamma_omega[paste0("q", 100 * probs[i])] <- gamma_omega_raw + q * sds[p_i + p_s + p_o + seq_len(p_d)] - } + # stopifnot_( + # checkmate::test_numeric( + # x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L + # ), + # "Argument {.arg probs} must be a {.cls numeric} vector with values + # between 0 and 1." + # ) + # p_i <- length(gamma_pi_raw) + # p_s <- length(gamma_A_raw) + # p_o <- length(gamma_B_raw) + # p_d <- length(gamma_omega_raw) + # sds <- try( + # sqrt(diag(solve(-object$estimation_results$hessian))), + # silent = TRUE + # ) + # if (inherits(sds, "try-error")) { + # warning_( + # "Standard errors could not be computed due to singular Hessian. + # Confidence intervals will not be provided." + # ) + # sds <- rep(NA, p_i + p_s + p_o + p_d) + # } + # + # for(i in seq_along(probs)) { + # q <- qnorm(probs[i]) + # gamma_pi[paste0("q", 100 * probs[i])] <- gamma_pi_raw + q * sds[seq_len(p_i)] + # gamma_A[paste0("q", 100 * probs[i])] <- gamma_A_raw + q * sds[p_i + seq_len(p_s)] + # gamma_B[paste0("q", 100 * probs[i])] <- gamma_B_raw + q * sds[p_i + p_s + seq_len(p_o)] + # gamma_omega[paste0("q", 100 * probs[i])] <- gamma_omega_raw + q * sds[p_i + p_s + p_o + seq_len(p_d)] + # } list( gamma_initial = gamma_pi, diff --git a/R/create_base_nhmm.R b/R/create_base_nhmm.R index 416e22cc..3215b8e9 100644 --- a/R/create_base_nhmm.R +++ b/R/create_base_nhmm.R @@ -10,16 +10,6 @@ create_base_nhmm <- function(observations, data, time, id, n_states, !missing(n_states) && checkmate::test_int(x = n_states, lower = 1L), "Argument {.arg n_states} must be a single positive integer." ) - n_states <- as.integer(n_states) - if (is.null(state_names)) { - state_names <- paste("State", seq_len(n_states)) - } else { - stopifnot_( - length(state_names) == n_states, - "Length of {.arg state_names} is not equal to the number of hidden - states." - ) - } stopifnot_( inherits(initial_formula, "formula"), "Argument {.arg initial_formula} must be a {.cls formula} object.") @@ -34,6 +24,49 @@ create_base_nhmm <- function(observations, data, time, id, n_states, stopifnot_( !mixture || inherits(cluster_formula, "formula"), "Argument {.arg cluster_formula} must be a {.cls formula} object.") + + n_states <- as.integer(n_states) + if (is.null(state_names)) { + state_names <- paste("State", seq_len(n_states)) + if (mixture) { + state_names <- replicate(n_cluster, state_names, simplify = FALSE) + names(state_names) <- cluster_names + } + } else { + if (mixture) { + names_is_vec <- !is.list(state_names) && length(state_names) == n_states + stopifnot_( + length(state_names) == n_clusters || names_is_vec, + paste0( + "For MNHMMs, {.arg state_names} should be a list of length ", + "{n_clusters}, the number of clusters, or a vector of length + {n_states}, number of hidden states." + ) + ) + if (names_is_vec) { + state_names <- rep(n_cluster, state_names, simplify = FALSE) + } else { + lapply(seq_len(n_states), function(i) { + stopifnot_( + length(state_names[[i]]) == n_states, + paste0( + "Length of {.arg state_names[[{i}]]} is not equal to ", + "{n_states}, the number of hidden states." + ) + ) + }) + } + } else { + stopifnot_( + length(state_names) == n_states, + paste0( + "Length of {.arg state_names} is not equal to {n_states}, the number", + " of hidden states." + ) + ) + } + } + icp_only_i <- intercept_only(initial_formula) icp_only_s <- intercept_only(transition_formula) icp_only_o <- intercept_only(emission_formula) @@ -95,7 +128,7 @@ create_base_nhmm <- function(observations, data, time, id, n_states, cluster_formula, data, n_sequences, n_clusters, time, id ) coefficients <- create_initial_values( - list(pi = NULL, A = NULL, B = NULL, omega = NULL), + list(gamma_pi = NULL, gamma_A = NULL, gamma_B = NULL, gamma_omega = NULL), n_states, n_symbols, 0, length(pi$coef_names), length(A$coef_names), length(B$coef_names), length(omega$coef_names), n_clusters @@ -106,18 +139,17 @@ create_base_nhmm <- function(observations, data, time, id, n_states, n_states, n_symbols, 0, length(pi$coef_names), length(A$coef_names), length(B$coef_names) ) + omega <- list(n_pars = 0, iv = FALSE) } - n_pars <- if (mixture) omega$n_pars else 0 - n_pars <- n_pars + n_clusters * (pi$n_pars + A$n_pars + B$n_pars) list( model = list( observations = observations, time_variable = if (is.null(time)) "time" else time, id_variable = if (is.null(id)) "id" else id, - X_initial = t(pi$X), - X_transition = aperm(A$X, c(3, 1, 2)), - X_emission = aperm(B$X, c(3, 1, 2)), - X_cluster = if(mixture) t(omega$X) else NULL, + X_initial = pi$X, + X_transition = A$X, + X_emission = B$X, + X_cluster = if(mixture) omega$X else NULL, initial_formula = pi$formula, transition_formula = A$formula, emission_formula = B$formula, @@ -140,8 +172,17 @@ create_base_nhmm <- function(observations, data, time, id, n_states, coef_names_cluster = if(mixture) omega$coef_names else NULL ), extras = list( - n_pars = n_pars, + np_pi = pi$n_pars, + np_A = A$n_pars, + np_B = B$n_pars, + np_omega = omega$n_pars, multichannel = ifelse(n_channels > 1, "multichannel_", ""), + iv_pi = pi$iv, + iv_A = A$iv, + iv_B = B$iv, + iv_omega = omega$iv, + tv_A = A$tv, + tv_B = B$tv, intercept_only = icp_only_i && icp_only_s && icp_only_o && icp_only_d ) ) diff --git a/R/create_initial_values.R b/R/create_initial_values.R index e2f6dd3d..32f1fd0c 100644 --- a/R/create_initial_values.R +++ b/R/create_initial_values.R @@ -34,10 +34,13 @@ create_gamma_B_raw_mnhmm <- function(x, S, M, K, D) { }) } create_gamma_multichannel_B_raw_mnhmm <- function(x, S, M, K, D) { - n <- (sum(M) - 1) * K * S - lapply(seq_len(D), function(i) { - create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K) - }) + n <- sum((M - 1) * K * S) + unlist( + lapply(seq_len(D), function(i) { + create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K) + }), + recursive = FALSE + ) } create_gamma_omega_raw_mnhmm <- function(x, D, K) { matrix(x, D - 1, K) @@ -106,14 +109,14 @@ create_gamma_B_inits <- function(x, S, M, K, init_sd = 0, D = 1) { if (D > 1) { if (is.null(x)) { create_gamma_multichannel_B_raw_mnhmm( - rnorm((M - 1) * K * S * D, sd = init_sd), S, M, K, D + rnorm(sum((M - 1) * K * S) * D, sd = init_sd), S, M, K, D ) } else { stopifnot_( - length(x) == (M - 1) * K * S * D, + length(x) == sum((M - 1) * K * S) * D, paste0( "Number of initial values for {.val gamma_B} is not equal to ", - "(M - 1) * K * S * D = {(M - 1) * K * S * D}." + "sum((M - 1) * K * S) * D = {sum((M - 1) * K * S) * D}." ) ) create_gamma_multichannel_B_raw_mnhmm(x, S, M, K, D) @@ -121,14 +124,14 @@ create_gamma_B_inits <- function(x, S, M, K, init_sd = 0, D = 1) { } else { if (is.null(x)) { create_gamma_multichannel_B_raw_nhmm( - rnorm((M - 1) * K * S, sd = init_sd), S, M, K + rnorm(sum((M - 1) * K * S), sd = init_sd), S, M, K ) } else { stopifnot_( - length(x) == (M - 1) * K * S, + length(x) == sum((M - 1) * K * S), paste0( "Number of initial values for {.val gamma_B} is not equal to ", - "(M - 1) * K * S = {(M - 1) * K * S}." + "sum((M - 1) * K * S) = {sum((M - 1) * K * S)}." ) ) create_gamma_multichannel_B_raw_nhmm(x, S, M, K) diff --git a/R/estimate_mnhmm.R b/R/estimate_mnhmm.R index 85e46eef..9dca8e50 100644 --- a/R/estimate_mnhmm.R +++ b/R/estimate_mnhmm.R @@ -53,7 +53,8 @@ estimate_mnhmm <- function( ) stopifnot_( checkmate::test_flag(x = store_data), - "Argument {.arg store_data} must be a single {.cls logical} value.") + "Argument {.arg store_data} must be a single {.cls logical} value." + ) if (store_data) { model$data <- data } diff --git a/R/estimate_nhmm.R b/R/estimate_nhmm.R index 2f3bf896..8ff20512 100644 --- a/R/estimate_nhmm.R +++ b/R/estimate_nhmm.R @@ -84,7 +84,8 @@ estimate_nhmm <- function( ) stopifnot_( checkmate::test_flag(x = store_data), - "Argument {.arg store_data} must be a single {.cls logical} value.") + "Argument {.arg store_data} must be a single {.cls logical} value." + ) if (store_data) { model$data <- data } diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 8c0e9c0a..2a04d463 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -29,10 +29,16 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { if (is.null(inits$emission_probs)) inits$emission_probs <- NULL } - n_i <- length(unlist(model$coefficients$gamma_pi_raw)) - n_s <- length(unlist(model$coefficients$gamma_A_raw)) - n_o <- length(unlist(model$coefficients$gamma_B_raw)) - n_d <- length(model$coefficients$gamma_omega_raw) + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + n_d <- attr(model, "np_omega") + iv_pi <- attr(model, "iv_pi") + iv_A <- attr(model, "iv_A") + iv_B <- attr(model, "iv_B") + iv_omega <- attr(model, "iv_omega") + tv_A <- attr(model, "tv_A") + tv_B <- attr(model, "tv_B") X_i <- model$X_initial X_s <- model$X_transition X_o <- model$X_emission @@ -58,11 +64,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d ) out <- log_objective_mnhmm_singlechannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - gamma_omega_raw, X_d, - obs) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, + gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega + ) list(objective = - out$loglik, gradient = - unlist(out[-1])) } @@ -77,11 +81,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d ) out <- forward_mnhmm_singlechannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - gamma_omega_raw, X_d, - obs) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, + gamma_omega_raw, X_d, obs + ) - sum(apply(out[, T_, ], 2, logSumExp)) } @@ -94,21 +96,16 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + seq_len(n_s)], S, K_s, D ) - gamma_B_raw <- unlist( - create_gamma_multichannel_B_raw_mnhmm( - pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D - ), - recursive = FALSE + gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d ) out <- log_objective_mnhmm_multichannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - gamma_omega_raw, X_d, - obs, M) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, + gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega + ) list(objective = - out$loglik, gradient = - unlist(out[-1])) } @@ -117,13 +114,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) gamma_A_raw <- create_gamma_A_raw_mnhmm( pars[n_i + seq_len(n_s)], S, K_s, D - ) - gamma_B_raw <- unlist( - create_gamma_multichannel_B_raw_mnhmm( - pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D - ), - recursive = FALSE ) + gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d ) @@ -187,10 +181,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars <- out$solution model$coefficients$gamma_pi_raw <- create_gamma_pi_raw_mnhmm( pars[seq_len(n_i)], S, K_i, D - ) + ) model$coefficients$gamma_A_raw <- create_gamma_A_raw_mnhmm( pars[n_i + seq_len(n_s)], S, K_s, D - ) + ) if (model$n_channels == 1L) { model$coefficients$gamma_B_raw <- create_gamma_B_raw_mnhmm( pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index 78bf390e..c3b3a7db 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -29,9 +29,14 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { if (is.null(inits$emission_probs)) inits$emission_probs <- NULL } - n_i <- length(model$coefficients$gamma_pi_raw) - n_s <- length(model$coefficients$gamma_A_raw) - n_o <- length(unlist(model$coefficients$gamma_B_raw)) + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + iv_pi <- attr(model, "iv_pi") + iv_A <- attr(model, "iv_A") + iv_B <- attr(model, "iv_B") + tv_A <- attr(model, "tv_A") + tv_B <- attr(model, "tv_B") X_i <- model$X_initial X_s <- model$X_transition X_o <- model$X_emission @@ -52,10 +57,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + seq_len(n_o)], S, M, K_o ) out <- log_objective_nhmm_singlechannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - obs) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, + iv_pi, iv_A, iv_B, tv_A, tv_B + ) list(objective = - out$loglik, gradient = - unlist(out[-1])) } @@ -67,10 +71,8 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + seq_len(n_o)], S, M, K_o ) out <- forward_nhmm_singlechannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - obs) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs + ) - sum(apply(out[, T_, ], 2, logSumExp)) } @@ -84,10 +86,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + seq_len(n_o)], S, M, K_o ) out <- log_objective_nhmm_multichannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - obs, M) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, + iv_pi, iv_A, iv_B, tv_A, tv_B + ) list(objective = - out$loglik, gradient = - unlist(out[-1])) } @@ -99,10 +100,8 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + n_s + seq_len(n_o)], S, M, K_o ) out <- forward_nhmm_multichannel( - gamma_pi_raw, X_i, - gamma_A_raw, X_s, - gamma_B_raw, X_o, - obs, M) + gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M + ) - sum(apply(out[, T_, ], 2, logSumExp)) } diff --git a/R/forwardBackward.R b/R/forwardBackward.R index 4518f268..a3af786b 100644 --- a/R/forwardBackward.R +++ b/R/forwardBackward.R @@ -260,7 +260,7 @@ forward_backward.mnhmm <- function(model, forward_only = FALSE, } state_names <- paste0( rep(model$cluster_names, each = model$n_states), ": ", - model$state_names + unlist(model$state_names) ) dimnames(out$forward_probs) <- list( "state" = state_names, "time" = time_names, "id" = sequence_names diff --git a/R/get_probs.R b/R/get_probs.R index 6ce9723b..84bad714 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -1,9 +1,417 @@ -#' Get the Estimated Initial, Transition, and Emission Probabilities for NHMM -#' or MNHMM +#' @rdname initial_probs +#' @export +get_initial_probs <- function(model, ...) { + UseMethod("get_initial_probs", model) +} +#' @rdname transition_probs +#' @export +get_transition_probs <- function(model, ...) { + UseMethod("get_transition_probs", model) +} +#' @rdname emission_probs +#' @export +get_emission_probs <- function(model, ...) { + UseMethod("get_emission_probs", model) +} +#' @rdname cluster_probs +#' @export +get_cluster_probs <- function(model, ...) { + UseMethod("get_cluster_probs", model) +} +#' Extract the Initial State Probabilities of Hidden Markov Model +#' @rdname initial_probs +#' @export +get_initial_probs.nhmm <- function(model, ...) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + } + if (!attr(model, "iv_pi")) { + d <- data.frame( + id = rep(ids, each = model$n_states), + state = model$state_names, + estimate = get_pi( + model$coefficients$gamma_pi_raw, model$X_initial[, 1L], FALSE + ) + ) + } else { + d <- data.frame( + id = rep(ids, each = model$n_states), + state = model$state_names, + estimate = c(apply( + model$X_initial, 2, function(z) { + get_pi(model$coefficients$gamma_pi_raw, z, FALSE) + } + )) + ) + } + stats::setNames(d, c(model$id_variable, "state", "estimate")) +} +#' @rdname initial_probs +#' @export +get_initial_probs.mnhmm <- function(x) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + } + if (!attr(model, "iv_pi")) { + d <- do.call( + rbind, + lapply(seq_len(model$n_clusters), function(i) { + data.frame( + cluster = model$cluster_names[i], + id = rep(ids, each = model$n_states), + state = model$state_names[[i]], + estimate = get_pi( + model$coefficients$gamma_pi_raw[[i]], model$X_initial[, 1L], FALSE + ) + ) + }) + ) + } else { + d <- do.call( + rbind, + lapply(seq_len(model$n_clusters), function(i) { + data.frame( + cluster = model$cluster_names[i], + id = rep(ids, each = model$n_states), + state = model$state_names[[i]], + estimate = c(apply( + model$X_initial, 2, function(z) { + get_pi(model$coefficients$gamma_pi_raw[[i]], z, FALSE) + } + )) + ) + }) + ) + } + stats::setNames(d, c("cluster", model$id_variable, "state", "estimate")) +} +#' @rdname initial_probs +#' @export +get_initial_probs.hmm <- function(x) { + x$initial_probs +} +#' @rdname initial_probs +#' @export +get_initial_probs.mhmm <- function(x) { + x$initial_probs +} +#' Extract the State Transition Probabilities of Hidden Markov Model +#' @rdname transition_probs +#' @export +get_transition_probs.nhmm <- function(x) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + times <- colnames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + times <- colnames(model$observations[[1]]) + } + S <- model$n_states + T_ <- length(times) + if (!attr(model, "iv_A")) { + X <- matrix(model$X_transition[, 1L, ], ncol = model$length_of_sequences) + d <- data.frame( + id = rep(ids, each = S^2 * T_), + time = rep(times, each = S^2), + state_from = object$state_names, + state_to = rep(object$state_names, each = S), + estimate = get_A( + model$coefficients$gamma_A_raw, X, FALSE, attr(model, "tv_A") + ) + ) + } else { + d <- data.frame( + id = rep(ids, each = S^2 * T_), + time = rep(times, each = S^2), + state_from = object$state_names, + state_to = rep(object$state_names, each = S), + estimate = c(apply( + model$X_transition, 2, function(z) { + get_A( + model$coefficients$gamma_A_raw, matrix(z, ncol = T), FALSE, + attr(model, "tv_A") + ) + } + )) + ) + } + stats::setNames( + d, + c( + model$id_variable, model$time_variable, + "state_from", "state_to", "estimate" + ) + ) +} +#' @rdname transition_probs +#' @export +get_transition_probs.mnhmm <- function(x) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + times <- colnames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + times <- colnames(model$observations[[1]]) + } + S <- model$n_states + T_ <- length(times) + D <- model$n_clusters + if (!attr(model, "iv_A")) { + X <- matrix(model$X_transition[, 1L, ], ncol = model$length_of_sequences) + d <- do.call( + rbind, + lapply(seq_len(D), function(i) { + data.frame( + id = rep(ids, each = S^2 * T_), + time = rep(times, each = S^2), + state_from = object$state_names[[i]], + state_to = rep(object$state_names[[i]], each = S), + estimate = get_A( + model$coefficients$gamma_A_raw[[i]], X, FALSE, attr(model, "tv_A") + ) + ) + }) + ) + } else { + d <- do.call( + rbind, + lapply(seq_len(D), function(i) { + data.frame( + id = rep(ids, each = S^2 * T_), + time = rep(times, each = S^2), + state_from = object$state_names[[i]], + state_to = rep(object$state_names[[i]], each = S), + estimate = c(apply( + model$X_transition, 2, function(z) { + get_A( + model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T), FALSE, + attr(model, "tv_A") + ) + } + )) + ) + }) + ) + } + stats::setNames( + d, + c("cluster", model$id_variable, model$time_variable, + "state_from", "state_to", "estimate") + ) +} +#' @rdname transition_probs +#' @export +get_transition_probs.hmm <- function(x) { + x$transition_probs +} +#' @rdname transition_probs +#' @export +get_transition_probs.mhmm <- function(x) { + x$transition_probs +} +#' Extract the Emission Probabilities of Hidden Markov Model +#' @rdname emission_probs +#' @export +get_emission_probs.nhmm <- function(x) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + times <- colnames(model$observations) + symbol_names <- list(model$symbol_names) + model$coefficients$gamma_B_raw <- list(model$coefficients$gamma_B_raw) + } else { + ids <- rownames(model$observations[[1]]) + times <- colnames(model$observations[[1]]) + symbol_names <- model$symbol_names + } + S <- model$n_states + C <- model$n_channels + T_ <- length(times) + if (!attr(model, "iv_B")) { + X <- matrix(model$X_emission[, 1L, ], ncol = model$length_of_sequences) + d <- do.call( + rbind, + lapply(seq_len(C), function(i) { + data.frame( + id = rep(ids, each = S * M[i] * T_), + time = rep(times, each = S * M[i]), + state = model$state_names, + channel = model$channel_names[i], + observation = rep(symbol_names[[i]], each = S), + estimate = get_B( + model$coefficients$gamma_B_raw[[i]], X, FALSE, FALSE, attr(model, "tv_B") + ) + ) + }) + ) + } else { + d <- do.call( + rbind, + lapply(seq_len(C), function(i) { + data.frame( + id = rep(ids, each = S * M[i] * T_), + time = rep(times, each = S * M[i]), + state = model$state_names, + channel = model$channel_names[i], + observation = rep(symbol_names[[i]], each = S), + estimate = c(apply( + model$X_emission, 2, function(z) { + get_B( + model$coefficients$gamma_B_raw[[i]], matrix(z, ncol = T), FALSE, + FALSE, attr(model, "tv_B") + ) + } + )) + ) + }) + ) + } + stats::setNames( + d, + c(model$id_variable, model$time_variable, "state", "channel", + "observation", "estimate") + ) +} +#' @rdname emission_probs +#' @export +get_emission_probs.mnhmm <- function(x) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + times <- colnames(model$observations) + symbol_names <- list(model$symbol_names) + model$coefficients$gamma_B_raw <- list(model$coefficients$gamma_B_raw) + } else { + ids <- rownames(model$observations[[1]]) + times <- colnames(model$observations[[1]]) + symbol_names <- model$symbol_names + } + S <- model$n_states + C <- model$n_channels + D <- model$n_clusters + T_ <- length(times) + if (!attr(model, "iv_B")) { + X <- matrix(model$X_emission[, 1L, ], ncol = model$length_of_sequences) + d <- do.call( + rbind, + lapply(seq_len(D), function(j) { + do.call( + rbind, + lapply(seq_len(C), function(i) { + data.frame( + cluster = cluster_names[[j]], + id = rep(ids, each = S * M[i] * T_), + time = rep(times, each = S * M[i]), + state = model$state_names[[j]], + channel = model$channel_names[i], + observation = rep(symbol_names[[i]], each = S), + estimate = get_B( + model$coefficients$gamma_B_raw[[j]][[i]], X, FALSE, FALSE, + attr(model, "tv_B") + ) + ) + }) + ) + }) + ) + } else { + d <- do.call( + rbind, + lapply(seq_len(D), function(j) { + do.call( + rbind, + lapply(seq_len(C), function(i) { + data.frame( + cluster = cluster_names[[j]], + id = rep(ids, each = S * M[i] * T_), + time = rep(times, each = S * M[i]), + state = model$state_names[[j]], + channel = model$channel_names[i], + observation = rep(symbol_names[[i]], each = S), + estimate = c(apply( + model$X_emission, 2, function(z) { + get_B( + model$coefficients$gamma_B_raw[[j]][[i]], + matrix(z, ncol = T), FALSE, FALSE, attr(model, "tv_B") + ) + } + )) + ) + }) + ) + }) + ) + } + stats::setNames( + d, + c("cluster", model$id_variable, model$time_variable, "state", "channel", + "observation", "estimate") + ) +} +#' @rdname emission_probs +#' @export +get_emission_probs.hmm <- function(x) { + x$emission_probs +} +#' @rdname emission_probs +#' @export +get_emission_probs.mhmm <- function(x) { + x$emission_probs +} +#' Extract the Prior Cluster Probabilities of MHMM or MNHMM +#' +#' @param model An object of class `mnhmm` or `mhmm. +#' @rdname cluster_probs +#' @export +#' @seealso [posterior_cluster_probabilities()]. +get_cluster_probs.mnhmm <- function(model, ...) { + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + } + if (!attr(model, "iv_omega")) { + d <- data.frame( + cluster = model$cluster_names, + id = rep(ids, each = model$n_clusters), + estimate = c(get_omega( + model$coefficients$gamma_omega_raw, model$X_cluster[, 1L], FALSE + )) + ) + } else { + d <- data.frame( + cluster = model$cluster_names, + id = rep(ids, each = model$n_clusters), + estimate = c(apply( + model$X_cluster, 2, function(z) { + get_omega(model$coefficients$gamma_omega_raw, z, FALSE) + } + )) + ) + } + stats::setNames(d, c("cluster", model$id_variable, "estimate")) +} +#' @rdname cluster_probs +#' @export +get_cluster_probs.mhmm <- function(model, ...) { + pr <- exp(model$X %*% model$coefficients) + prior_cluster_probabilities <- pr / rowSums(pr) + data.frame( + cluster = model$cluster_names, + id = rep(ids, each = model$n_clusters), + estimate = c(t(prior_cluster_probabilities)) + ) +} +#' Get the Estimated Initial, Transition, Emission and (Prior) Cluster +#' Probabilities for NHMM or MNHMM #' #' @param model An object of class `nhmm` or `mnhmm`. #' @param newdata An optional data frame containing the new data to be used in #' computing the probabilities. +#' @param remove_voids Should the time points corresponding to `TraMineR`'s +#' void in the observed sequences be removed? Default is `TRUE`. #' @param ... Ignored. #' @rdname get_probs #' @export @@ -12,30 +420,95 @@ get_probs <- function(model, ...) { } #' @rdname get_probs #' @export -get_probs.nhmm <- function(model, newdata = NULL, ...) { - out <- predict(model, newdata) +get_probs.nhmm <- function(model, newdata = NULL, remove_voids = TRUE, ...) { + stopifnot_( + checkmate::test_flag(x = remove_voids), + "Argument {.arg remove_voids} must be a single {.cls logical} value." + ) + if (!is.null(newdata)) { + time <- object$time_variable + id <- object$id_variable + stopifnot_( + is.data.frame(newdata), + "Argument {.arg newdata} must be a {.cls data.frame} object." + ) + stopifnot_( + !is.null(newdata[[id]]), + "Can't find grouping variable {.var {id}} in {.arg newdata}." + ) + stopifnot_( + !is.null(newdata[[time]]), + "Can't find time index variable {.var {time}} in {.arg newdata}." + ) + object <- update(object, newdata = newdata) + } + S <- object$n_states + M <- object$n_symbols + C <- object$n_channels + N <- object$n_sequences + T_ <- object$length_of_sequences + out <- list( + initial_probs = get_initial_probs(object), + transition_probs = get_transition_probs(object), + emission_probs = get_emission_probs(object) + ) rownames(out$initial_probs) <- NULL rownames(out$transition_probs) <- NULL rownames(out$emission_probs) <- NULL - list( - initial_probs = out$initial_probs, - transition_probs = remove_voids(model, out$transition_probs), - emission_probs = remove_voids(model, out$emission_probs) - ) + if (remove_voids) { + list( + initial_probs = out$initial_probs, + transition_probs = remove_voids(model, out$transition_probs), + emission_probs = remove_voids(model, out$emission_probs) + ) + } else out } #' @rdname get_probs #' @export get_probs.mnhmm <- function(model, newdata = NULL, ...) { - out <- predict(model, newdata) + stopifnot_( + checkmate::test_flag(x = remove_voids), + "Argument {.arg remove_voids} must be a single {.cls logical} value." + ) + if (!is.null(newdata)) { + time <- object$time_variable + id <- object$id_variable + stopifnot_( + is.data.frame(newdata), + "Argument {.arg newdata} must be a {.cls data.frame} object." + ) + stopifnot_( + !is.null(newdata[[id]]), + "Can't find grouping variable {.var {id}} in {.arg newdata}." + ) + stopifnot_( + !is.null(newdata[[time]]), + "Can't find time index variable {.var {time}} in {.arg newdata}." + ) + object <- update(object, newdata = newdata) + } + S <- object$n_states + M <- object$n_symbols + C <- object$n_channels + N <- object$n_sequences + T_ <- object$length_of_sequences + out <- list( + initial_probs = get_initial_probs(object), + transition_probs = get_transition_probs(object), + emission_probs = get_emission_probs(object), + cluster_probs = get_cluster_probs(object) + ) rownames(out$initial_probs) <- NULL rownames(out$transition_probs) <- NULL rownames(out$emission_probs) <- NULL rownames(out$cluster_probs) <- NULL - list( - initial_probs = out$initial_probs, - transition_probs = remove_voids(model, out$transition_probs), - emission_probs = remove_voids(model, out$emission_probs), - cluster_probs = out$cluster_probs - ) + if (remove_voids) { + list( + initial_probs = out$initial_probs, + transition_probs = remove_voids(model, out$transition_probs), + emission_probs = remove_voids(model, out$emission_probs), + cluster_probs = out$cluster_probs + ) + } else out } diff --git a/R/hidden_paths.R b/R/hidden_paths.R index 2c8fa42e..8e9d0c37 100644 --- a/R/hidden_paths.R +++ b/R/hidden_paths.R @@ -108,10 +108,14 @@ hidden_paths.mnhmm <- function(model, respect_void = TRUE, ...) { model$coefficients$gamma_omega_raw, model$X_cluster, obsArray, model$n_symbols) } - model$state_names <- paste0( - rep(model$cluster_names, each = model$n_states), ": ", - model$state_names - ) + if (identical(model$state_names[[1]], model$state_names[[2]])) { + model$state_names <- paste0( + rep(model$cluster_names, each = model$n_states), ": ", + model$state_names + ) + } else { + model$state_names <- unlist(model$state_names) + } model$n_states <- length(model$state_names) create_mpp_seq(out, model, respect_void) } diff --git a/R/model_matrix.R b/R/model_matrix.R index 249400bd..ca1417b6 100644 --- a/R/model_matrix.R +++ b/R/model_matrix.R @@ -1,3 +1,22 @@ +#' Does covariate values vary per ID? +#' @noRd +iv_X <- function(X) { + dim(unique(X, MARGIN = 3))[3] > 1L +} +#' Does covariate values vary in time? +#' @noRd +tv_X <- function(X) { + dim(unique(X, MARGIN = 2))[2] > 1L +} + +# Function to check uniqueness along the N dimension +check_unique_N <- function(arr) { + # Flatten along the T and K dimensions + flattened_N <- apply(arr, 2, function(x) as.vector(x)) + # Check for unique rows + length(unique(as.data.frame(t(flattened_N)))) > 1 +} + #' Create the Model Matrix based on NHMM Formulas #' #' @noRd @@ -9,6 +28,7 @@ model_matrix_initial_formula <- function(formula, data, n_sequences, n_pars <- n_states - 1L X <- matrix(1, n_sequences, 1) coef_names <- "(Intercept)" + iv <- FALSE } else { first_time_point <- min(data[[time]]) X <- stats::model.matrix.lm( @@ -17,6 +37,7 @@ model_matrix_initial_formula <- function(formula, data, n_sequences, na.action = stats::na.pass ) missing_values <- which(!complete.cases(X)) + iv <- nrow(unique(X[-missing_values, ])) > 1 stopifnot_( length(missing_values) == 0, c( @@ -31,7 +52,8 @@ model_matrix_initial_formula <- function(formula, data, n_sequences, coef_names <- colnames(X) n_pars <- (n_states - 1L) * ncol(X) } - list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names) + list(formula = formula, n_pars = n_pars, X = t(X), coef_names = coef_names, + iv = iv) } #' Create the Model Matrix based on NHMM Formulas #' @@ -44,6 +66,7 @@ model_matrix_transition_formula <- function(formula, data, n_sequences, n_pars <- n_states * (n_states - 1L) X <- array(1, c(length_of_sequences, n_sequences, 1L)) coef_names <- "(Intercept)" + iv <- tv <- FALSE } else { X <- stats::model.matrix.lm( formula, @@ -66,13 +89,16 @@ model_matrix_transition_formula <- function(formula, data, n_sequences, ) ) } - # Replace NAs in void cases with zero as we need to input these to Stan - X[missing_values] <- 0 + # Replace NAs in void cases with zero + X[is.na(X)] <- 0 coef_names <- colnames(X) dim(X) <- c(length_of_sequences, n_sequences, ncol(X)) n_pars <- n_states * (n_states - 1L) * dim(X)[3] + iv <- iv_X(X) + tv <- tv_X(X) } - list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names) + list(formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), + coef_names = coef_names, iv = iv, tv = tv) } #' Create the Model Matrix based on NHMM Formulas #' @@ -83,9 +109,10 @@ model_matrix_emission_formula <- function(formula, data, n_sequences, time, id, sequence_lengths) { icp_only <- intercept_only(formula) if (icp_only) { - n_pars <- n_channels * n_states * (n_symbols - 1L) + n_pars <- sum(n_states * (n_symbols - 1L)) X <- array(1, c(length_of_sequences, n_sequences, 1L)) coef_names <- "(Intercept)" + iv <- tv <- FALSE } else { X <- stats::model.matrix.lm( formula, @@ -108,13 +135,16 @@ model_matrix_emission_formula <- function(formula, data, n_sequences, ) ) } - # Replace NAs in void cases with zero as we need to input these to Stan - X[missing_values] <- 0 + # Replace NAs in void cases with zero + X[is.na(X)] <- 0 coef_names <- colnames(X) dim(X) <- c(length_of_sequences, n_sequences, ncol(X)) - n_pars <- n_channels * n_states * (n_symbols - 1L) * dim(X)[3] + n_pars <- sum(n_states * (n_symbols - 1L) * dim(X)[3]) + iv <- iv_X(X) + tv <- tv_X(X) } - list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names) + list(formula = formula, n_pars = n_pars, X = aperm(X, c(3, 1, 2)), + coef_names = coef_names, iv = iv, tv = tv) } #' Create the Model Matrix based on NHMM Formulas #' @@ -126,6 +156,7 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters, n_pars <- n_clusters - 1L X <- matrix(1, n_sequences, 1) coef_names <- "(Intercept)" + iv <- FALSE } else { first_time_point <- min(data[[time]]) X <- stats::model.matrix.lm( @@ -134,6 +165,7 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters, na.action = stats::na.pass ) missing_values <- which(!complete.cases(X)) + iv <- nrow(unique(X[-missing_values, ])) > 1 stopifnot_( length(missing_values) == 0, c( @@ -148,5 +180,6 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters, coef_names <- colnames(X) n_pars <- (n_clusters - 1L) * ncol(X) } - list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names) + list(formula = formula, n_pars = n_pars, X = t(X), coef_names = coef_names, + iv = iv) } diff --git a/R/predict.R b/R/predict.R deleted file mode 100644 index 85b350fc..00000000 --- a/R/predict.R +++ /dev/null @@ -1,201 +0,0 @@ -#' #' Predict method for non-homogeneous hidden Markov models -#' #' -#' #' This is essentially same as `get_probs` but with option to return samples. -#' #' -#' #' @param object A Hidden Markov Model of class `nhmm` or `mnhmm`. -#' #' @param newdata Optional data frame which is used for prediction. -#' #' @param ... Ignored. -#' #' @noRd -#' predict.nhmm <- function( -#' object, newdata = NULL, dontchange_colnames = FALSE, ...) { -#' -#' if (!is.null(newdata)) { -#' time <- object$time_variable -#' id <- object$id_variable -#' stopifnot_( -#' is.data.frame(newdata), -#' "Argument {.arg newdata} must be a {.cls data.frame} object." -#' ) -#' stopifnot_( -#' !is.null(newdata[[id]]), -#' "Can't find grouping variable {.var {id}} in {.arg newdata}." -#' ) -#' stopifnot_( -#' !is.null(newdata[[time]]), -#' "Can't find time index variable {.var {time}} in {.arg newdata}." -#' ) -#' object <- update(object, newdata = newdata) -#' } -#' S <- object$n_states -#' M <- object$n_symbols -#' C <- object$n_channels -#' N <- object$n_sequences -#' T_ <- object$length_of_sequences -#' initial_probs <- get_pi(object$coefficients$gamma_pi_raw, object$X_initial, 0) -#' transition_probs <- get_A(object$coefficients$gamma_A_raw, object$X_transition, 0) -#' emission_probs <- if (C == 1) { -#' get_B(object$coefficients$gamma_B_raw, object$X_emission, 0, 0) -#' } else { -#' get_multichannel_B(object$object$gamma_B_raw, object$X_emission, S, M, 0, 0) -#' } -#' if (C == 1) { -#' ids <- rownames(object$observations) -#' times <- colnames(object$observations) -#' symbol_names <- list(object$symbol_names) -#' } else { -#' ids <- rownames(object$observations[[1]]) -#' times <- colnames(object$observations[[1]]) -#' symbol_names <- object$symbol_names -#' } -#' out <- list() -#' out$initial_probs <- data.frame( -#' id = rep(ids, each = S), -#' state = object$state_names, -#' estimate = c(initial_probs) -#' ) -#' -#' out$transition_probs <- data.frame( -#' id = rep(ids, each = S^2 * T_), -#' time = rep(times, each = S^2), -#' state_from = object$state_names, -#' state_to = rep(object$state_names, each = S), -#' estimate = unlist(transition_probs) -#' ) -#' out$emission_probs <- do.call( -#' rbind, -#' lapply(seq_len(C), function(i) { -#' data.frame( -#' id = rep(ids, each = S * M[i] * T_), -#' time = rep(times, each = S * M[i]), -#' state = object$state_names, -#' channel = object$channel_names[i], -#' observation = rep(symbol_names[[i]], each = S), -#' estimate = unlist(emission_probs[((i - 1) * N + 1):(i * N)]) -#' ) -#' }) -#' ) -#' if (C == 1) out$emission_probs$channel <- NULL -#' -#' if (!dontchange_colnames) { -#' colnames(out$initial_probs)[1] <- object$id_variable -#' colnames(out$transition_probs)[1] <- object$id_variable -#' colnames(out$transition_probs)[2] <- object$time_variable -#' colnames(out$emission_probs)[1] <- object$id_variable -#' colnames(out$emission_probs)[2] <- object$time_variable -#' } -#' out -#' } -#' #' @noRd -#' predict.mnhmm <- function( -#' object, newdata = NULL, dontchange_colnames = FALSE, ...) { -#' -#' if (!is.null(newdata)) { -#' time <- object$time_variable -#' id <- object$id_variable -#' stopifnot_( -#' is.data.frame(newdata), -#' "Argument {.arg newdata} must be a {.cls data.frame} object." -#' ) -#' stopifnot_( -#' !is.null(newdata[[id]]), -#' "Can't find grouping variable {.var {id}} in {.arg newdata}." -#' ) -#' stopifnot_( -#' !is.null(newdata[[time]]), -#' "Can't find time index variable {.var {time}} in {.arg newdata}." -#' ) -#' object <- update(object, newdata = newdata) -#' } -#' T_ <- object$length_of_sequences -#' N <- object$n_sequences -#' S <- object$n_states -#' M <- object$n_symbols -#' C <- object$n_channels -#' D <- object$n_clusters -#' gamma_omega_raw <- object$coefficients$gamma_omega_raw -#' initial_probs <- vector("list", D) -#' transition_probs <- vector("list", D) -#' emission_probs <- vector("list", D) -#' for (d in seq_len(D)) { -#' gamma_pi_raw <- coef_to_cpp_initial( -#' matrix( -#' object$coefficients$gamma_pi_raw[d, ,], -#' S - 1, nrow(object$X_initial) -#' ) -#' ) -#' gamma_A_raw <- coef_to_cpp_transition( -#' array( -#' object$coefficients$gamma_A_raw[d, , ,], -#' dim = c(S, S - 1, nrow(object$X_transition)) -#' ) -#' ) -#' gamma_B_raw <- coef_to_cpp_emission( -#' if (C == 1) { -#' array( -#' object$coefficients$gamma_B_raw[d, , ,], -#' dim = c(S, M - 1, nrow(object$X_emission)) -#' ) -#' } else { -#' object$coefficients$gamma_B_raw[d, ] -#' }, -#' 1, -#' C > 1 -#' ) -#' initial_probs[[d]] <- get_pi(object$coefficients$gamma_pi_raw, object$X_initial, 0) -#' transition_probs[[d]] <- get_A(object$coefficients$gamma_A_raw, object$X_transition, 0) -#' emission_probs[[d]] <- if (C == 1) { -#' get_B(object$coefficients$gamma_B_raw, object$X_emission, 0, 0) -#' } else { -#' get_multichannel_B(object$coefficients$gamma_B_raw, object$X_emission, S, M, 0, 0) -#' } -#' } -#' if (C == 1) { -#' ids <- rownames(object$observations) -#' times <- colnames(object$observations) -#' symbol_names <- list(object$symbol_names) -#' } else { -#' ids <- rownames(object$observations[[1]]) -#' times <- colnames(object$observations[[1]]) -#' symbol_names <- object$symbol_names -#' } -#' out <- list() -#' out$initial_probs <- data.frame( -#' cluster = rep(object$cluster_names, each = S * N), -#' id = rep(ids, each = S), -#' state = object$state_names, -#' estimate = unlist(initial_probs) -#' ) -#' out$transition_probs <- data.frame( -#' cluster = rep(object$cluster_names, each = S^2 * T_ * N), -#' id = rep(ids, each = S^2 * T_), -#' time = rep(times, each = S^2), -#' state_from = object$state_names, -#' state_to = rep(object$state_names, each = S), -#' estimate = unlist(transition_probs) -#' ) -#' out$emission_probs <- data.frame( -#' cluster = rep(object$cluster_names, each = S * sum(M) * T_ * N), -#' id = unlist(lapply(seq_len(C), function(i) rep(ids, each = S * M[i] * T_))), -#' time = unlist(lapply(seq_len(C), function(i) rep(times, each = S * M[i]))), -#' state = object$state_names, -#' channel = rep(object$channel_names, S * M * T_ * N), -#' observation = rep(unlist(symbol_names), each = S), -#' estimate = unlist(emission_probs) -#' ) -#' if (C == 1) emission_probs$channel <- NULL -#' out$cluster_probs <- data.frame( -#' cluster = object$cluster_names, -#' id = rep(ids, each = D), -#' estimate = c(get_omega(object$coefficients$gamma_omega_raw, -#' object$coefficients$X_cluster, 0)) -#' ) -#' if (!dontchange_colnames) { -#' colnames(out$initial_probs)[2] <- object$id_variable -#' colnames(out$transition_probs)[2] <- object$id_variable -#' colnames(out$transition_probs)[3] <- object$time_variable -#' colnames(out$emission_probs)[2] <- object$id_variable -#' colnames(out$emission_probs)[3] <- object$time_variable -#' colnames(out$cluster_probs)[2] <- object$id_variable -#' } -#' out -#' } diff --git a/R/state_names.R b/R/state_names.R index cbf4aadc..2df102a5 100644 --- a/R/state_names.R +++ b/R/state_names.R @@ -1,34 +1,45 @@ -#' Get state names from hmm or mhmm object +#' Get State Names of Hidden Markov Model #' -#' @param object An object of class `hmm` or `mhmm`. +#' @param object An object of class `hmm`, `mhmm`, `nhmm`, or `mnhmm`. #' @return A character vector containing the state names, or a list of such -#' vectors in `mhmm` case. +#' vectors in case of mixture models. +#' @rdname state_names #' @export state_names <- function(object) { UseMethod("state_names") } - +#' @rdname state_names #' @export state_names.hmm <- function(object) { object$state_names } - +#' @rdname state_names #' @export state_names.mhmm <- function(object) { object$state_names } - -#' Set state names for hmm or mhmm object +#' @rdname state_names +#' @export +state_names.nhmm <- function(object) { + object$state_names +} +#' @rdname state_names +#' @export +state_names.mnhmm <- function(object) { + object$state_names +} +#' Set State Names of Hidden Markov Model #' -#' @param object An object of class `hmm` or `mhmm`. +#' @param object object An object of class `hmm`, `mhmm`, `nhmm`, or `mnhmm`. #' @param value A character vector containing the new state names, or a list of -#' such vectors in `mhmm` case. -#' @return The modified object with updated state names. +#' such vectors in case of mixture models. +#' @return The original object with updated state names. +#' @rdname state_names #' @export `state_names<-` <- function(object, value) { UseMethod("state_names<-") } - +#' @rdname state_names #' @export `state_names<-.hmm` <- function(object, value) { stopifnot_( @@ -47,7 +58,7 @@ state_names.mhmm <- function(object) { } object } - +#' @rdname state_names #' @export `state_names<-.mhmm` <- function(object, value) { stopifnot_( @@ -55,7 +66,7 @@ state_names.mhmm <- function(object) { "New state names should be a {.cls list} with length of {object$n_clusters}." ) - for (i in 1:object$n_clusters) { + for (i in seq_len(object$n_clusters)) { stopifnot_( length(value[[i]]) == object$n_states[i], "Number of new state names for cluster {i} is not equal to the number of @@ -74,3 +85,31 @@ state_names.mhmm <- function(object) { } object } +#' @rdname state_names +#' @export +`state_names<-.nhmm` <- function(object, value) { + stopifnot_( + length(value) == object$n_states, + "Number of state names does not match with the number of states." + ) + object$state_names <- value + object +} +#' @rdname state_names +#' @export +`state_names<-.mnhmm` <- function(object, value) { + stopifnot_( + length(value) != object$n_clusters, + "New state names should be a {.cls list} with length of + {object$n_clusters}." + ) + for (i in seq_len(object$n_clusters)) { + stopifnot_( + length(value[[i]]) == object$n_states[i], + "Number of new state names for cluster {i} is not equal to the number of + hidden states." + ) + object$state_names[[i]] <- value[[i]] + } + object +} \ No newline at end of file diff --git a/R/update.R b/R/update.R index e36441c1..64812527 100644 --- a/R/update.R +++ b/R/update.R @@ -12,22 +12,30 @@ update.nhmm <- function(object, newdata, ...) { object$n_sequences <- length(unique(newdata[[object$id_variable]])) object$length_of_sequences <- length(unique(newdata[[object$time_variable]])) if (!is.null(object$data)) object$data <- newdata - object$X_initial <- model_matrix_initial_formula( + X <- model_matrix_initial_formula( object$initial_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, object$id_variable - )$X - object$X_transition <- model_matrix_transition_formula( + ) + object$X_initial <- X$X + attr(object, "iv_pi") <- x$iv + X <- model_matrix_transition_formula( object$transition_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, object$id_variable, object$sequence_lengths - )$X - object$X_emission <- model_matrix_emission_formula( + ) + object$X_transition <- X$X + attr(object, "iv_A") <- X$iv + attr(object, "tv_A") <- X$tv + X <- model_matrix_emission_formula( object$emission_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$n_symbols, object$n_channels, object$time_variable, object$id_variable, object$sequence_lengths - )$X + ) + object$X_emission <- X$X + attr(object, "iv_B") <- X$iv + attr(object, "tv_B") <- X$tv object } #' @rdname update_nhmm @@ -37,25 +45,35 @@ update.mnhmm <- function(object, newdata, ...) { object$n_sequences <- length(unique(newdata[[object$id_variable]])) object$length_of_sequences <- length(unique(newdata[[object$time_variable]])) if (!is.null(object$data)) object$data <- newdata - object$X_initial <- model_matrix_initial_formula( + X <- model_matrix_initial_formula( object$initial_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, object$id_variable - )$X - object$X_transition <- model_matrix_transition_formula( + ) + object$X_initial <- X$X + attr(object, "iv_pi") <- x$iv + X <- model_matrix_transition_formula( object$transition_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, object$id_variable, object$sequence_lengths - )$X - object$X_emission <- model_matrix_emission_formula( + ) + object$X_transition <- X$X + attr(object, "iv_A") <- X$iv + attr(object, "tv_A") <- X$tv + X <- model_matrix_emission_formula( object$emission_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$n_symbols, object$n_channels, object$time_variable, object$id_variable, object$sequence_lengths - )$X - object$X_cluster <- model_matrix_cluster_formula( + ) + object$X_emission <- X$X + attr(object, "iv_B") <- X$iv + attr(object, "tv_B") <- X$tv + X <- model_matrix_cluster_formula( object$cluster_formula, newdata, object$n_sequences, object$n_clusters, object$time_variable, object$id_variable - )$X + ) + object$X_cluster <- X$X + attr(object, "iv_omega") <- x$iv object } \ No newline at end of file diff --git a/man/cluster_probs.Rd b/man/cluster_probs.Rd new file mode 100644 index 00000000..3aec0632 --- /dev/null +++ b/man/cluster_probs.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_probs.R +\name{get_cluster_probs} +\alias{get_cluster_probs} +\alias{get_cluster_probs.mnhmm} +\alias{get_cluster_probs.mhmm} +\title{Extract the Prior Cluster Probabilities of MHMM or MNHMM} +\usage{ +get_cluster_probs(model, ...) + +\method{get_cluster_probs}{mnhmm}(model, ...) + +\method{get_cluster_probs}{mhmm}(model, ...) +} +\arguments{ +\item{model}{An object of class \code{mnhmm} or `mhmm.} +} +\description{ +Extract the Prior Cluster Probabilities of MHMM or MNHMM +} +\seealso{ +\code{\link[=posterior_cluster_probabilities]{posterior_cluster_probabilities()}}. +} diff --git a/man/emission_probs.Rd b/man/emission_probs.Rd new file mode 100644 index 00000000..f978cb73 --- /dev/null +++ b/man/emission_probs.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_probs.R +\name{get_emission_probs} +\alias{get_emission_probs} +\alias{get_emission_probs.nhmm} +\alias{get_emission_probs.mnhmm} +\alias{get_emission_probs.hmm} +\alias{get_emission_probs.mhmm} +\title{Extract the Emission Probabilities of Hidden Markov Model} +\usage{ +get_emission_probs(model, ...) + +\method{get_emission_probs}{nhmm}(x) + +\method{get_emission_probs}{mnhmm}(x) + +\method{get_emission_probs}{hmm}(x) + +\method{get_emission_probs}{mhmm}(x) +} +\description{ +Extract the Emission Probabilities of Hidden Markov Model +} diff --git a/man/get_probs.Rd b/man/get_probs.Rd index fbaf0a06..6d629e61 100644 --- a/man/get_probs.Rd +++ b/man/get_probs.Rd @@ -4,12 +4,12 @@ \alias{get_probs} \alias{get_probs.nhmm} \alias{get_probs.mnhmm} -\title{Get the Estimated Initial, Transition, and Emission Probabilities for NHMM -or MNHMM} +\title{Get the Estimated Initial, Transition, Emission and (Prior) Cluster +Probabilities for NHMM or MNHMM} \usage{ get_probs(model, ...) -\method{get_probs}{nhmm}(model, newdata = NULL, ...) +\method{get_probs}{nhmm}(model, newdata = NULL, remove_voids = TRUE, ...) \method{get_probs}{mnhmm}(model, newdata = NULL, ...) } @@ -20,8 +20,11 @@ get_probs(model, ...) \item{newdata}{An optional data frame containing the new data to be used in computing the probabilities.} + +\item{remove_voids}{Should the time points corresponding to \code{TraMineR}'s +void in the observed sequences be removed? Default is \code{TRUE}.} } \description{ -Get the Estimated Initial, Transition, and Emission Probabilities for NHMM -or MNHMM +Get the Estimated Initial, Transition, Emission and (Prior) Cluster +Probabilities for NHMM or MNHMM } diff --git a/man/initial_probs.Rd b/man/initial_probs.Rd new file mode 100644 index 00000000..2d9967b1 --- /dev/null +++ b/man/initial_probs.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_probs.R +\name{get_initial_probs} +\alias{get_initial_probs} +\alias{get_initial_probs.nhmm} +\alias{get_initial_probs.mnhmm} +\alias{get_initial_probs.hmm} +\alias{get_initial_probs.mhmm} +\title{Extract the Initial State Probabilities of Hidden Markov Model} +\usage{ +get_initial_probs(model, ...) + +\method{get_initial_probs}{nhmm}(model, ...) + +\method{get_initial_probs}{mnhmm}(x) + +\method{get_initial_probs}{hmm}(x) + +\method{get_initial_probs}{mhmm}(x) +} +\description{ +Extract the Initial State Probabilities of Hidden Markov Model +} diff --git a/man/state_names-set.Rd b/man/state_names-set.Rd deleted file mode 100644 index dbab0946..00000000 --- a/man/state_names-set.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/state_names.R -\name{state_names<-} -\alias{state_names<-} -\title{Set state names for hmm or mhmm object} -\usage{ -state_names(object) <- value -} -\arguments{ -\item{object}{An object of class \code{hmm} or \code{mhmm}.} - -\item{value}{A character vector containing the new state names, or a list of -such vectors in \code{mhmm} case.} -} -\value{ -The modified object with updated state names. -} -\description{ -Set state names for hmm or mhmm object -} diff --git a/man/state_names.Rd b/man/state_names.Rd index b3a61e5b..76bc0f3c 100644 --- a/man/state_names.Rd +++ b/man/state_names.Rd @@ -2,17 +2,51 @@ % Please edit documentation in R/state_names.R \name{state_names} \alias{state_names} -\title{Get state names from hmm or mhmm object} +\alias{state_names.hmm} +\alias{state_names.mhmm} +\alias{state_names.nhmm} +\alias{state_names.mnhmm} +\alias{state_names<-} +\alias{state_names<-.hmm} +\alias{state_names<-.mhmm} +\alias{state_names<-.nhmm} +\alias{state_names<-.mnhmm} +\title{Get State Names of Hidden Markov Model} \usage{ state_names(object) + +\method{state_names}{hmm}(object) + +\method{state_names}{mhmm}(object) + +\method{state_names}{nhmm}(object) + +\method{state_names}{mnhmm}(object) + +state_names(object) <- value + +\method{state_names}{hmm}(object) <- value + +\method{state_names}{mhmm}(object) <- value + +\method{state_names}{nhmm}(object) <- value + +\method{state_names}{mnhmm}(object) <- value } \arguments{ -\item{object}{An object of class \code{hmm} or \code{mhmm}.} +\item{object}{object An object of class \code{hmm}, \code{mhmm}, \code{nhmm}, or \code{mnhmm}.} + +\item{value}{A character vector containing the new state names, or a list of +such vectors in case of mixture models.} } \value{ A character vector containing the state names, or a list of such -vectors in \code{mhmm} case. +vectors in case of mixture models. + +The original object with updated state names. } \description{ -Get state names from hmm or mhmm object +Get State Names of Hidden Markov Model + +Set State Names of Hidden Markov Model } diff --git a/man/transition_probs.Rd b/man/transition_probs.Rd new file mode 100644 index 00000000..0da2bd29 --- /dev/null +++ b/man/transition_probs.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_probs.R +\name{get_transition_probs} +\alias{get_transition_probs} +\alias{get_transition_probs.nhmm} +\alias{get_transition_probs.mnhmm} +\alias{get_transition_probs.hmm} +\alias{get_transition_probs.mhmm} +\title{Extract the State Transition Probabilities of Hidden Markov Model} +\usage{ +get_transition_probs(model, ...) + +\method{get_transition_probs}{nhmm}(x) + +\method{get_transition_probs}{mnhmm}(x) + +\method{get_transition_probs}{hmm}(x) + +\method{get_transition_probs}{mhmm}(x) +} +\description{ +Extract the State Transition Probabilities of Hidden Markov Model +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 896a02af..d995b17f 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -240,56 +240,58 @@ BEGIN_RCPP END_RCPP } // get_omega -arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const int logspace); +arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace); RcppExport SEXP _seqHMM_get_omega(SEXP gamma_omega_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::mat& >::type gamma_omega_raw(gamma_omega_rawSEXP); Rcpp::traits::input_parameter< const arma::vec >::type X(XSEXP); - Rcpp::traits::input_parameter< const int >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); rcpp_result_gen = Rcpp::wrap(get_omega(gamma_omega_raw, X, logspace)); return rcpp_result_gen; END_RCPP } // get_pi -arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const int logspace); +arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const bool logspace); RcppExport SEXP _seqHMM_get_pi(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::mat& >::type gamma_raw(gamma_rawSEXP); Rcpp::traits::input_parameter< const arma::vec >::type X(XSEXP); - Rcpp::traits::input_parameter< const int >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); rcpp_result_gen = Rcpp::wrap(get_pi(gamma_raw, X, logspace)); return rcpp_result_gen; END_RCPP } // get_A -arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, const int logspace); -RcppExport SEXP _seqHMM_get_A(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP) { +arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, const bool logspace, const bool tv); +RcppExport SEXP _seqHMM_get_A(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP logspaceSEXP, SEXP tvSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::cube& >::type gamma_raw(gamma_rawSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); - Rcpp::traits::input_parameter< const int >::type logspace(logspaceSEXP); - rcpp_result_gen = Rcpp::wrap(get_A(gamma_raw, X, logspace)); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP); + rcpp_result_gen = Rcpp::wrap(get_A(gamma_raw, X, logspace, tv)); return rcpp_result_gen; END_RCPP } // get_B -arma::field get_B(const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, const int logspace, const int add_missing); -RcppExport SEXP _seqHMM_get_B(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP MSEXP, SEXP logspaceSEXP, SEXP add_missingSEXP) { +arma::field get_B(const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, const bool logspace, const bool add_missing, const bool tv); +RcppExport SEXP _seqHMM_get_B(SEXP gamma_rawSEXP, SEXP XSEXP, SEXP MSEXP, SEXP logspaceSEXP, SEXP add_missingSEXP, SEXP tvSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::field& >::type gamma_raw(gamma_rawSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); Rcpp::traits::input_parameter< const arma::uvec& >::type M(MSEXP); - Rcpp::traits::input_parameter< const int >::type logspace(logspaceSEXP); - Rcpp::traits::input_parameter< const int >::type add_missing(add_missingSEXP); - rcpp_result_gen = Rcpp::wrap(get_B(gamma_raw, X, M, logspace, add_missing)); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type add_missing(add_missingSEXP); + Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP); + rcpp_result_gen = Rcpp::wrap(get_B(gamma_raw, X, M, logspace, add_missing, tv)); return rcpp_result_gen; END_RCPP } @@ -466,8 +468,8 @@ BEGIN_RCPP END_RCPP } // log_objective_nhmm_singlechannel -Rcpp::List log_objective_nhmm_singlechannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, const arma::mat& obs); -RcppExport SEXP _seqHMM_log_objective_nhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP) { +Rcpp::List log_objective_nhmm_singlechannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B); +RcppExport SEXP _seqHMM_log_objective_nhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -478,13 +480,18 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const arma::cube& >::type gamma_B_raw(gamma_B_rawSEXP); Rcpp::traits::input_parameter< const arma::cube& >::type X_o(X_oSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type obs(obsSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs)); + Rcpp::traits::input_parameter< const bool >::type iv_pi(iv_piSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, iv_pi, iv_A, iv_B, tv_A, tv_B)); return rcpp_result_gen; END_RCPP } // log_objective_nhmm_multichannel -Rcpp::List log_objective_nhmm_multichannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::cube& obs, const arma::uvec& M); -RcppExport SEXP _seqHMM_log_objective_nhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP MSEXP) { +Rcpp::List log_objective_nhmm_multichannel(const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B); +RcppExport SEXP _seqHMM_log_objective_nhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -496,13 +503,18 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const arma::cube& >::type X_o(X_oSEXP); Rcpp::traits::input_parameter< const arma::cube& >::type obs(obsSEXP); Rcpp::traits::input_parameter< const arma::uvec& >::type M(MSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M)); + Rcpp::traits::input_parameter< const bool >::type iv_pi(iv_piSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_nhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B)); return rcpp_result_gen; END_RCPP } // log_objective_mnhmm_singlechannel -Rcpp::List log_objective_mnhmm_singlechannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::mat& obs); -RcppExport SEXP _seqHMM_log_objective_mnhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP) { +Rcpp::List log_objective_mnhmm_singlechannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega); +RcppExport SEXP _seqHMM_log_objective_mnhmm_singlechannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -515,13 +527,19 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const arma::mat& >::type gamma_omega_raw(gamma_omega_rawSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type X_d(X_dSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type obs(obsSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs)); + Rcpp::traits::input_parameter< const bool >::type iv_pi(iv_piSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_omega(iv_omegaSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_singlechannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega)); return rcpp_result_gen; END_RCPP } // log_objective_mnhmm_multichannel -Rcpp::List log_objective_mnhmm_multichannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::cube& obs, const arma::uvec& M); -RcppExport SEXP _seqHMM_log_objective_mnhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP MSEXP) { +Rcpp::List log_objective_mnhmm_multichannel(const arma::field& gamma_pi_raw, const arma::mat& X_i, const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, const arma::cube& obs, const arma::uvec& M, const bool iv_pi, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const bool iv_omega); +RcppExport SEXP _seqHMM_log_objective_mnhmm_multichannel(SEXP gamma_pi_rawSEXP, SEXP X_iSEXP, SEXP gamma_A_rawSEXP, SEXP X_sSEXP, SEXP gamma_B_rawSEXP, SEXP X_oSEXP, SEXP gamma_omega_rawSEXP, SEXP X_dSEXP, SEXP obsSEXP, SEXP MSEXP, SEXP iv_piSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP iv_omegaSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -535,7 +553,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const arma::mat& >::type X_d(X_dSEXP); Rcpp::traits::input_parameter< const arma::cube& >::type obs(obsSEXP); Rcpp::traits::input_parameter< const arma::uvec& >::type M(MSEXP); - rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M)); + Rcpp::traits::input_parameter< const bool >::type iv_pi(iv_piSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_omega(iv_omegaSEXP); + rcpp_result_gen = Rcpp::wrap(log_objective_mnhmm_multichannel(gamma_pi_raw, X_i, gamma_A_raw, X_s, gamma_B_raw, X_o, gamma_omega_raw, X_d, obs, M, iv_pi, iv_A, iv_B, tv_A, tv_B, iv_omega)); return rcpp_result_gen; END_RCPP } @@ -603,13 +627,13 @@ BEGIN_RCPP END_RCPP } // softmax -arma::vec softmax(const arma::vec& x, const int logspace); +arma::vec softmax(const arma::vec& x, const bool logspace); RcppExport SEXP _seqHMM_softmax(SEXP xSEXP, SEXP logspaceSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::vec& >::type x(xSEXP); - Rcpp::traits::input_parameter< const int >::type logspace(logspaceSEXP); + Rcpp::traits::input_parameter< const bool >::type logspace(logspaceSEXP); rcpp_result_gen = Rcpp::wrap(softmax(x, logspace)); return rcpp_result_gen; END_RCPP @@ -748,8 +772,8 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_forwardbackwardx", (DL_FUNC) &_seqHMM_forwardbackwardx, 9}, {"_seqHMM_get_omega", (DL_FUNC) &_seqHMM_get_omega, 3}, {"_seqHMM_get_pi", (DL_FUNC) &_seqHMM_get_pi, 3}, - {"_seqHMM_get_A", (DL_FUNC) &_seqHMM_get_A, 3}, - {"_seqHMM_get_B", (DL_FUNC) &_seqHMM_get_B, 5}, + {"_seqHMM_get_A", (DL_FUNC) &_seqHMM_get_A, 4}, + {"_seqHMM_get_B", (DL_FUNC) &_seqHMM_get_B, 6}, {"_seqHMM_logLikHMM", (DL_FUNC) &_seqHMM_logLikHMM, 5}, {"_seqHMM_logLikMixHMM", (DL_FUNC) &_seqHMM_logLikMixHMM, 8}, {"_seqHMM_logSumExp", (DL_FUNC) &_seqHMM_logSumExp, 1}, @@ -760,10 +784,10 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_log_logLikHMM", (DL_FUNC) &_seqHMM_log_logLikHMM, 5}, {"_seqHMM_log_logLikMixHMM", (DL_FUNC) &_seqHMM_log_logLikMixHMM, 8}, {"_seqHMM_log_objective", (DL_FUNC) &_seqHMM_log_objective, 9}, - {"_seqHMM_log_objective_nhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_singlechannel, 7}, - {"_seqHMM_log_objective_nhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_multichannel, 8}, - {"_seqHMM_log_objective_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_singlechannel, 9}, - {"_seqHMM_log_objective_mnhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_multichannel, 10}, + {"_seqHMM_log_objective_nhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_singlechannel, 12}, + {"_seqHMM_log_objective_nhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_nhmm_multichannel, 13}, + {"_seqHMM_log_objective_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_singlechannel, 15}, + {"_seqHMM_log_objective_mnhmm_multichannel", (DL_FUNC) &_seqHMM_log_objective_mnhmm_multichannel, 16}, {"_seqHMM_log_objectivex", (DL_FUNC) &_seqHMM_log_objectivex, 12}, {"_seqHMM_objective", (DL_FUNC) &_seqHMM_objective, 9}, {"_seqHMM_objectivex", (DL_FUNC) &_seqHMM_objectivex, 12}, diff --git a/src/backward_nhmm.cpp b/src/backward_nhmm.cpp index 73bbb2be..5ffc3f65 100644 --- a/src/backward_nhmm.cpp +++ b/src/backward_nhmm.cpp @@ -38,11 +38,11 @@ arma::cube backward_nhmm_singlechannel( arma::cube log_A(S, S, T); arma::cube log_B(S, M + 1, T); for (unsigned int i = 0; i < N; i++) { - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), 1, 1); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { - log_py(s, t) = log_B.at(s, obs(t, i), t); + log_py(s, t) = log_B(s, obs(t, i), t); } } log_beta.slice(i) = univariate_backward_nhmm(log_A, log_py); @@ -65,13 +65,13 @@ arma::cube backward_nhmm_multichannel( arma::cube log_A(S, S, T); arma::field log_B(C); for (unsigned int i = 0; i < N; i++) { - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), M, 1, 1); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); } } } @@ -98,8 +98,8 @@ arma::cube backward_mnhmm_singlechannel( arma::cube log_B(S, M + 1, T); for (unsigned int i = 0; i < N; i++) { for (unsigned int d = 0; d < D; d++) { - log_A = get_A(gamma_A_raw(d), X_s.slice(i), 1); - log_B = get_B(gamma_B_raw(d), X_o.slice(i), 1, 1); + log_A = get_A(gamma_A_raw(d), X_s.slice(i), true); + log_B = get_B(gamma_B_raw(d), X_o.slice(i), true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = log_B(s, obs(t, i), t); @@ -128,15 +128,15 @@ arma::cube backward_mnhmm_multichannel( arma::field log_B(C); for (unsigned int i = 0; i < N; i++) { for (unsigned int d = 0; d < D; d++) { - log_A = get_A(gamma_A_raw(d), X_s.slice(i), 1); + log_A = get_A(gamma_A_raw(d), X_s.slice(i), true); log_B = get_B( - gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, 1, 1 + gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, true, true ); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); } } } diff --git a/src/forward_nhmm.cpp b/src/forward_nhmm.cpp index 3b4aaf04..aeb16cb0 100644 --- a/src/forward_nhmm.cpp +++ b/src/forward_nhmm.cpp @@ -40,12 +40,12 @@ arma::cube forward_nhmm_singlechannel( arma::cube log_A(S, S, T); arma::cube log_B(S, M + 1, T); for (unsigned int i = 0; i < N; i++) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), 1); - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), 1, 1); + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { - log_py(s, t) = log_B.at(s, obs(t, i), t); + log_py(s, t) = log_B(s, obs(t, i), t); } } log_alpha.slice(i) = univariate_forward_nhmm(log_Pi, log_A, log_py); @@ -70,14 +70,14 @@ arma::cube forward_nhmm_multichannel( arma::cube log_A(S, S, T); arma::field log_B(C); for (unsigned int i = 0; i < N; i++) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), 1); - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), M, 1, 1); + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); } } } @@ -107,11 +107,11 @@ arma::cube forward_mnhmm_singlechannel( arma::cube log_B(S, M + 1, T); arma::vec log_omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); for (unsigned int d = 0; d < D; d++) { - log_Pi = get_pi(gamma_pi_raw(d), X_i.col(i), 1); - log_A = get_A(gamma_A_raw(d), X_s.slice(i), 1); - log_B = get_B(gamma_B_raw(d), X_o.slice(i), 1, 1); + log_Pi = get_pi(gamma_pi_raw(d), X_i.col(i), true); + log_A = get_A(gamma_A_raw(d), X_s.slice(i), true); + log_B = get_B(gamma_B_raw(d), X_o.slice(i), true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = log_B(s, obs(t, i), t); @@ -143,18 +143,18 @@ arma::cube forward_mnhmm_multichannel( arma::field log_B(C); arma::vec log_omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); for (unsigned int d = 0; d < D; d++) { - log_Pi = get_pi(gamma_pi_raw(d), X_i.col(i), 1); - log_A = get_A(gamma_A_raw(d), X_s.slice(i), 1); + log_Pi = get_pi(gamma_pi_raw(d), X_i.col(i), true); + log_A = get_A(gamma_A_raw(d), X_s.slice(i), true); log_B = get_B( - gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, 1, 1 + gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, true, true ); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); } } } diff --git a/src/get_parameters.cpp b/src/get_parameters.cpp index 38811509..8b7d468e 100644 --- a/src/get_parameters.cpp +++ b/src/get_parameters.cpp @@ -5,7 +5,7 @@ // gamma_omega_raw is (D - 1) x K (start from, covariates) // X a vector of length K // [[Rcpp::export]] -arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const int logspace) { +arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace) { arma::mat gamma_omega = arma::join_cols(arma::zeros(gamma_omega_raw.n_cols), gamma_omega_raw); return softmax(gamma_omega * X, logspace); } @@ -13,7 +13,7 @@ arma::vec get_omega(const arma::mat& gamma_omega_raw, const arma::vec X, const i // gamma_raw is (S - 1) x K (start from, covariates) // X a vector of length K // [[Rcpp::export]] -arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const int logspace) { +arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const bool logspace) { arma::mat beta = arma::join_cols(arma::zeros(gamma_raw.n_cols), gamma_raw); return softmax(beta * X, logspace); } @@ -21,7 +21,7 @@ arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec X, const int logspa // X is K x T matrix (covariates, time points) // [[Rcpp::export]] arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, - const int logspace) { + const bool logspace, const bool tv) { unsigned int S = gamma_raw.n_slices; unsigned int K = X.n_rows; unsigned int T = X.n_cols; @@ -31,18 +31,25 @@ arma::cube get_A(const arma::cube& gamma_raw, const arma::mat& X, } arma::cube A(S, S, T); arma::mat Atmp(S, S); - for (unsigned int t = 0; t < T; t++) { // time + if (tv) { + for (unsigned int t = 0; t < T; t++) { // time + for (unsigned int j = 0; j < S; j ++) { // from states + Atmp.col(j) = softmax(beta.slice(j) * X.col(t), logspace); + } + A.slice(t) = Atmp.t(); + } + } else { for (unsigned int j = 0; j < S; j ++) { // from states - Atmp.col(j) = softmax(beta.slice(j) * X.col(t), logspace); + Atmp.col(j) = softmax(beta.slice(j) * X.col(0), logspace); } - A.slice(t) = Atmp.t(); + A.each_slice() = Atmp.t(); } return A; } // gamma_raw is (M - 1) x K x S (symbols, covariates, transition from) // X is K x T (covariates, time points) arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, - const int logspace, const int add_missing) { + const bool logspace, const bool add_missing, const bool tv) { unsigned int S = gamma_raw.n_slices; unsigned int M = gamma_raw.n_rows + 1; unsigned int K = X.n_rows; @@ -56,13 +63,22 @@ arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, if (add_missing) { Btmp.row(M).fill(1.0 - logspace); } - for (unsigned int t = 0; t < T; t++) { // time + if (tv) { + for (unsigned int t = 0; t < T; t++) { // time + for (unsigned int j = 0; j < S; j ++) { // from states + Btmp.col(j).rows(0, M - 1) = softmax( + beta.slice(j) * X.col(t), logspace + ); + } + B.slice(t) = Btmp.t(); + } + } else { for (unsigned int j = 0; j < S; j ++) { // from states Btmp.col(j).rows(0, M - 1) = softmax( - beta.slice(j) * X.col(t), logspace + beta.slice(j) * X.col(0), logspace ); } - B.slice(t) = Btmp.t(); + B.each_slice() = Btmp.t(); } return B; } @@ -72,11 +88,11 @@ arma::cube get_B(const arma::cube& gamma_raw, const arma::mat& X, arma::field get_B( const arma::field& gamma_raw, const arma::mat& X, const arma::uvec& M, - const int logspace, const int add_missing) { + const bool logspace, const bool add_missing, const bool tv) { unsigned int C = M.n_elem; arma::field B(C); // C field of cubes, each S x M_c x T for (unsigned int c = 0; c < C; c++) { - B(c) = get_B(gamma_raw(c), X, logspace, add_missing); + B(c) = get_B(gamma_raw(c), X, logspace, add_missing, tv); } return B; } diff --git a/src/get_parameters.h b/src/get_parameters.h index 54dd907f..e61c03cc 100644 --- a/src/get_parameters.h +++ b/src/get_parameters.h @@ -4,20 +4,22 @@ #include arma::vec get_omega( - const arma::mat& gamma_omega_raw, const arma::vec X, const int logspace + const arma::mat& gamma_omega_raw, const arma::vec X, const bool logspace ); arma::vec get_pi( - const arma::mat& beta_raw, const arma::vec X, const int logspace + const arma::mat& beta_raw, const arma::vec X, const bool logspace ); arma::cube get_A( - const arma::cube& beta_raw, const arma::mat& X, const int logspace + const arma::cube& beta_raw, const arma::mat& X, const bool logspace, + const bool tv = true ); arma::cube get_B( - const arma::cube& beta_raw, const arma::mat& X, const int logspace, - const int add_missing = 0 + const arma::cube& beta_raw, const arma::mat& X, const bool logspace, + const bool add_missing = false, const bool tv = true ); arma::field get_B( const arma::field& beta_raw, const arma::mat& X, - const arma::uvec& M, const int logspace, const int add_missing = 0 + const arma::uvec& M, const bool logspace, const bool add_missing = false, + const bool tv = true ); #endif diff --git a/src/log_objective_nhmm.cpp b/src/log_objective_nhmm.cpp index 8b37f920..58e24dc5 100644 --- a/src/log_objective_nhmm.cpp +++ b/src/log_objective_nhmm.cpp @@ -9,7 +9,8 @@ Rcpp::List log_objective_nhmm_singlechannel( const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::cube& gamma_B_raw, const arma::cube& X_o, - const arma::mat& obs) { + const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, + const bool tv_A, const bool tv_B) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -34,9 +35,15 @@ Rcpp::List log_objective_nhmm_singlechannel( arma::mat A(S, S); arma::rowvec Brow(M); for (unsigned int i = 0; i < N; i++) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), 1); - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), 1, 1); + if (iv_pi || i == 0) { + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + } + if (iv_A || i == 0) { + log_A = get_A(gamma_A_raw, X_s.slice(i), true, tv_A); + } + if (iv_B || i == 0) { + log_B = get_B(gamma_B_raw, X_o.slice(i), true, true, tv_B); + } for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = log_B(s, obs(t, i), t); @@ -104,109 +111,121 @@ Rcpp::List log_objective_nhmm_multichannel( const arma::mat& gamma_pi_raw, const arma::mat& X_i, const arma::cube& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, - const arma::cube& obs, const arma::uvec& M) { - - unsigned int C = M.n_elem; - unsigned int N = X_s.n_slices; - unsigned int T = X_s.n_cols; - unsigned int S = gamma_A_raw.n_slices; - arma::vec loglik(N); - arma::mat log_alpha(S, T); - arma::mat log_beta(S, T); - arma::mat log_py(S, T); - arma::vec log_Pi(S); - arma::cube log_A(S, S, T); - arma::field log_B(C); - arma::mat grad_pi(S - 1, X_i.n_rows, arma::fill::zeros); - arma::cube grad_A(S - 1, X_s.n_rows, S, arma::fill::zeros); - arma::field grad_B(C); - for (unsigned int c = 0; c < C; c++) { - grad_B(c) = arma::cube(M(c) - 1, X_o.n_rows, S, arma::fill::zeros); - } - arma::vec gradvec_S(S); - arma::mat gradmat_S(S, S); - arma::mat A(S, S); - - for (unsigned int i = 0; i < N; i++) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), 1); - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), M, 1, 1); - for (unsigned int t = 0; t < T; t++) { - for (unsigned int s = 0; s < S; s++) { - log_py(s, t) = 0; - for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); - } - } - } - log_alpha = univariate_forward_nhmm(log_Pi, log_A, log_py); - log_beta = univariate_backward_nhmm(log_A, log_py); - double ll = logSumExp(log_alpha.col(T - 1)); - loglik(i) = ll; - // gradient wrt gamma_pi - // d loglik / d pi - gradvec_S = exp(log_py.col(0) + log_beta.col(0) - ll); - // d pi / d gamma_pi - arma::vec Pi = exp(log_Pi); - gradmat_S = -Pi * Pi.t(); - gradmat_S.diag() += Pi; - grad_pi += gradmat_S.rows(1, S - 1) * gradvec_S * X_i.col(i).t(); - - // gradient wrt gamma_A - for (unsigned int t = 0; t < (T - 1); t++) { - A = exp(log_A.slice(t)); - for (unsigned int s = 0; s < S; s++) { - // d loglik / d a_s - gradvec_S = exp(log_alpha(s, t) + log_py.col(t + 1) + log_beta.col(t + 1) - ll); - // d a_s / d gamma_A - gradmat_S = -A.row(s).t() * A.row(s); - gradmat_S.diag() += A.row(s); - grad_A.slice(s) += gradmat_S.rows(1, S - 1) * gradvec_S * X_s.slice(i).col(t).t(); - } - } - for (unsigned int c = 0; c < C; c++) { - arma::mat gradmat_M(M(c), M(c)); - arma::rowvec Brow(M(c)); - double logpy; - for (unsigned int s = 0; s < S; s++) { - if (obs(c, 0, i) < M(c)) { - Brow = exp(log_B(c).slice(0).row(s).cols(0, M(c) - 1)); - gradmat_M = -Brow.t() * Brow; - gradmat_M.diag() += Brow; - logpy = 0; - for (unsigned int cc = 0; cc < C; cc++) { - if (cc != c) { - logpy += log_B(cc).at(s, obs(cc, 0, i), 0); - } - } - double grad = exp(log_Pi(s) + logpy + log_beta(s, 0) - ll); - grad_B(c).slice(s) += gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * grad * X_o.slice(i).col(0).t(); - } - for (unsigned int t = 0; t < (T - 1); t++) { - if (obs(c, t + 1, i) < M(c)) { - Brow = exp(log_B(c).slice(t + 1).row(s).cols(0, M(c) - 1)); - gradmat_M = -Brow.t() * Brow; - gradmat_M.diag() += Brow; - logpy = 0; - for (unsigned int cc = 0; cc < C; cc++) { - if (cc != c) { - logpy += log_B(cc).at(s, obs(cc, t + 1, i), t + 1); - } - } - double grad = arma::accu( - exp(log_alpha.col(t) + log_A.slice(t).col(s) + logpy + log_beta(s, t + 1) - ll)); - grad_B(c).slice(s) += gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * grad * X_o.slice(i).col(t + 1).t(); - } - } - } - } - } - return Rcpp::List::create( - Rcpp::Named("loglik") = sum(loglik), - Rcpp::Named("gradient_pi") = Rcpp::wrap(grad_pi), - Rcpp::Named("gradient_A") = Rcpp::wrap(grad_A), - Rcpp::Named("gradient_B") = Rcpp::wrap(grad_B) - ); + const arma::cube& obs, const arma::uvec& M, const bool iv_pi, + const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B) { + + unsigned int C = M.n_elem; + unsigned int N = X_s.n_slices; + unsigned int T = X_s.n_cols; + unsigned int S = gamma_A_raw.n_slices; + arma::vec loglik(N); + arma::mat log_alpha(S, T); + arma::mat log_beta(S, T); + arma::mat log_py(S, T); + arma::vec log_Pi(S); + arma::cube log_A(S, S, T); + arma::field log_B(C); + arma::mat grad_pi(S - 1, X_i.n_rows, arma::fill::zeros); + arma::cube grad_A(S - 1, X_s.n_rows, S, arma::fill::zeros); + arma::field grad_B(C); + for (unsigned int c = 0; c < C; c++) { + grad_B(c) = arma::cube(M(c) - 1, X_o.n_rows, S, arma::fill::zeros); + } + arma::vec gradvec_S(S); + arma::mat gradmat_S(S, S); + arma::mat A(S, S); + for (unsigned int i = 0; i < N; i++) { + if (iv_pi || i == 0) { + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + } + if (iv_A || i == 0) { + log_A = get_A(gamma_A_raw, X_s.slice(i), true, tv_A); + } + if (iv_B || i == 0) { + log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true, tv_B); + } + for (unsigned int t = 0; t < T; t++) { + for (unsigned int s = 0; s < S; s++) { + log_py(s, t) = 0; + for (unsigned int c = 0; c < C; c++) { + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); + } + } + } + log_alpha = univariate_forward_nhmm(log_Pi, log_A, log_py); + log_beta = univariate_backward_nhmm(log_A, log_py); + double ll = logSumExp(log_alpha.col(T - 1)); + loglik(i) = ll; + // gradient wrt gamma_pi + // d loglik / d pi + gradvec_S = exp(log_py.col(0) + log_beta.col(0) - ll); + // d pi / d gamma_pi + arma::vec Pi = exp(log_Pi); + gradmat_S = -Pi * Pi.t(); + gradmat_S.diag() += Pi; + grad_pi += gradmat_S.rows(1, S - 1) * gradvec_S * X_i.col(i).t(); + // gradient wrt gamma_A + for (unsigned int t = 0; t < (T - 1); t++) { + A = exp(log_A.slice(t)); + for (unsigned int s = 0; s < S; s++) { + // d loglik / d a_s + gradvec_S = exp(log_alpha(s, t) + log_py.col(t + 1) + + log_beta.col(t + 1) - ll); + // d a_s / d gamma_A + gradmat_S = -A.row(s).t() * A.row(s); + gradmat_S.diag() += A.row(s); + grad_A.slice(s) += gradmat_S.rows(1, S - 1) * + gradvec_S * X_s.slice(i).col(t).t(); + } + } + for (unsigned int c = 0; c < C; c++) { + arma::mat gradmat_M(M(c), M(c)); + arma::rowvec Brow(M(c)); + double logpy; + for (unsigned int s = 0; s < S; s++) { + if (obs(c, 0, i) < M(c)) { + Brow = exp(log_B(c).slice(0).row(s).cols(0, M(c) - 1)); + gradmat_M = -Brow.t() * Brow; + gradmat_M.diag() += Brow; + logpy = 0; + for (unsigned int cc = 0; cc < C; cc++) { + if (cc != c) { + logpy += log_B(cc)(s, obs(cc, 0, i), 0); + } + } + double grad = exp(log_Pi(s) + logpy + log_beta(s, 0) - ll); + grad_B(c).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * + grad * X_o.slice(i).col(0).t(); + } + for (unsigned int t = 0; t < (T - 1); t++) { + if (obs(c, t + 1, i) < M(c)) { + Brow = exp(log_B(c).slice(t + 1).row(s).cols(0, M(c) - 1)); + gradmat_M = -Brow.t() * Brow; + gradmat_M.diag() += Brow; + logpy = 0; + for (unsigned int cc = 0; cc < C; cc++) { + if (cc != c) { + logpy += log_B(cc)(s, obs(cc, t + 1, i), t + 1); + } + } + double grad = arma::accu( + exp(log_alpha.col(t) + log_A.slice(t).col(s) + + logpy + log_beta(s, t + 1) - ll)); + grad_B(c).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * + grad * X_o.slice(i).col(t + 1).t(); + } + } + } + } + } + return Rcpp::List::create( + Rcpp::Named("loglik") = sum(loglik), + Rcpp::Named("gradient_pi") = Rcpp::wrap(grad_pi), + Rcpp::Named("gradient_A") = Rcpp::wrap(grad_A), + Rcpp::Named("gradient_B") = Rcpp::wrap(grad_B) + ); } // [[Rcpp::export]] @@ -215,7 +234,8 @@ Rcpp::List log_objective_mnhmm_singlechannel( const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, - const arma::mat& obs) { + const arma::mat& obs, const bool iv_pi, const bool iv_A, const bool iv_B, + const bool tv_A, const bool tv_B, const bool iv_omega) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -252,17 +272,26 @@ Rcpp::List log_objective_mnhmm_singlechannel( arma::rowvec Brow(M); arma::vec omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + if (iv_omega || i == 0) { + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); + } for (unsigned int d = 0; d < D; d++) { - log_Pi(d) = get_pi(gamma_pi_raw(d), X_i.col(i), 1); - log_A(d) = get_A(gamma_A_raw(d), X_s.slice(i), 1); - log_B(d) = get_B(gamma_B_raw(d), X_o.slice(i), 1, 1); + if (iv_pi || i == 0) { + log_Pi(d) = get_pi(gamma_pi_raw(d), X_i.col(i), true); + } + if (iv_A || i == 0) { + log_A(d) = get_A(gamma_A_raw(d), X_s.slice(i), true, tv_A); + } + if (iv_B || i == 0) { + log_B(d) = get_B(gamma_B_raw(d), X_o.slice(i), true, true, tv_B); + } for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { - log_py(s, t, d) = log_B(d).at(s, obs(t, i), t); + log_py(s, t, d) = log_B(d)(s, obs(t, i), t); } } - log_alpha.slice(d) = univariate_forward_nhmm(log_Pi(d), log_A(d), log_py.slice(d)); + log_alpha.slice(d) = univariate_forward_nhmm( + log_Pi(d), log_A(d), log_py.slice(d)); log_beta.slice(d) = univariate_backward_nhmm(log_A(d), log_py.slice(d)); loglik_i(d) = logSumExp(log_alpha.slice(d).col(T - 1)); } @@ -270,7 +299,8 @@ Rcpp::List log_objective_mnhmm_singlechannel( // gradient wrt gamma_pi // d loglik / d pi for (unsigned int d = 0; d < D; d++) { - gradvec_S = exp(log_omega(d) + log_py.slice(d).col(0) + log_beta.slice(d).col(0) - loglik(i)); + gradvec_S = exp(log_omega(d) + log_py.slice(d).col(0) + + log_beta.slice(d).col(0) - loglik(i)); // d pi / d gamma_pi Pi = exp(log_Pi(d)); gradmat_S = -Pi * Pi.t(); @@ -282,12 +312,14 @@ Rcpp::List log_objective_mnhmm_singlechannel( A = exp(log_A(d).slice(t)); for (unsigned int s = 0; s < S; s++) { // d loglik / d a_s - gradvec_S = exp(log_omega(d) + log_alpha(s, t, d) + log_py.slice(d).col(t + 1) + + gradvec_S = exp(log_omega(d) + log_alpha(s, t, d) + + log_py.slice(d).col(t + 1) + log_beta.slice(d).col(t + 1) - loglik(i)); // d a_s / d gamma_A gradmat_S = -A.row(s).t() * A.row(s); gradmat_S.diag() += A.row(s); - grad_A(d).slice(s) += gradmat_S.rows(1, S - 1) * gradvec_S * X_s.slice(i).col(t).t(); + grad_A(d).slice(s) += gradmat_S.rows(1, S - 1) * + gradvec_S * X_s.slice(i).col(t).t(); } } for (unsigned int s = 0; s < S; s++) { @@ -295,8 +327,10 @@ Rcpp::List log_objective_mnhmm_singlechannel( Brow = exp(log_B(d).slice(0).row(s).cols(0, M - 1)); gradmat_M = -Brow.t() * Brow; gradmat_M.diag() += Brow; - double grad = exp(log_omega(d) + log_Pi(d).at(s) + log_beta(s, 0, d) - loglik(i)); - grad_B(d).slice(s) += gradmat_M.rows(1, M - 1).col(obs(0, i)) * grad * X_o.slice(i).col(0).t(); + double grad = exp(log_omega(d) + log_Pi(d)(s) + + log_beta(s, 0, d) - loglik(i)); + grad_B(d).slice(s) += gradmat_M.rows(1, M - 1).col(obs(0, i)) * + grad * X_o.slice(i).col(0).t(); } for (unsigned int t = 0; t < (T - 1); t++) { if (obs(t + 1, i) < M) { @@ -304,8 +338,10 @@ Rcpp::List log_objective_mnhmm_singlechannel( gradmat_M = -Brow.t() * Brow; gradmat_M.diag() += Brow; double grad = arma::accu( - exp(log_omega(d) + log_alpha.slice(d).col(t) + log_A(d).slice(t).col(s) + log_beta(s, t + 1, d) - loglik(i))); - grad_B(d).slice(s) += gradmat_M.rows(1, M - 1).col(obs(t + 1, i)) * grad * X_o.slice(i).col(t + 1).t(); + exp(log_omega(d) + log_alpha.slice(d).col(t) + + log_A(d).slice(t).col(s) + log_beta(s, t + 1, d) - loglik(i))); + grad_B(d).slice(s) += gradmat_M.rows(1, M - 1).col(obs(t + 1, i)) * + grad * X_o.slice(i).col(t + 1).t(); } } } @@ -331,7 +367,9 @@ Rcpp::List log_objective_mnhmm_multichannel( const arma::field& gamma_A_raw, const arma::cube& X_s, const arma::field& gamma_B_raw, const arma::cube& X_o, const arma::mat& gamma_omega_raw, const arma::mat& X_d, - const arma::cube& obs, const arma::uvec& M) { + const arma::cube& obs, const arma::uvec& M, const bool iv_pi, + const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, + const bool iv_omega) { unsigned int N = X_s.n_slices; unsigned int T = X_s.n_cols; @@ -352,7 +390,7 @@ Rcpp::List log_objective_mnhmm_multichannel( arma::mat grad_omega(D - 1, X_d.n_rows, arma::fill::zeros); arma::field grad_pi(D); arma::field grad_A(D); - arma::field grad_B(D, C); + arma::field grad_B(C, D); for (unsigned int d = 0; d < D; d++) { grad_pi(d) = arma::mat(S - 1, X_i.n_rows, arma::fill::zeros); grad_A(d) = arma::cube(S - 1, X_s.n_rows, S, arma::fill::zeros); @@ -368,22 +406,32 @@ Rcpp::List log_objective_mnhmm_multichannel( arma::mat A(S, S); arma::vec omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + if (iv_omega || i == 0) { + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); + } for (unsigned int d = 0; d < D; d++) { - log_Pi(d) = get_pi(gamma_pi_raw(d), X_i.col(i), 1); - log_A(d) = get_A(gamma_A_raw(d), X_s.slice(i), 1); - log_B.rows(d * C, (d + 1) * C - 1) = get_B( - gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, 1, 1 - ); + if (iv_pi || i == 0) { + log_Pi(d) = get_pi(gamma_pi_raw(d), X_i.col(i), true); + } + if (iv_A || i == 0) { + log_A(d) = get_A(gamma_A_raw(d), X_s.slice(i), true, tv_A); + } + if (iv_B || i == 0) { + log_B.rows(d * C, (d + 1) * C - 1) = get_B( + gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, true, + true, tv_B + ); + } for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t, d) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t, d) += log_B(d * C + c).at(s, obs(c, t, i), t); + log_py(s, t, d) += log_B(d * C + c)(s, obs(c, t, i), t); } } } - log_alpha.slice(d) = univariate_forward_nhmm(log_Pi(d), log_A(d), log_py.slice(d)); + log_alpha.slice(d) = univariate_forward_nhmm(log_Pi(d), log_A(d), + log_py.slice(d)); log_beta.slice(d) = univariate_backward_nhmm(log_A(d), log_py.slice(d)); loglik_i(d) = logSumExp(log_alpha.slice(d).col(T - 1)); } @@ -391,7 +439,8 @@ Rcpp::List log_objective_mnhmm_multichannel( // gradient wrt gamma_pi // d loglik / d pi for (unsigned int d = 0; d < D; d++) { - gradvec_S = exp(log_omega(d) + log_py.slice(d).col(0) + log_beta.slice(d).col(0) - loglik(i)); + gradvec_S = exp(log_omega(d) + log_py.slice(d).col(0) + + log_beta.slice(d).col(0) - loglik(i)); // d pi / d gamma_pi Pi = exp(log_Pi(d)); gradmat_S = -Pi * Pi.t(); @@ -403,12 +452,14 @@ Rcpp::List log_objective_mnhmm_multichannel( A = exp(log_A(d).slice(t)); for (unsigned int s = 0; s < S; s++) { // d loglik / d a_s - gradvec_S = exp(log_omega(d) + log_alpha(s, t, d) + log_py.slice(d).col(t + 1) + + gradvec_S = exp(log_omega(d) + log_alpha(s, t, d) + + log_py.slice(d).col(t + 1) + log_beta.slice(d).col(t + 1) - loglik(i)); // d a_s / d gamma_A gradmat_S = -A.row(s).t() * A.row(s); gradmat_S.diag() += A.row(s); - grad_A(d).slice(s) += gradmat_S.rows(1, S - 1) * gradvec_S * X_s.slice(i).col(t).t(); + grad_A(d).slice(s) += gradmat_S.rows(1, S - 1) * gradvec_S * + X_s.slice(i).col(t).t(); } } // gradient wrt gamma_B @@ -424,11 +475,14 @@ Rcpp::List log_objective_mnhmm_multichannel( logpy = 0; for (unsigned int cc = 0; cc < C; cc++) { if (cc != c) { - logpy += log_B(d * C + cc).at(s, obs(cc, 0, i), 0); + logpy += log_B(d * C + cc)(s, obs(cc, 0, i), 0); } } - double grad = exp(log_omega(d) + log_Pi(d).at(s) + logpy + log_beta(s, 0, d) - loglik(i)); - grad_B(c, d).slice(s) += gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * grad * X_o.slice(i).col(0).t(); + double grad = exp(log_omega(d) + log_Pi(d)(s) + logpy + + log_beta(s, 0, d) - loglik(i)); + grad_B(c, d).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, 0, i)) * + grad * X_o.slice(i).col(0).t(); } for (unsigned int t = 0; t < (T - 1); t++) { if (obs(c, t + 1, i) < M(c)) { @@ -438,12 +492,15 @@ Rcpp::List log_objective_mnhmm_multichannel( logpy = 0; for (unsigned int cc = 0; cc < C; cc++) { if (cc != c) { - logpy += log_B(d * C + cc).at(s, obs(cc, t + 1, i), t + 1); + logpy += log_B(d * C + cc)(s, obs(cc, t + 1, i), t + 1); } } double grad = arma::accu( - exp(log_omega(d) + log_alpha.slice(d).col(t) + log_A(d).slice(t).col(s) + logpy + log_beta(s, t + 1, d) - loglik(i))); - grad_B(c, d).slice(s) += gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * grad * X_o.slice(i).col(t + 1).t(); + exp(log_omega(d) + log_alpha.slice(d).col(t) + + log_A(d).slice(t).col(s) + logpy + log_beta(s, t + 1, d) - loglik(i))); + grad_B(c, d).slice(s) += + gradmat_M.rows(1, M(c) - 1).col(obs(c, t + 1, i)) * + grad * X_o.slice(i).col(t + 1).t(); } } } diff --git a/src/softmax.cpp b/src/softmax.cpp index 6067e914..9edab1a7 100644 --- a/src/softmax.cpp +++ b/src/softmax.cpp @@ -3,7 +3,7 @@ #include "softmax.h" // [[Rcpp::export]] -arma::vec softmax(const arma::vec& x, const int logspace) { +arma::vec softmax(const arma::vec& x, const bool logspace) { arma::vec result; if (logspace == 0) { double x_max = arma::max(x); diff --git a/src/softmax.h b/src/softmax.h index 2cac5406..319a10a6 100644 --- a/src/softmax.h +++ b/src/softmax.h @@ -2,5 +2,5 @@ #define SOFTMAX_H #include -arma::vec softmax(const arma::vec& x, const int logspace); +arma::vec softmax(const arma::vec& x, const bool logspace); #endif diff --git a/src/viterbi_nhmm.cpp b/src/viterbi_nhmm.cpp index f47cec19..2ecb5013 100644 --- a/src/viterbi_nhmm.cpp +++ b/src/viterbi_nhmm.cpp @@ -45,12 +45,12 @@ Rcpp::List viterbi_nhmm_singlechannel( arma::cube log_A(S, S, T); arma::cube log_B(S, M + 1, T); for (unsigned int i = 0; i < N; i++) { - log_Pi = get_pi(gamma_pi_raw, X_i.col(i), 1); - log_A = get_A(gamma_A_raw, X_s.slice(i), 1); - log_B = get_B(gamma_B_raw, X_o.slice(i), 1, 1); + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { - log_py(s, t) = log_B.at(s, obs(t, i), t); + log_py(s, t) = log_B(s, obs(t, i), t); } } logp(i) = univariate_viterbi_nhmm(log_Pi, log_A, log_py, q.col(i)); @@ -79,11 +79,14 @@ Rcpp::List viterbi_nhmm_multichannel( arma::cube log_A(S, S, T); arma::field log_B(C); for (unsigned int i = 0; i < N; i++) { + log_Pi = get_pi(gamma_pi_raw, X_i.col(i), true); + log_A = get_A(gamma_A_raw, X_s.slice(i), true); + log_B = get_B(gamma_B_raw, X_o.slice(i), M, true, true); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(s, t) += log_B(c)(s, obs(c, t, i), t); } } } @@ -118,16 +121,16 @@ Rcpp::List viterbi_mnhmm_singlechannel( arma::cube log_B(SD, M + 1, T); arma::vec log_omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); for (unsigned int d = 0; d < D; d++) { log_Pi.rows(d * S, (d + 1) * S - 1) = log_omega(d) + get_pi( - gamma_pi_raw(d), X_i.col(i), 1 + gamma_pi_raw(d), X_i.col(i), true ); log_A.tube(d * S, d * S, (d + 1) * S - 1, (d + 1) * S - 1) = get_A( - gamma_A_raw(d), X_s.slice(i), 1 + gamma_A_raw(d), X_s.slice(i), true ); log_B.rows(d * S, (d + 1) * S - 1) = get_B( - gamma_B_raw(d), X_o.slice(i), 1, 1 + gamma_B_raw(d), X_o.slice(i), true, true ); } for (unsigned int t = 0; t < T; t++) { @@ -163,22 +166,22 @@ Rcpp::List viterbi_mnhmm_multichannel( arma::field log_B(C); arma::vec log_omega(D); for (unsigned int i = 0; i < N; i++) { - log_omega = get_omega(gamma_omega_raw, X_d.col(i), 1); + log_omega = get_omega(gamma_omega_raw, X_d.col(i), true); for (unsigned int d = 0; d < D; d++) { log_Pi.rows(d * S, (d + 1) * S - 1) = log_omega(d) + get_pi( - gamma_pi_raw(d), X_i.col(i), 1 + gamma_pi_raw(d), X_i.col(i), true ); log_A.tube(d * S, d * S, (d + 1) * S - 1, (d + 1) * S - 1) = get_A( - gamma_A_raw(d), X_s.slice(i), 1 + gamma_A_raw(d), X_s.slice(i), true ); log_B = get_B( - gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, 1, 1 + gamma_B_raw.rows(d * C, (d + 1) * C - 1), X_o.slice(i), M, true, true ); for (unsigned int t = 0; t < T; t++) { for (unsigned int s = 0; s < S; s++) { log_py(d * S + s, t) = 0; for (unsigned int c = 0; c < C; c++) { - log_py(d * S + s, t) += log_B(c).at(s, obs(c, t, i), t); + log_py(d * S + s, t) += log_B(c)(s, obs(c, t, i), t); } } } diff --git a/tests/testthat/test-build_mnhmm.R b/tests/testthat/test-build_mnhmm.R index a2508ced..e263ecdd 100644 --- a/tests/testthat/test-build_mnhmm.R +++ b/tests/testthat/test-build_mnhmm.R @@ -37,7 +37,7 @@ test_that("estimate_mnhmm returns object of class 'mnhmm'", { fit <- estimate_mnhmm( "y", s, d, initial_formula = ~ x, transition_formula = ~z, emission_formula = ~ z, cluster_formula = ~ x, - data = data, time = "time", id = "id"), + data = data, time = "time", id = "id", maxeval = 1), NA ) expect_s3_class( @@ -107,7 +107,8 @@ test_that("estimate_mnhmm errors with incorrect observations", { }) test_that("build_mnhmm works with vector of characters as observations", { expect_error( - model <- estimate_mnhmm("y", s, d, data = data, time = "time", id = "id"), + model <- estimate_mnhmm("y", s, d, data = data, time = "time", id = "id", + maxeval = 1), NA ) expect_error( @@ -125,7 +126,7 @@ test_that("build_mnhmm works with missing observations", { data$y[50:55] <- NA expect_error( model <- estimate_mnhmm( - "y", s, d, data = data, time = "time", id = "id"), + "y", s, d, data = data, time = "time", id = "id", maxeval = 1), NA ) expect_equal( diff --git a/tests/testthat/test-build_nhmm.R b/tests/testthat/test-build_nhmm.R index a0dd3fc7..975493f9 100644 --- a/tests/testthat/test-build_nhmm.R +++ b/tests/testthat/test-build_nhmm.R @@ -112,4 +112,4 @@ test_that("build_nhmm works with missing observations", { c(41L, 42L, 43L, 44L, 45L, 46L, 47L, 48L, 49L, 50L, 60L, 61L, 62L, 63L, 64L, 65L) ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-forward_backward.R b/tests/testthat/test-forward_backward.R index 5256e080..c97e3ac8 100644 --- a/tests/testthat/test-forward_backward.R +++ b/tests/testthat/test-forward_backward.R @@ -91,7 +91,8 @@ test_that("'forward_backward' works for multichannel 'mnhmm'", { set.seed(1) expect_error( fit <- estimate_mnhmm( - hmm_biofam$observations, n_states = 3, n_clusters = 2 + hmm_biofam$observations, n_states = 3, n_clusters = 2, + maxeval = 1 ), NA ) @@ -117,7 +118,7 @@ test_that("'forward_backward' works for single-channel 'mnhmm'", { expect_error( fit <- estimate_mnhmm( hmm_biofam$observations[[1]], n_states = 4, n_clusters = 2, - restarts = 2, threads = 1 + restarts = 2, threads = 1, maxeval = 1 ), NA ) diff --git a/tests/testthat/test-gradients.R b/tests/testthat/test-gradients.R new file mode 100644 index 00000000..2005ba16 --- /dev/null +++ b/tests/testthat/test-gradients.R @@ -0,0 +1,310 @@ +test_that("Gradients for singlechannel-NHMM are correct", { + set.seed(123) + M <- 4 + S <- 3 + n_id <- 5 + n_time <- 10 + obs <- suppressMessages(seqdef( + matrix( + sample(letters[1:M], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + data <- data.frame( + y = unlist(obs), + x = rnorm(n_id * n_time), + z = rnorm(n_id * n_time), + time = rep(1:n_time, each = n_id), + id = rep(1:n_id, n_time) + ) + data <- data[-10L, ] + data$y[10:15] <- NA + data$x[12] <- NA + model <- build_nhmm( + "y", S, initial_formula = ~ x, transition_formula = ~z, + emission_formula = ~ z, data = data, time = "time", id = "id") + + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + X_i <- model$X_initial + X_s <- model$X_transition + X_o <- model$X_emission + K_i <- nrow(X_i) + K_s <- nrow(X_s) + K_o <- nrow(X_o) + obs <- create_obsArray(model) + obs <- array(obs, dim(obs)[2:3]) + pars <- rnorm(n_i + n_s + n_o) + + f <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i) + gamma_A_raw <- create_gamma_A_raw_nhmm(pars[n_i + seq_len(n_s)], S, K_s) + gamma_B_raw <- create_gamma_B_raw_nhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o + ) + -log_objective_nhmm_singlechannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + obs, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + + } + g <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i) + gamma_A_raw <- create_gamma_A_raw_nhmm(pars[n_i + seq_len(n_s)], S, K_s) + gamma_B_raw <- create_gamma_B_raw_nhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o + ) + -unname(unlist(log_objective_nhmm_singlechannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + obs, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + } + expect_equal(g(pars), numDeriv::grad(f, pars)) +}) + +test_that("Gradients for multichannel-NHMM are correct", { + set.seed(123) + M <- c(2, 5) + S <- 3 + n_id <- 5 + n_time <- 15 + obs1 <- suppressMessages(seqdef( + matrix( + sample(letters[1:M[1]], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + obs2 <- suppressMessages(seqdef( + matrix( + sample(LETTERS[1:M[2]], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + data <- data.frame( + y1 = unlist(obs1), + y2 = unlist(obs2), + x = rnorm(n_id * n_time), + z = rnorm(n_id * n_time), + time = rep(1:n_time, each = n_id), + id = rep(1:n_id, n_time) + ) + + data <- data[-10L, ] + data$y1[10:15] <- NA + data$y2[12:20] <- NA + data$x[12] <- NA + + model <- build_nhmm( + c("y1", "y2"), S, initial_formula = ~ x, transition_formula = ~z, + emission_formula = ~ z, data = data, time = "time", id = "id") + + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + X_i <- model$X_initial + X_s <- model$X_transition + X_o <- model$X_emission + K_i <- nrow(X_i) + K_s <- nrow(X_s) + K_o <- nrow(X_o) + obs <- create_obsArray(model) + pars <- rnorm(n_i + n_s + n_o) + + f <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i) + gamma_A_raw <- create_gamma_A_raw_nhmm(pars[n_i + seq_len(n_s)], S, K_s) + gamma_B_raw <- create_gamma_multichannel_B_raw_nhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o + ) + -log_objective_nhmm_multichannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + + } + g <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i) + gamma_A_raw <- create_gamma_A_raw_nhmm(pars[n_i + seq_len(n_s)], S, K_s) + gamma_B_raw <- create_gamma_multichannel_B_raw_nhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o + ) + -unname(unlist(log_objective_nhmm_multichannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + } + expect_equal(g(pars), numDeriv::grad(f, pars)) +}) + +test_that("Gradients for singlechannel-NHMM are correct", { + set.seed(123) + M <- 4 + S <- 3 + D <- 2 + n_id <- 5 + n_time <- 10 + obs <- suppressMessages(seqdef( + matrix( + sample(letters[1:M], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + data <- data.frame( + y = unlist(obs), + x = rnorm(n_id * n_time), + z = rnorm(n_id * n_time), + time = rep(1:n_time, each = n_id), + id = rep(1:n_id, n_time) + ) + data <- data[-10L, ] + data$y[10:15] <- NA + data$x[12] <- NA + model <- build_mnhmm( + "y", S, D, initial_formula = ~ x, transition_formula = ~z, + emission_formula = ~ z, cluster_formula = ~ z, data = data, + time = "time", id = "id") + + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + n_d <- attr(model, "np_omega") + X_i <- model$X_initial + X_s <- model$X_transition + X_o <- model$X_emission + X_d <- model$X_cluster + K_i <- nrow(X_i) + K_s <- nrow(X_s) + K_o <- nrow(X_o) + K_d <- nrow(X_d) + obs <- create_obsArray(model) + obs <- array(obs, dim(obs)[2:3]) + pars <- rnorm(n_i + n_s + n_o + n_d) + + f <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) + gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) + gamma_B_raw <- create_gamma_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ) + gamma_omega_raw <- create_gamma_omega_raw_mnhmm( + pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d + ) + -log_objective_mnhmm_singlechannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + gamma_omega_raw, X_d, + obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + + } + g <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) + gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) + gamma_B_raw <- create_gamma_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ) + gamma_omega_raw <- create_gamma_omega_raw_mnhmm( + pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d + ) + -unname(unlist(log_objective_mnhmm_singlechannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + gamma_omega_raw, X_d, + obs, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + } + expect_equal(g(pars), numDeriv::grad(f, pars)) +}) +test_that("Gradients for multichannel-MNHMM are correct", { + set.seed(123) + M <- c(2, 5) + S <- 3 + D <- 4 + n_id <- 5 + n_time <- 15 + obs1 <- suppressMessages(seqdef( + matrix( + sample(letters[1:M[1]], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + obs2 <- suppressMessages(seqdef( + matrix( + sample(LETTERS[1:M[2]], n_id * n_time, replace = TRUE), + n_id, n_time + ) + )) + data <- data.frame( + y1 = unlist(obs1), + y2 = unlist(obs2), + x = rnorm(n_id * n_time), + z = rnorm(n_id * n_time), + time = rep(1:n_time, each = n_id), + id = rep(1:n_id, n_time) + ) + + data <- data[-10L, ] + data$y1[10:15] <- NA + data$x[12] <- NA + + model <- build_mnhmm( + c("y1", "y2"), S, D, initial_formula = ~ x, transition_formula = ~z, + emission_formula = ~ z, cluster_formula = ~ z, data = data, + time = "time", id = "id") + + n_i <- attr(model, "np_pi") + n_s <- attr(model, "np_A") + n_o <- attr(model, "np_B") + n_d <- attr(model, "np_omega") + X_i <- model$X_initial + X_s <- model$X_transition + X_o <- model$X_emission + X_d <- model$X_cluster + K_i <- nrow(X_i) + K_s <- nrow(X_s) + K_o <- nrow(X_o) + K_d <- nrow(X_d) + obs <- create_obsArray(model) + pars <- rnorm(n_i + n_s + n_o + n_d) + + f <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) + gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) + gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ) + gamma_omega_raw <- create_gamma_omega_raw_mnhmm( + pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d + ) + -log_objective_mnhmm_multichannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + gamma_omega_raw, X_d, + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)$loglik + + } + g <- function(pars) { + gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) + gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) + gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ) + gamma_omega_raw <- create_gamma_omega_raw_mnhmm( + pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d + ) + -unname(unlist(log_objective_mnhmm_multichannel( + gamma_pi_raw, X_i, + gamma_A_raw, X_s, + gamma_B_raw, X_o, + gamma_omega_raw, X_d, + obs, M, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)[-1])) + } + expect_equal(g(pars), numDeriv::grad(f, pars)) +}) + diff --git a/tests/testthat/test-ivX_and_tvX.R b/tests/testthat/test-ivX_and_tvX.R new file mode 100644 index 00000000..95488c29 --- /dev/null +++ b/tests/testthat/test-ivX_and_tvX.R @@ -0,0 +1,22 @@ +test_that("'iv_X' and 'tv_X' works", { + x1 <- array( + c(rbind(1:4, 1:4, 1:4), rbind(1:4, 1:4, 1:4)), c(3, 4, 2) + ) + expect_true(tv_X(x1)) + expect_false(iv_X(x1)) + x2 <- array( + c(rbind(1:4, 1:4, 1:4), -rbind(1:4, 1:4, 1:4)), c(3, 4, 2) + ) + expect_true(tv_X(x2)) + expect_true(iv_X(x2)) + x3 <- array( + c(cbind(1:4, 1:4, 1:4), -cbind(1:4, 1:4, 1:4)), c(4, 3, 2) + ) + expect_false(tv_X(x3)) + expect_true(iv_X(x3)) + x4 <- array( + c(cbind(1:4, 1:4, 1:4), cbind(1:4, 1:4, 1:4)), c(4, 3, 2) + ) + expect_false(tv_X(x4)) + expect_false(iv_X(x4)) +})