From 91260f7af446c46eb485246867c13136dd3f16c7 Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Tue, 24 Sep 2024 10:07:56 +0300 Subject: [PATCH] indexing, fixing bugs based on tests --- R/ame.R | 4 -- R/coef.R | 5 ++- R/create_initial_values.R | 19 ++++----- R/fit_mnhmm.R | 14 +++++-- R/forwardBackward.R | 5 ++- R/get_probs.R | 73 ++++++++++++++++++++------------- R/hidden_paths.R | 3 +- R/plot.ame.R | 1 + R/simulate_mnhmm.R | 11 +++-- R/simulate_nhmm.R | 1 - R/update.R | 6 +-- tests/testthat/test-gradients.R | 14 +++++-- 12 files changed, 90 insertions(+), 66 deletions(-) diff --git a/R/ame.R b/R/ame.R index c709a175..42892ad4 100644 --- a/R/ame.R +++ b/R/ame.R @@ -31,10 +31,6 @@ ame.nhmm <- function( marginalize_B_over != "clusters", "Cannot marginalize over clusters as {.arg model} is not a {.cls mnhmm} object." ) - stopifnot_( - checkmate::test_count(nsim), - "Argument {.arg nsim} should be a single non-negative integer." - ) stopifnot_( checkmate::test_string(x = variable), "Argument {.arg variable} must be a single character string." diff --git a/R/coef.R b/R/coef.R index 41eacd75..563d7602 100644 --- a/R/coef.R +++ b/R/coef.R @@ -119,12 +119,13 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) { cluster = rep(object$cluster_names, each = (S - 1) * S * K_s) ) K_o <- length(object$coef_names_emission) + gamma_B_raw <- unlist(object$coefficients$gamma_B_raw) if (object$n_channels == 1) { gamma_B <- data.frame( state = unlist(object$state_names), observations = rep(object$symbol_names[-1], each = S), parameter = rep(object$coef_names_emission, each = S * (M - 1)), - estimate = unlist(gamma_B_raw), + estimate = gamma_B_raw, cluster = rep(object$cluster_names, each = S * (S - 1) * K_o) ) } else { @@ -134,7 +135,7 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) { parameter = unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$coef_names_emission, each = S * (M[i] - 1)) })), - estimate = unlist(gamma_B_raw), + estimate = gamma_B_raw, cluster = unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$cluster_names, each = S * (M[i] - 1) * K_o) })) diff --git a/R/create_initial_values.R b/R/create_initial_values.R index 32f1fd0c..3fdc26e9 100644 --- a/R/create_initial_values.R +++ b/R/create_initial_values.R @@ -35,12 +35,9 @@ create_gamma_B_raw_mnhmm <- function(x, S, M, K, D) { } create_gamma_multichannel_B_raw_mnhmm <- function(x, S, M, K, D) { n <- sum((M - 1) * K * S) - unlist( - lapply(seq_len(D), function(i) { - create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K) - }), - recursive = FALSE - ) + lapply(seq_len(D), function(i) { + create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K) + }) } create_gamma_omega_raw_mnhmm <- function(x, D, K) { matrix(x, D - 1, K) @@ -181,7 +178,7 @@ create_gamma_omega_inits <- function(x, D, K, init_sd = 0) { "(D - 1) * K = {(D - 1) * K}." ) ) - create_gamma_omega_raw_nhmm(x, D, K) + create_gamma_omega_raw_mnhmm(x, D, K) } } #' Convert Initial Values for Inverse Softmax Scale @@ -206,7 +203,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0, if(!is.null(inits$initial_probs)) { if (D > 1) { gamma_pi_raw <- lapply( - seq_len(d), function(i) { + seq_len(D), function(i) { create_inits_vector(inits$initial_probs[[i]], S, K_i, init_sd) } ) @@ -222,7 +219,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0, if(!is.null(inits$transition_probs)) { if (D > 1) { gamma_A_raw <- lapply( - seq_len(d), function(i) { + seq_len(D), function(i) { create_inits_matrix(inits$transition_probs[[i]], S, S, K_s, init_sd) } ) @@ -239,7 +236,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0, if (D > 1) { if (length(M) > 1) { gamma_B_raw <- lapply( - seq_len(d), function(i) { + seq_len(D), function(i) { lapply(seq_len(length(M)), function(j) { create_inits_matrix( inits$emission_probs[[i]][[j]], S, M[j], K_o, init_sd) @@ -247,7 +244,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0, }) } else { gamma_B_raw <- lapply( - seq_len(d), function(i) { + seq_len(D), function(i) { create_inits_matrix(inits$emission_probs[[i]], S, M, K_o, init_sd) } ) diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 2a04d463..41339f25 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -96,8 +96,11 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { pars[n_i + seq_len(n_s)], S, K_s, D ) - gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( - pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + gamma_B_raw <- unlist( + create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ), + recursive = FALSE ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d @@ -115,9 +118,12 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { gamma_A_raw <- create_gamma_A_raw_mnhmm( pars[n_i + seq_len(n_s)], S, K_s, D ) - gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( + gamma_B_raw <- unlist( + create_gamma_multichannel_B_raw_mnhmm( pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D - ) + ), + recursive = FALSE + ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d ) diff --git a/R/forwardBackward.R b/R/forwardBackward.R index a3af786b..f416ef88 100644 --- a/R/forwardBackward.R +++ b/R/forwardBackward.R @@ -238,16 +238,17 @@ forward_backward.mnhmm <- function(model, forward_only = FALSE, sequence_names <- seq_len(model$n_sequences) } } else { + gamma_B_raw <- unlist(model$coefficients$gamma_B_raw, recursive = FALSE) out$forward_probs <- forward_mnhmm_multichannel( model$coefficients$gamma_pi_raw, model$X_initial, model$coefficients$gamma_A_raw, model$X_transition, - model$coefficients$gamma_B_raw, model$X_emission, + gamma_B_raw, model$X_emission, model$coefficients$gamma_omega_raw, model$X_cluster, obsArray, model$n_symbols) if (!forward_only) { out$backward_probs <- backward_mnhmm_multichannel( model$coefficients$gamma_A_raw, model$X_transition, - model$coefficients$gamma_B_raw, model$X_emission, + gamma_B_raw, model$X_emission, model$coefficients$gamma_omega_raw, model$X_cluster, obsArray, model$n_symbols) } diff --git a/R/get_probs.R b/R/get_probs.R index c5089b03..617e87e4 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -19,6 +19,8 @@ get_cluster_probs <- function(model, ...) { UseMethod("get_cluster_probs", model) } #' Extract the Initial State Probabilities of Hidden Markov Model +#' @param model A hidden Markov model. +#' @param ... Ignored. #' @rdname initial_probs #' @export get_initial_probs.nhmm <- function(model, ...) { @@ -100,9 +102,13 @@ get_initial_probs.mhmm <- function(model, ...) { model$initial_probs } #' Extract the State Transition Probabilities of Hidden Markov Model +#' @param model A hidden Markov model. +#' @param ... Ignored. #' @rdname transition_probs #' @export get_transition_probs.nhmm <- function(model, ...) { + S <- model$n_states + T_ <- model$length_of_sequences if (model$n_channels == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) @@ -110,8 +116,6 @@ get_transition_probs.nhmm <- function(model, ...) { ids <- rownames(model$observations[[1]]) times <- colnames(model$observations[[1]]) } - S <- model$n_states - T_ <- model$length_of_sequences if (!attr(model, "iv_A")) { X <- matrix(model$X_transition[, , 1L], ncol = model$length_of_sequences) d <- data.frame( @@ -130,9 +134,9 @@ get_transition_probs.nhmm <- function(model, ...) { state_from = model$state_names, state_to = rep(model$state_names, each = S), estimate = c(apply( - model$X_transition, 2, function(z) { + model$X_transition, 3, function(z) { get_A( - model$coefficients$gamma_A_raw, matrix(z, ncol = T), FALSE, + model$coefficients$gamma_A_raw, matrix(z, ncol = T_), FALSE, attr(model, "tv_A") ) } @@ -150,6 +154,9 @@ get_transition_probs.nhmm <- function(model, ...) { #' @rdname transition_probs #' @export get_transition_probs.mnhmm <- function(model, ...) { + S <- model$n_states + T_ <- model$length_of_sequences + D <- model$n_clusters if (model$n_channels == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) @@ -157,15 +164,13 @@ get_transition_probs.mnhmm <- function(model, ...) { ids <- rownames(model$observations[[1]]) times <- colnames(model$observations[[1]]) } - S <- model$n_states - T_ <- model$length_of_sequences - D <- model$n_clusters if (!attr(model, "iv_A")) { X <- matrix(model$X_transition[, , 1L], ncol = model$length_of_sequences) d <- do.call( rbind, lapply(seq_len(D), function(i) { data.frame( + cluster = model$cluster_names[i], id = rep(ids, each = S^2 * T_), time = rep(times, each = S^2), state_from = model$state_names[[i]], @@ -181,14 +186,15 @@ get_transition_probs.mnhmm <- function(model, ...) { rbind, lapply(seq_len(D), function(i) { data.frame( + cluster = model$cluster_names[i], id = rep(ids, each = S^2 * T_), time = rep(times, each = S^2), state_from = model$state_names[[i]], state_to = rep(model$state_names[[i]], each = S), estimate = c(apply( - model$X_transition, 2, function(z) { + model$X_transition, 3, function(z) { get_A( - model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T), FALSE, + model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T_), FALSE, attr(model, "tv_A") ) } @@ -214,10 +220,16 @@ get_transition_probs.mhmm <- function(model, ...) { model$transition_probs } #' Extract the Emission Probabilities of Hidden Markov Model +#' @param model A hidden Markov model. +#' @param ... Ignored. #' @rdname emission_probs #' @export get_emission_probs.nhmm <- function(model, ...) { - if (model$n_channels == 1L) { + S <- model$n_states + C <- model$n_channels + T_ <- model$length_of_sequences + M <- model$n_symbols + if (C == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) symbol_names <- list(model$symbol_names) @@ -227,10 +239,6 @@ get_emission_probs.nhmm <- function(model, ...) { times <- colnames(model$observations[[1]]) symbol_names <- model$symbol_names } - S <- model$n_states - C <- model$n_channels - T_ <- model$length_of_sequences - M <- model$n_symbols if (!attr(model, "iv_B")) { X <- matrix(model$X_emission[, , 1L], ncol = model$length_of_sequences) d <- do.call( @@ -260,9 +268,9 @@ get_emission_probs.nhmm <- function(model, ...) { channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), estimate = apply( - model$X_emission, 2, function(z) { + model$X_emission, 3, function(z) { unlist(get_B( - model$coefficients$gamma_B_raw[i], matrix(z, ncol = T), M[i], + model$coefficients$gamma_B_raw[i], matrix(z, ncol = T_), M[i], FALSE, FALSE, attr(model, "tv_B") )) } @@ -280,21 +288,24 @@ get_emission_probs.nhmm <- function(model, ...) { #' @rdname emission_probs #' @export get_emission_probs.mnhmm <- function(model, ...) { - if (model$n_channels == 1L) { + S <- model$n_states + C <- model$n_channels + D <- model$n_clusters + T_ <- model$length_of_sequences + M <- model$n_symbols + if (C == 1L) { ids <- rownames(model$observations) times <- colnames(model$observations) symbol_names <- list(model$symbol_names) - model$coefficients$gamma_B_raw <- list(model$coefficients$gamma_B_raw) + for (i in seq_len(D)) { + model$coefficients$gamma_B_raw[[i]] <- + list(model$coefficients$gamma_B_raw[[i]]) + } } else { ids <- rownames(model$observations[[1]]) times <- colnames(model$observations[[1]]) symbol_names <- model$symbol_names } - S <- model$n_states - C <- model$n_channels - D <- model$n_clusters - T_ <- model$length_of_sequences - M <- model$n_symbols if (!attr(model, "iv_B")) { X <- matrix(model$X_emission[, , 1L], ncol = model$length_of_sequences) d <- do.call( @@ -304,7 +315,7 @@ get_emission_probs.mnhmm <- function(model, ...) { rbind, lapply(seq_len(C), function(i) { data.frame( - cluster = cluster_names[[j]], + cluster = model$cluster_names[j], id = rep(ids, each = S * M[i] * T_), time = rep(times, each = S * M[i]), state = model$state_names[[j]], @@ -327,17 +338,17 @@ get_emission_probs.mnhmm <- function(model, ...) { rbind, lapply(seq_len(C), function(i) { data.frame( - cluster = cluster_names[[j]], + cluster = model$cluster_names[j], id = rep(ids, each = S * M[i] * T_), time = rep(times, each = S * M[i]), state = model$state_names[[j]], channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), estimate = apply( - model$X_emission, 2, function(z) { + model$X_emission, 3, function(z) { unlist(get_B( model$coefficients$gamma_B_raw[[j]][i], - matrix(z, ncol = T), M[i], FALSE, FALSE, attr(model, "tv_B") + matrix(z, ncol = T_), M[i], FALSE, FALSE, attr(model, "tv_B") )) } ) @@ -365,7 +376,8 @@ get_emission_probs.mhmm <- function(model, ...) { } #' Extract the Prior Cluster Probabilities of MHMM or MNHMM #' -#' @param model An object of class `mnhmm` or `mhmm. +#' @param model A mixture hidden Markov model. +#' @param ... Ignored. #' @rdname cluster_probs #' @export #' @seealso [posterior_cluster_probabilities()]. @@ -401,6 +413,11 @@ get_cluster_probs.mnhmm <- function(model, ...) { get_cluster_probs.mhmm <- function(model, ...) { pr <- exp(model$X %*% model$coefficients) prior_cluster_probabilities <- pr / rowSums(pr) + if (model$n_channels == 1L) { + ids <- rownames(model$observations) + } else { + ids <- rownames(model$observations[[1]]) + } data.frame( cluster = model$cluster_names, id = rep(ids, each = model$n_clusters), diff --git a/R/hidden_paths.R b/R/hidden_paths.R index 6c957da2..c3d08704 100644 --- a/R/hidden_paths.R +++ b/R/hidden_paths.R @@ -104,7 +104,8 @@ hidden_paths.mnhmm <- function(model, respect_void = TRUE, ...) { out <- viterbi_mnhmm_multichannel( model$coefficients$gamma_pi_raw, model$X_initial, model$coefficients$gamma_A_raw, model$X_transition, - model$coefficients$gamma_B_raw, model$X_emission, + unlist(model$coefficients$gamma_B_raw, recursive = FALSE), + model$X_emission, model$coefficients$gamma_omega_raw, model$X_cluster, obsArray, model$n_symbols) } diff --git a/R/plot.ame.R b/R/plot.ame.R index 253fe034..4168ce3d 100644 --- a/R/plot.ame.R +++ b/R/plot.ame.R @@ -6,6 +6,7 @@ #' @param probs A numeric vector of length 2 with the lower and upper limits for #' confidence intervals. Default is `c(0.025, 0.975)`. If the limits are not #' found in the input object `x`, an error is thrown. +#' @param ... Ignored. #' @export plot.ame <- function(x, type, probs = c(0.025, 0.975), ...) { diff --git a/R/simulate_mnhmm.R b/R/simulate_mnhmm.R index 168e0daa..623c0586 100644 --- a/R/simulate_mnhmm.R +++ b/R/simulate_mnhmm.R @@ -77,14 +77,13 @@ simulate_mnhmm <- function( if (is.null(coefs$emission_probs)) coefs$emission_probs <- NULL if (is.null(coefs$cluster_probs)) coefs$cluster_probs <- NULL } - K_i <- dim(model$X_initial)[2] - K_s <- dim(model$X_transition)[3] - K_o <- dim(model$X_emission)[3] - K_d <- dim(model$X_cluster)[2] + K_i <- nrow(model$X_initial) + K_s <- nrow(model$X_transition) + K_o <- nrow(model$X_emission) + K_d <- nrow(model$X_cluster) model$coefficients <- create_initial_values( coefs, n_states, n_symbols, init_sd, K_i, K_s, K_o, K_d, n_clusters ) - probs <- get_probs(model) states <- array(NA_character_, c(max(sequence_lengths), n_sequences)) obs <- array(NA_character_, c(max(sequence_lengths), n_channels, n_sequences)) @@ -93,7 +92,7 @@ simulate_mnhmm <- function( clusters <- character(n_sequences) cluster_names <- model$cluster_names state_names <- paste0( - rep(cluster_names, each = model$n_states), ": ", model$state_names + rep(cluster_names, each = model$n_states), ": ", unlist(model$state_names) ) for (i in seq_len(n_sequences)) { p_cluster <- probs$cluster[ diff --git a/R/simulate_nhmm.R b/R/simulate_nhmm.R index 9c3a2aa8..1cbc146e 100644 --- a/R/simulate_nhmm.R +++ b/R/simulate_nhmm.R @@ -75,7 +75,6 @@ simulate_nhmm <- function( model$coefficients <- create_initial_values( coefs, n_states, n_symbols, init_sd, K_i, K_s, K_o, 0, 0 ) - model$stan_model <- stanmodels[[attr(model, "type")]] probs <- get_probs(model) states <- array(NA_character_, c(max(sequence_lengths), n_sequences)) obs <- array(NA_character_, c(max(sequence_lengths), n_channels, n_sequences)) diff --git a/R/update.R b/R/update.R index 64812527..2fcaccee 100644 --- a/R/update.R +++ b/R/update.R @@ -18,7 +18,7 @@ update.nhmm <- function(object, newdata, ...) { object$id_variable ) object$X_initial <- X$X - attr(object, "iv_pi") <- x$iv + attr(object, "iv_pi") <- X$iv X <- model_matrix_transition_formula( object$transition_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, @@ -51,7 +51,7 @@ update.mnhmm <- function(object, newdata, ...) { object$id_variable ) object$X_initial <- X$X - attr(object, "iv_pi") <- x$iv + attr(object, "iv_pi") <- X$iv X <- model_matrix_transition_formula( object$transition_formula, newdata, object$n_sequences, object$length_of_sequences, object$n_states, object$time_variable, @@ -74,6 +74,6 @@ update.mnhmm <- function(object, newdata, ...) { object$time_variable, object$id_variable ) object$X_cluster <- X$X - attr(object, "iv_omega") <- x$iv + attr(object, "iv_omega") <- X$iv object } \ No newline at end of file diff --git a/tests/testthat/test-gradients.R b/tests/testthat/test-gradients.R index 2005ba16..90beeec4 100644 --- a/tests/testthat/test-gradients.R +++ b/tests/testthat/test-gradients.R @@ -275,8 +275,11 @@ test_that("Gradients for multichannel-MNHMM are correct", { f <- function(pars) { gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) - gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( - pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + gamma_B_raw <- unlist( + create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ), + recursive = FALSE ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d @@ -292,8 +295,11 @@ test_that("Gradients for multichannel-MNHMM are correct", { g <- function(pars) { gamma_pi_raw <- create_gamma_pi_raw_mnhmm(pars[seq_len(n_i)], S, K_i, D) gamma_A_raw <- create_gamma_A_raw_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D) - gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm( - pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + gamma_B_raw <- unlist( + create_gamma_multichannel_B_raw_mnhmm( + pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D + ), + recursive = FALSE ) gamma_omega_raw <- create_gamma_omega_raw_mnhmm( pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d