Skip to content

Commit

Permalink
indexing, fixing bugs based on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 24, 2024
1 parent 4dcc696 commit 91260f7
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 66 deletions.
4 changes: 0 additions & 4 deletions R/ame.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ ame.nhmm <- function(
marginalize_B_over != "clusters",
"Cannot marginalize over clusters as {.arg model} is not a {.cls mnhmm} object."
)
stopifnot_(
checkmate::test_count(nsim),
"Argument {.arg nsim} should be a single non-negative integer."
)
stopifnot_(
checkmate::test_string(x = variable),
"Argument {.arg variable} must be a single character string."
Expand Down
5 changes: 3 additions & 2 deletions R/coef.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) {
cluster = rep(object$cluster_names, each = (S - 1) * S * K_s)
)
K_o <- length(object$coef_names_emission)
gamma_B_raw <- unlist(object$coefficients$gamma_B_raw)
if (object$n_channels == 1) {
gamma_B <- data.frame(
state = unlist(object$state_names),
observations = rep(object$symbol_names[-1], each = S),
parameter = rep(object$coef_names_emission, each = S * (M - 1)),
estimate = unlist(gamma_B_raw),
estimate = gamma_B_raw,
cluster = rep(object$cluster_names, each = S * (S - 1) * K_o)
)
} else {
Expand All @@ -134,7 +135,7 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) {
parameter = unlist(lapply(seq_len(object$n_channels), function(i) {
rep(object$coef_names_emission, each = S * (M[i] - 1))
})),
estimate = unlist(gamma_B_raw),
estimate = gamma_B_raw,
cluster = unlist(lapply(seq_len(object$n_channels), function(i) {
rep(object$cluster_names, each = S * (M[i] - 1) * K_o)
}))
Expand Down
19 changes: 8 additions & 11 deletions R/create_initial_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ create_gamma_B_raw_mnhmm <- function(x, S, M, K, D) {
}
create_gamma_multichannel_B_raw_mnhmm <- function(x, S, M, K, D) {
n <- sum((M - 1) * K * S)
unlist(
lapply(seq_len(D), function(i) {
create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K)
}),
recursive = FALSE
)
lapply(seq_len(D), function(i) {
create_gamma_multichannel_B_raw_nhmm(x[(i - 1) * n + 1:n], S, M, K)
})
}
create_gamma_omega_raw_mnhmm <- function(x, D, K) {
matrix(x, D - 1, K)
Expand Down Expand Up @@ -181,7 +178,7 @@ create_gamma_omega_inits <- function(x, D, K, init_sd = 0) {
"(D - 1) * K = {(D - 1) * K}."
)
)
create_gamma_omega_raw_nhmm(x, D, K)
create_gamma_omega_raw_mnhmm(x, D, K)
}
}
#' Convert Initial Values for Inverse Softmax Scale
Expand All @@ -206,7 +203,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0,
if(!is.null(inits$initial_probs)) {
if (D > 1) {
gamma_pi_raw <- lapply(
seq_len(d), function(i) {
seq_len(D), function(i) {
create_inits_vector(inits$initial_probs[[i]], S, K_i, init_sd)
}
)
Expand All @@ -222,7 +219,7 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0,
if(!is.null(inits$transition_probs)) {
if (D > 1) {
gamma_A_raw <- lapply(
seq_len(d), function(i) {
seq_len(D), function(i) {
create_inits_matrix(inits$transition_probs[[i]], S, S, K_s, init_sd)
}
)
Expand All @@ -239,15 +236,15 @@ create_initial_values <- function(inits, S, M, init_sd, K_i, K_s, K_o, K_d = 0,
if (D > 1) {
if (length(M) > 1) {
gamma_B_raw <- lapply(
seq_len(d), function(i) {
seq_len(D), function(i) {
lapply(seq_len(length(M)), function(j) {
create_inits_matrix(
inits$emission_probs[[i]][[j]], S, M[j], K_o, init_sd)
})
})
} else {
gamma_B_raw <- lapply(
seq_len(d), function(i) {
seq_len(D), function(i) {
create_inits_matrix(inits$emission_probs[[i]], S, M, K_o, init_sd)
}
)
Expand Down
14 changes: 10 additions & 4 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
pars[n_i + seq_len(n_s)],
S, K_s, D
)
gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm(
pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D
gamma_B_raw <- unlist(
create_gamma_multichannel_B_raw_mnhmm(
pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D
),
recursive = FALSE
)
gamma_omega_raw <- create_gamma_omega_raw_mnhmm(
pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d
Expand All @@ -115,9 +118,12 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
gamma_A_raw <- create_gamma_A_raw_mnhmm(
pars[n_i + seq_len(n_s)], S, K_s, D
)
gamma_B_raw <- create_gamma_multichannel_B_raw_mnhmm(
gamma_B_raw <- unlist(
create_gamma_multichannel_B_raw_mnhmm(
pars[n_i + n_s + seq_len(n_o)], S, M, K_o, D
)
),
recursive = FALSE
)
gamma_omega_raw <- create_gamma_omega_raw_mnhmm(
pars[n_i + n_s + n_o + seq_len(n_d)], D, K_d
)
Expand Down
5 changes: 3 additions & 2 deletions R/forwardBackward.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,17 @@ forward_backward.mnhmm <- function(model, forward_only = FALSE,
sequence_names <- seq_len(model$n_sequences)
}
} else {
gamma_B_raw <- unlist(model$coefficients$gamma_B_raw, recursive = FALSE)
out$forward_probs <- forward_mnhmm_multichannel(
model$coefficients$gamma_pi_raw, model$X_initial,
model$coefficients$gamma_A_raw, model$X_transition,
model$coefficients$gamma_B_raw, model$X_emission,
gamma_B_raw, model$X_emission,
model$coefficients$gamma_omega_raw, model$X_cluster,
obsArray, model$n_symbols)
if (!forward_only) {
out$backward_probs <- backward_mnhmm_multichannel(
model$coefficients$gamma_A_raw, model$X_transition,
model$coefficients$gamma_B_raw, model$X_emission,
gamma_B_raw, model$X_emission,
model$coefficients$gamma_omega_raw, model$X_cluster,
obsArray, model$n_symbols)
}
Expand Down
73 changes: 45 additions & 28 deletions R/get_probs.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ get_cluster_probs <- function(model, ...) {
UseMethod("get_cluster_probs", model)
}
#' Extract the Initial State Probabilities of Hidden Markov Model
#' @param model A hidden Markov model.
#' @param ... Ignored.
#' @rdname initial_probs
#' @export
get_initial_probs.nhmm <- function(model, ...) {
Expand Down Expand Up @@ -100,18 +102,20 @@ get_initial_probs.mhmm <- function(model, ...) {
model$initial_probs
}
#' Extract the State Transition Probabilities of Hidden Markov Model
#' @param model A hidden Markov model.
#' @param ... Ignored.
#' @rdname transition_probs
#' @export
get_transition_probs.nhmm <- function(model, ...) {
S <- model$n_states
T_ <- model$length_of_sequences
if (model$n_channels == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
}
S <- model$n_states
T_ <- model$length_of_sequences
if (!attr(model, "iv_A")) {
X <- matrix(model$X_transition[, , 1L], ncol = model$length_of_sequences)
d <- data.frame(
Expand All @@ -130,9 +134,9 @@ get_transition_probs.nhmm <- function(model, ...) {
state_from = model$state_names,
state_to = rep(model$state_names, each = S),
estimate = c(apply(
model$X_transition, 2, function(z) {
model$X_transition, 3, function(z) {
get_A(
model$coefficients$gamma_A_raw, matrix(z, ncol = T), FALSE,
model$coefficients$gamma_A_raw, matrix(z, ncol = T_), FALSE,
attr(model, "tv_A")
)
}
Expand All @@ -150,22 +154,23 @@ get_transition_probs.nhmm <- function(model, ...) {
#' @rdname transition_probs
#' @export
get_transition_probs.mnhmm <- function(model, ...) {
S <- model$n_states
T_ <- model$length_of_sequences
D <- model$n_clusters
if (model$n_channels == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
}
S <- model$n_states
T_ <- model$length_of_sequences
D <- model$n_clusters
if (!attr(model, "iv_A")) {
X <- matrix(model$X_transition[, , 1L], ncol = model$length_of_sequences)
d <- do.call(
rbind,
lapply(seq_len(D), function(i) {
data.frame(
cluster = model$cluster_names[i],
id = rep(ids, each = S^2 * T_),
time = rep(times, each = S^2),
state_from = model$state_names[[i]],
Expand All @@ -181,14 +186,15 @@ get_transition_probs.mnhmm <- function(model, ...) {
rbind,
lapply(seq_len(D), function(i) {
data.frame(
cluster = model$cluster_names[i],
id = rep(ids, each = S^2 * T_),
time = rep(times, each = S^2),
state_from = model$state_names[[i]],
state_to = rep(model$state_names[[i]], each = S),
estimate = c(apply(
model$X_transition, 2, function(z) {
model$X_transition, 3, function(z) {
get_A(
model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T), FALSE,
model$coefficients$gamma_A_raw[[i]], matrix(z, ncol = T_), FALSE,
attr(model, "tv_A")
)
}
Expand All @@ -214,10 +220,16 @@ get_transition_probs.mhmm <- function(model, ...) {
model$transition_probs
}
#' Extract the Emission Probabilities of Hidden Markov Model
#' @param model A hidden Markov model.
#' @param ... Ignored.
#' @rdname emission_probs
#' @export
get_emission_probs.nhmm <- function(model, ...) {
if (model$n_channels == 1L) {
S <- model$n_states
C <- model$n_channels
T_ <- model$length_of_sequences
M <- model$n_symbols
if (C == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
symbol_names <- list(model$symbol_names)
Expand All @@ -227,10 +239,6 @@ get_emission_probs.nhmm <- function(model, ...) {
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names
}
S <- model$n_states
C <- model$n_channels
T_ <- model$length_of_sequences
M <- model$n_symbols
if (!attr(model, "iv_B")) {
X <- matrix(model$X_emission[, , 1L], ncol = model$length_of_sequences)
d <- do.call(
Expand Down Expand Up @@ -260,9 +268,9 @@ get_emission_probs.nhmm <- function(model, ...) {
channel = model$channel_names[i],
observation = rep(symbol_names[[i]], each = S),
estimate = apply(
model$X_emission, 2, function(z) {
model$X_emission, 3, function(z) {
unlist(get_B(
model$coefficients$gamma_B_raw[i], matrix(z, ncol = T), M[i],
model$coefficients$gamma_B_raw[i], matrix(z, ncol = T_), M[i],
FALSE, FALSE, attr(model, "tv_B")
))
}
Expand All @@ -280,21 +288,24 @@ get_emission_probs.nhmm <- function(model, ...) {
#' @rdname emission_probs
#' @export
get_emission_probs.mnhmm <- function(model, ...) {
if (model$n_channels == 1L) {
S <- model$n_states
C <- model$n_channels
D <- model$n_clusters
T_ <- model$length_of_sequences
M <- model$n_symbols
if (C == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
symbol_names <- list(model$symbol_names)
model$coefficients$gamma_B_raw <- list(model$coefficients$gamma_B_raw)
for (i in seq_len(D)) {
model$coefficients$gamma_B_raw[[i]] <-
list(model$coefficients$gamma_B_raw[[i]])
}
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names
}
S <- model$n_states
C <- model$n_channels
D <- model$n_clusters
T_ <- model$length_of_sequences
M <- model$n_symbols
if (!attr(model, "iv_B")) {
X <- matrix(model$X_emission[, , 1L], ncol = model$length_of_sequences)
d <- do.call(
Expand All @@ -304,7 +315,7 @@ get_emission_probs.mnhmm <- function(model, ...) {
rbind,
lapply(seq_len(C), function(i) {
data.frame(
cluster = cluster_names[[j]],
cluster = model$cluster_names[j],
id = rep(ids, each = S * M[i] * T_),
time = rep(times, each = S * M[i]),
state = model$state_names[[j]],
Expand All @@ -327,17 +338,17 @@ get_emission_probs.mnhmm <- function(model, ...) {
rbind,
lapply(seq_len(C), function(i) {
data.frame(
cluster = cluster_names[[j]],
cluster = model$cluster_names[j],
id = rep(ids, each = S * M[i] * T_),
time = rep(times, each = S * M[i]),
state = model$state_names[[j]],
channel = model$channel_names[i],
observation = rep(symbol_names[[i]], each = S),
estimate = apply(
model$X_emission, 2, function(z) {
model$X_emission, 3, function(z) {
unlist(get_B(
model$coefficients$gamma_B_raw[[j]][i],
matrix(z, ncol = T), M[i], FALSE, FALSE, attr(model, "tv_B")
matrix(z, ncol = T_), M[i], FALSE, FALSE, attr(model, "tv_B")
))
}
)
Expand Down Expand Up @@ -365,7 +376,8 @@ get_emission_probs.mhmm <- function(model, ...) {
}
#' Extract the Prior Cluster Probabilities of MHMM or MNHMM
#'
#' @param model An object of class `mnhmm` or `mhmm.
#' @param model A mixture hidden Markov model.
#' @param ... Ignored.
#' @rdname cluster_probs
#' @export
#' @seealso [posterior_cluster_probabilities()].
Expand Down Expand Up @@ -401,6 +413,11 @@ get_cluster_probs.mnhmm <- function(model, ...) {
get_cluster_probs.mhmm <- function(model, ...) {
pr <- exp(model$X %*% model$coefficients)
prior_cluster_probabilities <- pr / rowSums(pr)
if (model$n_channels == 1L) {
ids <- rownames(model$observations)
} else {
ids <- rownames(model$observations[[1]])
}
data.frame(
cluster = model$cluster_names,
id = rep(ids, each = model$n_clusters),
Expand Down
3 changes: 2 additions & 1 deletion R/hidden_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ hidden_paths.mnhmm <- function(model, respect_void = TRUE, ...) {
out <- viterbi_mnhmm_multichannel(
model$coefficients$gamma_pi_raw, model$X_initial,
model$coefficients$gamma_A_raw, model$X_transition,
model$coefficients$gamma_B_raw, model$X_emission,
unlist(model$coefficients$gamma_B_raw, recursive = FALSE),
model$X_emission,
model$coefficients$gamma_omega_raw, model$X_cluster,
obsArray, model$n_symbols)
}
Expand Down
1 change: 1 addition & 0 deletions R/plot.ame.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' @param probs A numeric vector of length 2 with the lower and upper limits for
#' confidence intervals. Default is `c(0.025, 0.975)`. If the limits are not
#' found in the input object `x`, an error is thrown.
#' @param ... Ignored.
#' @export
plot.ame <- function(x, type, probs = c(0.025, 0.975), ...) {

Expand Down
Loading

0 comments on commit 91260f7

Please sign in to comment.