From 3aab8ae3769355b9a9f42f145fb146bd6d44a40d Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Fri, 15 Nov 2024 10:25:11 +0200 Subject: [PATCH] ame probs and obs --- .gitignore | 1 + NAMESPACE | 6 +- R/RcppExports.R | 16 + R/ame_obs.R | 127 +++++ R/{ame.R => ame_probs.R} | 42 +- man/ame.Rd | 30 -- man/ame_probs.Rd | 46 ++ src/RcppExports.cpp | 114 +++++ src/mnhmm_EM.cpp | 677 ++++++++++++++++++++++++++ src/mnhmm_base.h | 2 +- src/mnhmm_mc.h | 9 +- src/mnhmm_sc.h | 9 +- src/nhmm_base.h | 1 + src/nhmm_forward.h | 23 +- src/nhmm_mc.h | 11 +- src/nhmm_sc.h | 10 +- src/nhmm_viterbi.h | 2 +- src/state_obs_probs.cpp | 111 +++++ src/state_obs_probs.h | 262 ++++++++++ tests/testthat/test-state_obs_probs.R | 118 +++++ 20 files changed, 1526 insertions(+), 91 deletions(-) create mode 100644 R/ame_obs.R rename R/{ame.R => ame_probs.R} (81%) delete mode 100644 man/ame.Rd create mode 100644 man/ame_probs.Rd create mode 100644 src/mnhmm_EM.cpp create mode 100644 src/state_obs_probs.cpp create mode 100644 src/state_obs_probs.h create mode 100644 tests/testthat/test-state_obs_probs.R diff --git a/.gitignore b/.gitignore index eaeae9df..41548e91 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ .project .settings # Extra source folders and compiled code +vignettes/cache src-i386 src-x64 *.o diff --git a/NAMESPACE b/NAMESPACE index 38da0b31..cffabc2e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,8 +6,8 @@ S3method("state_names<-",hmm) S3method("state_names<-",mhmm) S3method("state_names<-",mnhmm) S3method("state_names<-",nhmm) -S3method(ame,mnhmm) -S3method(ame,nhmm) +S3method(ame_probs,mnhmm) +S3method(ame_probs,nhmm) S3method(bootstrap_coefs,mnhmm) S3method(bootstrap_coefs,nhmm) S3method(cluster_names,mhmm) @@ -74,7 +74,7 @@ S3method(vcov,mhmm) export("cluster_names<-") export("state_names<-") export(alphabet) -export(ame) +export(ame_probs) export(bootstrap_coefs) export(build_hmm) export(build_lcm) diff --git a/R/RcppExports.R b/R/RcppExports.R index f0b5d084..95aca09c 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -289,3 +289,19 @@ softmax <- function(x) { .Call(`_seqHMM_softmax`, x) } +state_obs_probs_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) { + .Call(`_seqHMM_state_obs_probs_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) +} + +state_obs_probs_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) { + .Call(`_seqHMM_state_obs_probs_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) +} + +state_obs_probs_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) { + .Call(`_seqHMM_state_obs_probs_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) +} + +state_obs_probs_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) { + .Call(`_seqHMM_state_obs_probs_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) +} + diff --git a/R/ame_obs.R b/R/ame_obs.R new file mode 100644 index 00000000..c13a623e --- /dev/null +++ b/R/ame_obs.R @@ -0,0 +1,127 @@ +#' Average Marginal Effects for NHMM Responses +#' +#' The function `ame_obs` computes the average marginal effect (AME) of the +#' model covariate \eqn{X} at time t on the current and future responses by +#' marginalizing over the sequences and latent states. Under the assumption of +#' no unobserved confounding (i.e., there are no unobserved variables that +#' influence the covariate \eqn{X} and the outcome \eqn{Y}), these can be +#' regarded as the causal effects. In case `values` argument is a single value +#' \eqn{x}, the function returns the interventional distributions +#' \deqn{P(y_{t+k} | do(X_t = x))} +#' and in a case `values` contains two values \eqn{x} and \eqn{w} a shift in +#' interventional distributions, i.e., +#' \deqn{P(y_{t+k} | do(X_t = x)) - P(y_{t+k} | do(X_t = w))}. +#' +#' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`. +#' @param variable Name of the variable of interest. +#' @param values Vector containing one or two values for `variable`. +#' See details. +#' @param start_time Time(s) of intervention. Either a scalar or vector. +#' Intervention is applied to all provided time points. +#' @param newdata Optional data frame which is used for marginalization. +#' @param probs Quantiles of interest of average marginal effect. +#' @param ... Ignored. +#' @rdname ame_obs +#' @export +ame_obs <- function(model, variable, values, start_time, ...) { + UseMethod("ame_obs", model) +} +#' @rdname ame_obs +#' @export +ame_obs.nhmm <- function( + model, variable, values, start_time, newdata = NULL, probs = c(0.05, 0.95), + ...) { + stopifnot_( + attr(model, "intercept_only") == FALSE, + "Model does not contain any covariates." + ) + stopifnot_( + checkmate::test_string(x = variable), + "Argument {.arg variable} must be a single character string." + ) + stopifnot_( + length(values) == 2, + "Argument {.arg values} should contain two values for + variable {.var variable}.") + time <- model$time_variable + id <- model$id_variable + if (!is.null(newdata)) { + stopifnot_( + is.data.frame(newdata), + "Argument {.arg newdata} must be a {.cls data.frame} object." + ) + stopifnot_( + !is.null(newdata[[id]]), + "Can't find grouping variable {.var {id}} in {.arg newdata}." + ) + stopifnot_( + !is.null(newdata[[time]]), + "Can't find time index variable {.var {time}} in {.arg newdata}." + ) + stopifnot_( + !is.null(newdata[[variable]]), + "Can't find time variable {.var {variable}} in {.arg newdata}." + ) + } else { + stopifnot_( + !is.null(model$data), + "Model does not contain original data and argument {.arg newdata} is + {.var NULL}." + ) + newdata <- model$data + } + stopifnot_( + !is.null(model$boot), + paste0( + "Model does not contain bootstrap samples of coefficients. ", + "Run {.fn bootstrap_coefs} first." + ) + ) + newdata[[variable]] <- values[1] + model1 <- update(model, newdata) + newdata[[variable]] <- values[2] + model2 <- update(model, newdata) + C <- model$n_channels + if (C == 1L) { + times <- colnames(model$observations) + symbol_names <- list(model$symbol_names) + } else { + times <- colnames(model$observations[[1]]) + symbol_names <- model$symbol_names + } + stop("WIP") + if (model$n_channels == 1) { + + obs <- create_obsArray(model)[1L, , ] + out1 <- state_obs_probs_nhmm_singlechannel( + model1$etas$pi, model1$X_pi, model1$etas$A, model1$X_A, + model1$etas$B, model1$X_B, obs, model1$sequence_lengths, + attr(model1$X_pi, "icpt_only"), attr(model1$X_A, "icpt_only"), + attr(model1$X_B, "icpt_only"), attr(model1$X_A, "iv"), + attr(model1$X_B, "iv"), attr(model1$X_A, "tv"), attr(model1$X_B, "tv"), + start = start_time) + out2 <- state_obs_probs_nhmm_singlechannel( + model2$etas$pi, model2$X_pi, model2$etas$A, model2$X_A, + model2$etas$B, model2$X_B, obs, model2$sequence_lengths, + attr(model2$X_pi, "icpt_only"), attr(model2$X_A, "icpt_only"), + attr(model2$X_B, "icpt_only"), attr(model2$X_A, "iv"), + attr(model2$X_B, "iv"), attr(model2$X_A, "tv"), attr(model2$X_B, "tv"), + start = start_time) + } + + class(out) <- "ame_obs" + attr(out, "model") <- "nhmm" + out +} + +#' @rdname ame_obs +#' @export +ame_obs.mnhmm <- function( + model, variable, values, start_time, newdata = NULL, probs = c(0.05, 0.95), + ...) { + + stop("Not yet implemented") + class(out) <- "ame_obs" + attr(out, "model") <- "mnhmm" + out +} diff --git a/R/ame.R b/R/ame_probs.R similarity index 81% rename from R/ame.R rename to R/ame_probs.R index e7ae6dad..92aa0f99 100644 --- a/R/ame.R +++ b/R/ame_probs.R @@ -1,19 +1,37 @@ -#' Average Marginal Effects for Non-homogenous Hidden Markov Models +#' Average Marginal Effects on NHMM Parameters +#' +#' The function `ame_probs` computes the average marginal effect (AME) of the +#' model covariate \eqn{X} on the model parameters by marginalizing over the sequences. +#' Under the assumption of no unobserved confounding (i.e., there are no +#' unobserved variables that influence the covariate \eqn{X} and the outcome), +#' these can be regarded as the causal effects of the covariate on +#' the initial, emission, and transition probabilities of the model. In case +#' `values` argument is a single value \eqn{x}, the function returns the +#' interventional initial, transition, emission probabilities +#' \deqn{P(z_1 | do(X_1 = x))} +#' \deqn{P(z_t | do(X_{t-1} = x), z_{t-1})} +#' \deqn{P(y_t | do(X_t = x), z_t)} +#' and in a case `values` contains two values \eqn{x} and \eqn{w} a shift in +#' interventional distributions, i.e., +#' \deqn{P(z_1 | do(X_1 = x)) - P(z_1 | do(X_1 = w))} +#' \deqn{P(z_t | do(X_{t-1} = x), z_{t-1}) - P(z_t | do(X_{t-1} = w), z_{t-1})} +#' \deqn{P(y_t | do(X_t = x), z_t) - P(y_t | do(X_t = w), z_t)}. #' #' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`. #' @param variable Name of the variable of interest. -#' @param values Vector containing one or two values for `variable`. +#' @param values Vector containing one or two values for `variable`. +#' See details. #' @param newdata Optional data frame which is used for marginalization. #' @param probs Quantiles of interest of average marginal effect. #' @param ... Ignored. -#' @rdname ame +#' @rdname ame_probs #' @export -ame <- function(model, variable, values, ...) { - UseMethod("ame", model) +ame_probs <- function(model, variable, values, ...) { + UseMethod("ame_probs", model) } -#' @rdname ame +#' @rdname ame_probs #' @export -ame.nhmm <- function( +ame_probs.nhmm <- function( model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) { stopifnot_( @@ -184,19 +202,19 @@ ame.nhmm <- function( transition = ame_A, emission = ame_B ) - class(out) <- "amp" + class(out) <- "ace_prob" attr(out, "model") <- "nhmm" out } -#' @rdname ame +#' @rdname ame_probs #' @export -ame.mnhmm <- function( +ame_probs.mnhmm <- function( model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) { x <- lapply( - split_mnhmm(model), ame, variable = variable, values = values, + split_mnhmm(model), ame_probs, variable = variable, values = values, newdata = newdata, probs = probs ) out <- lapply(c("pi", "A", "B"), function(z) { @@ -235,7 +253,7 @@ ame.mnhmm <- function( ), qs_omega ) - class(out) <- "amp" + class(out) <- "ace_prob" attr(out, "model") <- "mnhmm" out } diff --git a/man/ame.Rd b/man/ame.Rd deleted file mode 100644 index 00215b49..00000000 --- a/man/ame.Rd +++ /dev/null @@ -1,30 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/ame.R -\name{ame} -\alias{ame} -\alias{ame.nhmm} -\alias{ame.mnhmm} -\title{Average Marginal Effects for Non-homogenous Hidden Markov Models} -\usage{ -ame(model, variable, values, ...) - -\method{ame}{nhmm}(model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) - -\method{ame}{mnhmm}(model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) -} -\arguments{ -\item{model}{A Hidden Markov Model of class \code{nhmm} or \code{mnhmm}.} - -\item{variable}{Name of the variable of interest.} - -\item{values}{Vector containing one or two values for \code{variable}.} - -\item{...}{Ignored.} - -\item{newdata}{Optional data frame which is used for marginalization.} - -\item{probs}{Quantiles of interest of average marginal effect.} -} -\description{ -Average Marginal Effects for Non-homogenous Hidden Markov Models -} diff --git a/man/ame_probs.Rd b/man/ame_probs.Rd new file mode 100644 index 00000000..b811fa9a --- /dev/null +++ b/man/ame_probs.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ace_probs.R +\name{ame_probs} +\alias{ame_probs} +\alias{ame_probs.nhmm} +\alias{ame_probs.mnhmm} +\title{Average Marginal Effects for Non-homogenous Hidden Markov Models} +\usage{ +ame_probs(model, variable, values, ...) + +\method{ame_probs}{nhmm}(model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) + +\method{ame_probs}{mnhmm}(model, variable, values, newdata = NULL, probs = c(0.05, 0.95), ...) +} +\arguments{ +\item{model}{A Hidden Markov Model of class \code{nhmm} or \code{mnhmm}.} + +\item{variable}{Name of the variable of interest.} + +\item{values}{Vector containing one or two values for \code{variable}. +See details.} + +\item{...}{Ignored.} + +\item{newdata}{Optional data frame which is used for marginalization.} + +\item{probs}{Quantiles of interest of average marginal effect.} +} +\description{ +The function \code{ame_probs} computes the average marginal effect (AME) of the +model covariate \eqn{X} on the model parameters by marginalizing over the sequences. +Under the assumption of no unobserved confounding (i.e., there are no +unobserved variables that influence the covariate \eqn{X} and the outcome +\eqn{Y}), these can be regarded as the causal effects of the covariate on +the initial, emission, and transition probabilities of the model. In case +\code{values} argument is a single value \eqn{x}, the function returns the +interventional initial, transition, emission probabilities +\deqn{P(z_1 | do(X_1 = x))} +\deqn{P(z_t | do(X_{t-1} = x), z_{t-1})} +\deqn{P(y_t | do(X_t = x), z_t)} +and in a case \code{values} contains two values \eqn{x} and \eqn{w} shift in +interventional distributions, i.e., +\deqn{P(z_1 | do(X_1 = x)) - P(z_1 | do(X_1 = w))} +\deqn{P(z_t | do(X_{t-1} = x), z_{t-1}) - P(z_t | do(X_{t-1} = w), z_{t-1})} +\deqn{P(y_t | do(X_t = x), z_t) - P(y_t | do(X_t = w), z_t)}. +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index d6df8057..f4611e35 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -1325,6 +1325,116 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// state_obs_probs_nhmm_singlechannel +Rcpp::List state_obs_probs_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword start); +RcppExport SEXP _seqHMM_state_obs_probs_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP); + Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP); + Rcpp::traits::input_parameter< arma::cube& >::type eta_B(eta_BSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP); + Rcpp::traits::input_parameter< const arma::umat& >::type obs(obsSEXP); + Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const arma::uword >::type start(startSEXP); + rcpp_result_gen = Rcpp::wrap(state_obs_probs_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)); + return rcpp_result_gen; +END_RCPP +} +// state_obs_probs_nhmm_multichannel +Rcpp::List state_obs_probs_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword start); +RcppExport SEXP _seqHMM_state_obs_probs_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP); + Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_B(eta_BSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP); + Rcpp::traits::input_parameter< const arma::ucube& >::type obs(obsSEXP); + Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const arma::uword >::type start(startSEXP); + rcpp_result_gen = Rcpp::wrap(state_obs_probs_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)); + return rcpp_result_gen; +END_RCPP +} +// state_obs_probs_mnhmm_singlechannel +Rcpp::List state_obs_probs_mnhmm_singlechannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword start); +RcppExport SEXP _seqHMM_state_obs_probs_mnhmm_singlechannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type eta_omega(eta_omegaSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_omega(X_omegaSEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_pi(eta_piSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_A(eta_ASEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_B(eta_BSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP); + Rcpp::traits::input_parameter< const arma::umat& >::type obs(obsSEXP); + Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_omega(icpt_only_omegaSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const arma::uword >::type start(startSEXP); + rcpp_result_gen = Rcpp::wrap(state_obs_probs_mnhmm_singlechannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)); + return rcpp_result_gen; +END_RCPP +} +// state_obs_probs_mnhmm_multichannel +Rcpp::List state_obs_probs_mnhmm_multichannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword start); +RcppExport SEXP _seqHMM_state_obs_probs_mnhmm_multichannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type eta_omega(eta_omegaSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_omega(X_omegaSEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_pi(eta_piSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_A(eta_ASEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP); + Rcpp::traits::input_parameter< arma::field& >::type eta_B(eta_BSEXP); + Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP); + Rcpp::traits::input_parameter< const arma::ucube& >::type obs(obsSEXP); + Rcpp::traits::input_parameter< const arma::uvec >::type Ti(TiSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_omega(icpt_only_omegaSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_pi(icpt_only_piSEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_A(icpt_only_ASEXP); + Rcpp::traits::input_parameter< const bool >::type icpt_only_B(icpt_only_BSEXP); + Rcpp::traits::input_parameter< const bool >::type iv_A(iv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type iv_B(iv_BSEXP); + Rcpp::traits::input_parameter< const bool >::type tv_A(tv_ASEXP); + Rcpp::traits::input_parameter< const bool >::type tv_B(tv_BSEXP); + Rcpp::traits::input_parameter< const arma::uword >::type start(startSEXP); + rcpp_result_gen = Rcpp::wrap(state_obs_probs_mnhmm_multichannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)); + return rcpp_result_gen; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { {"_seqHMM_cost_matrix_singlechannel", (DL_FUNC) &_seqHMM_cost_matrix_singlechannel, 6}, @@ -1399,6 +1509,10 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_viterbi", (DL_FUNC) &_seqHMM_viterbi, 4}, {"_seqHMM_viterbix", (DL_FUNC) &_seqHMM_viterbix, 7}, {"_seqHMM_softmax", (DL_FUNC) &_seqHMM_softmax, 1}, + {"_seqHMM_state_obs_probs_nhmm_singlechannel", (DL_FUNC) &_seqHMM_state_obs_probs_nhmm_singlechannel, 16}, + {"_seqHMM_state_obs_probs_nhmm_multichannel", (DL_FUNC) &_seqHMM_state_obs_probs_nhmm_multichannel, 16}, + {"_seqHMM_state_obs_probs_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_state_obs_probs_mnhmm_singlechannel, 19}, + {"_seqHMM_state_obs_probs_mnhmm_multichannel", (DL_FUNC) &_seqHMM_state_obs_probs_mnhmm_multichannel, 19}, {NULL, NULL, 0} }; diff --git a/src/mnhmm_EM.cpp b/src/mnhmm_EM.cpp new file mode 100644 index 00000000..18fc88de --- /dev/null +++ b/src/mnhmm_EM.cpp @@ -0,0 +1,677 @@ +// // EM algorithm for NHMMs +// +// #include "nhmm_forward.h" +// #include "nhmm_backward.h" +// #include "mmnhmm_sc.h" +// #include "mmnhmm_mc.h" +// #include "logsumexp.h" +// #include "sum_to_zero.h" +// #include +// +// double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) { +// if (!grad.is_empty()) { +// grad.zeros(); +// } +// double value = 0; +// +// eta_pi = arma::mat(x.memptr(), S - 1, K_pi); +// gamma_pi = sum_to_zero(eta_pi, Qs); +// arma::mat Qt = Qs.t(); +// arma::vec tmp(S); +// for (arma::uword i = 0; i < N; i++) { +// if (iv_pi || i == 0) { +// tmp = gamma_pi * X_pi.col(i); +// Pi = softmax(tmp); +// } +// double sum_e = sum(E_Pi(current_mixture).col(i)); +// +// double val = arma::as_scalar(E_Pi(current_mixture).col(i).t() * tmp - sum_e * logSumExp(tmp)); +// if (!std::isfinite(val)) { +// if (!grad.is_empty()) { +// grad.fill(std::numeric_limits::max()); +// } +// return std::numeric_limits::max(); +// } +// value -= val; +// +// // Only update grad if it's non-empty (i.e., for gradient-based optimization) +// if (!grad.is_empty()) { +// grad -= Qt * (E_Pi(current_mixture).col(i) - sum_e * Pi(current_mixture)) * X_pi.col(i).t(); +// if (!std::isfinite(arma::accu(grad))) { +// grad.fill(std::numeric_limits::max()); +// return std::numeric_limits::max(); +// } +// } +// } +// if (!grad.is_empty()) { +// grad += lambda * x; +// } +// return value + 0.5 * lambda * arma::dot(x, x); +// } +// void mnhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs, +// const double xtol_rel, const double ftol_rel, +// arma::uword maxeval) { +// auto objective_pi_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { +// auto* self = static_cast(data); +// arma::vec x_vec(const_cast(x), n, false, true); +// if (grad) { +// arma::vec grad_vec(grad, n, false, true); +// return self->objective_pi(x_vec, grad_vec); +// } else { +// arma::vec grad_dummy; +// return self->objective_pi(x_vec, grad_dummy); +// } +// }; +// +// arma::vec x_pi = arma::vectorise(eta_pi(current_mixture)); +// nlopt_opt opt_pi = nlopt_create(NLOPT_LD_LBFGS, x_pi.n_elem); +// nlopt_set_min_objective(opt_pi, objective_pi_wrapper, this); +// nlopt_set_xtol_abs1(opt_pi, xtol_abs); +// nlopt_set_ftol_abs(opt_pi, ftol_abs); +// nlopt_set_xtol_rel(opt_pi, xtol_rel); +// nlopt_set_ftol_rel(opt_pi, ftol_rel); +// nlopt_set_maxeval(opt_pi, maxeval); +// double minf; +// int status = nlopt_optimize(opt_pi, x_pi.memptr(), &minf); +// if (status < 0) { +// Rcpp::stop("M-step of initial probabilities errored with error code %i.", status); +// } +// eta_pi(current_mixture) = arma::mat(x_pi.memptr(), S - 1, K_pi); +// nlopt_destroy(opt_pi); +// } +// +// double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) { +// if (!grad.is_empty()) { +// grad.zeros(); +// } +// double value = 0; +// +// arma::mat eta_Arow = arma::mat(x.memptr(), S - 1, K_A); +// arma::mat gamma_Arow = sum_to_zero(eta_Arow, Qs); +// arma::vec A1(S); +// arma::mat Qt = Qs.t(); +// arma::vec tmp(S); +// if (!iv_A && !tv_A) { +// tmp = gamma_Arow * X_A.slice(0).col(0); +// A1 = softmax(tmp); +// } +// for (arma::uword i = 0; i < N; i++) { +// if (iv_A && !tv_A) { +// tmp = gamma_Arow * X_A.slice(i).col(0); +// A1 = softmax(tmp); +// } +// for (arma::uword t = 0; t < (Ti(i) - 1); t++) { +// if (tv_A) { +// tmp = gamma_Arow * X_A.slice(i).col(t); +// A1 = softmax(tmp); +// } +// double sum_e = sum(E_A(current_s).slice(t).col(i)); +// +// double val = arma::as_scalar(E_A(current_s).slice(t).col(i).t() * tmp - sum_e * logSumExp(tmp)); +// if (!std::isfinite(val)) { +// if (!grad.is_empty()) { +// grad.fill(std::numeric_limits::max()); +// } +// return std::numeric_limits::max(); +// } +// value -= val; +// if (!grad.is_empty()) { +// grad -= arma::vectorise( +// Qt * (E_A(current_s).slice(t).col(i) - sum_e * A1) * X_A.slice(i).col(t).t() +// ); +// if (!std::isfinite(arma::accu(grad))) { +// grad.fill(std::numeric_limits::max()); +// return std::numeric_limits::max(); +// } +// } +// } +// } +// if (!grad.is_empty()) { +// grad += lambda * x; +// } +// return value + 0.5 * lambda * arma::dot(x, x); +// } +// void mnhmm_base::mstep_A(const double ftol_abs, const double ftol_rel, +// const double xtol_abs, const double xtol_rel, +// arma::uword maxeval) { +// +// +// auto objective_A_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { +// auto* self = static_cast(data); +// arma::vec x_vec(const_cast(x), n, false, true); +// if (grad) { +// arma::vec grad_vec(grad, n, false, true); +// return self->objective_A(x_vec, grad_vec); +// } else { +// arma::vec grad_dummy; +// return self->objective_A(x_vec, grad_dummy); +// } +// }; +// +// arma::vec x_A(eta_A.slice(0).n_elem); +// nlopt_opt opt_A = nlopt_create(NLOPT_LD_LBFGS, x_A.n_elem); +// //nlopt_opt opt_A = nlopt_create(NLOPT_LN_SBPLX, x_A.n_elem); +// nlopt_set_min_objective(opt_A, objective_A_wrapper, this); +// nlopt_set_xtol_abs1(opt_A, xtol_abs); +// nlopt_set_ftol_abs(opt_A, ftol_abs); +// nlopt_set_xtol_rel(opt_A, xtol_rel); +// nlopt_set_ftol_rel(opt_A, ftol_rel); +// nlopt_set_maxeval(opt_A, maxeval); +// double minf; +// int status; +// for (arma::uword s = 0; s < S; s++) { +// current_s = s; +// x_A = arma::vectorise(eta_A.slice(s)); +// status = nlopt_optimize(opt_A, x_A.memptr(), &minf); +// if (status < 0) { +// Rcpp::stop("M-step of transition probabilities errored with error code %i.", status); +// } +// eta_A.slice(s) = arma::mat(x_A.memptr(), S - 1, K_A); +// } +// nlopt_destroy(opt_A); +// } +// +// double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) { +// if (!grad.is_empty()) { +// grad.zeros(); +// } +// double value = 0; +// +// arma::mat eta_Brow = arma::mat(x.memptr(), M - 1, K_B); +// arma::mat gamma_Brow = sum_to_zero(eta_Brow, Qm); +// arma::mat Qt = Qm.t(); +// arma::vec B1(M); +// arma::vec log_B1(M); +// +// if (!iv_B && !tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(0).col(0)); +// log_B1 = log(B1); +// } +// arma::mat I(M, M, arma::fill::eye); +// for (arma::uword i = 0; i < N; i++) { +// if (iv_B && !tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(i).col(0)); +// log_B1 = log(B1); +// } +// for (arma::uword t = 0; t < Ti(i); t++) { +// if (tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(i).col(t)); +// log_B1 = log(B1); +// } +// +// double val = E_B(t, i, current_s) * log_B1(obs(t, i)); +// if (!std::isfinite(val)) { +// if (!grad.is_empty()) { +// grad.fill(std::numeric_limits::max()); +// } +// return std::numeric_limits::max(); +// } +// value -= val; +// if (!grad.is_empty()) { +// grad -= arma::vectorise( +// E_B(t, i, current_s) * Qt * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t() +// ); +// if (!std::isfinite(arma::accu(grad))) { +// grad.fill(std::numeric_limits::max()); +// return std::numeric_limits::max(); +// } +// } +// } +// } +// if (!grad.is_empty()) { +// grad += lambda * x; +// } +// return value + 0.5 * lambda * arma::dot(x, x); +// } +// void mnhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, +// const double xtol_abs, const double xtol_rel, +// arma::uword maxeval) { +// +// auto objective_B_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { +// auto* self = static_cast(data); +// arma::vec x_vec(const_cast(x), n, false, true); +// if (grad) { +// arma::vec grad_vec(grad, n, false, true); +// return self->objective_B(x_vec, grad_vec); +// } else { +// arma::vec grad_dummy; +// return self->objective_B(x_vec, grad_dummy); +// } +// }; +// arma::vec x_B(eta_B.slice(0).n_elem); +// nlopt_opt opt_B = nlopt_create(NLOPT_LD_LBFGS, x_B.n_elem); +// //nlopt_opt opt_B = nlopt_create(NLOPT_LN_SBPLX, x_B.n_elem); +// nlopt_set_min_objective(opt_B, objective_B_wrapper, this); +// nlopt_set_xtol_abs1(opt_B, xtol_abs); +// nlopt_set_ftol_abs(opt_B, ftol_abs); +// nlopt_set_xtol_rel(opt_B, xtol_rel); +// nlopt_set_ftol_rel(opt_B, ftol_rel); +// nlopt_set_maxeval(opt_B, maxeval); +// double minf; +// int status; +// for (arma::uword s = 0; s < S; s++) { +// current_s = s; +// x_B = arma::vectorise(eta_B.slice(s)); +// status = nlopt_optimize(opt_B, x_B.memptr(), &minf); +// if (status < 0) { +// Rcpp::stop("M-step of emission probabilities errored with error code %i.", status); +// } +// eta_B.slice(s) = arma::mat(x_B.memptr(), M - 1, K_B); +// } +// nlopt_destroy(opt_B); +// } +// +// double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) { +// if (!grad.is_empty()) { +// grad.zeros(); +// } +// double value = 0; +// +// arma::mat eta_Brow = arma::mat(x.memptr(), M(current_c) - 1, K_B); +// arma::mat gamma_Brow = sum_to_zero(eta_Brow, Qm(current_c)); +// arma::mat Qt = Qm(current_c).t(); +// arma::vec B1(M(current_c)); +// arma::vec log_B1(M(current_c)); +// +// if (!iv_B && !tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(0).col(0)); +// log_B1 = log(B1); +// } +// arma::mat I(M(current_c), M(current_c), arma::fill::eye); +// for (arma::uword i = 0; i < N; i++) { +// if (iv_B && !tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(i).col(0)); +// log_B1 = log(B1); +// } +// for (arma::uword t = 0; t < Ti(i); t++) { +// if (tv_B) { +// B1 = softmax(gamma_Brow * X_B.slice(i).col(t)); +// log_B1 = log(B1); +// } +// +// double val = E_B(current_c)(t, i, current_s) * log_B1(obs(current_c, t, i)); +// if (!std::isfinite(val)) { +// if (!grad.is_empty()) { +// grad.fill(std::numeric_limits::max()); +// } +// return std::numeric_limits::max(); +// } +// value -= val; +// if (!grad.is_empty()) { +// grad -= arma::vectorise( +// E_B(current_c)(t, i, current_s) * Qt * (I.col(obs(current_c, t, i)) - B1) * X_B.slice(i).col(t).t() +// ); +// if (!std::isfinite(arma::accu(grad))) { +// grad.fill(std::numeric_limits::max()); +// return std::numeric_limits::max(); +// } +// } +// } +// } +// if (!grad.is_empty()) { +// grad += lambda * x; +// } +// return value + 0.5 * lambda * arma::dot(x, x); +// } +// void mnhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel, +// const double xtol_abs, const double xtol_rel, +// arma::uword maxeval) { +// +// auto objective_B_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { +// auto* self = static_cast(data); +// arma::vec x_vec(const_cast(x), n, false, true); +// if (grad) { +// arma::vec grad_vec(grad, n, false, true); +// return self->objective_B(x_vec, grad_vec); +// } else { +// arma::vec grad_dummy; +// return self->objective_B(x_vec, grad_dummy); +// } +// }; +// arma::vec x_B(eta_B.slice(0).n_elem); +// nlopt_opt opt_B = nlopt_create(NLOPT_LD_LBFGS, x_B.n_elem); +// //nlopt_opt opt_B = nlopt_create(NLOPT_LN_SBPLX, x_B.n_elem); +// nlopt_set_min_objective(opt_B, objective_B_wrapper, this); +// nlopt_set_xtol_abs1(opt_B, xtol_abs); +// nlopt_set_ftol_abs(opt_B, ftol_abs); +// nlopt_set_xtol_rel(opt_B, xtol_rel); +// nlopt_set_ftol_rel(opt_B, ftol_rel); +// nlopt_set_maxeval(opt_B, maxeval); +// double minf; +// int status; +// for (arma::uword c = 0; c < C; c++) { +// current_c = c; +// for (arma::uword s = 0; s < S; s++) { +// current_s = s; +// x_B = arma::vectorise(eta_B(c).slice(s)); +// status = nlopt_optimize(opt_B, x_B.memptr(), &minf); +// if (status < 0) { +// Rcpp::stop("M-step of emission probabilities errored with error code %i.", status); +// } +// eta_B(c).slice(s) = arma::mat(x_B.memptr(), M(c) - 1, K_B); +// } +// } +// nlopt_destroy(opt_B); +// } +// // [[Rcpp::export]] +// Rcpp::List EM_LBFGS_nhmm_singlechannel( +// arma::mat& eta_pi, const arma::mat& X_pi, +// arma::cube& eta_A, const arma::cube& X_A, +// arma::cube& eta_B, const arma::cube& X_B, +// const arma::umat& obs, const arma::uvec& Ti, , const bool icpt_only_omega, +// const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, +// const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, +// const arma::uword maxeval, const double ftol_abs, const double ftol_rel, +// const double xtol_abs, const double xtol_rel, +// const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, +// const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level, +// const double lambda) { +// +// mnhmm_sc model( +// eta_A.n_slices, X_pi, X_A, X_B, Ti, +// iv_pi, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B, lambda +// ); +// +// // EM-algorithm begins +// arma::uword n_pi = model.eta_pi.n_elem; +// arma::uword n_A = model.eta_A.n_elem; +// arma::uword n_B = model.eta_B.n_elem; +// arma::rowvec current_pars(n_pi + n_A + n_B); +// arma::rowvec previous_pars(n_pi + n_A + n_B); +// +// previous_pars.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); +// previous_pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); +// previous_pars.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t(); +// +// double relative_change = ftol_rel + 1.0; +// double absolute_change = ftol_abs + 1.0; +// double relative_x_change = xtol_rel + 1.0; +// double absolute_x_change = xtol_abs+ 1.0; +// arma::uword iter = 0; +// double ll_new; +// double ll; +// arma::mat log_alpha(model.S, model.T); +// arma::mat log_beta(model.S, model.T); +// +// // Initial log-likelihood +// for (arma::uword i = 0; i < model.N; i++) { +// if (model.iv_pi || i == 0) { +// model.update_pi(i); +// } +// if (model.iv_A || i == 0) { +// model.update_A(i); +// } +// if (model.iv_B || i == 0) { +// model.update_B(i); +// } +// model.update_log_py(i); +// univariate_forward_nhmm( +// log_alpha, model.log_Pi, model.log_A, +// model.log_py.cols(0, model.Ti(i) - 1) +// ); +// univariate_backward_nhmm( +// log_beta, model.log_A, model.log_py.cols(0, model.Ti(i) - 1) +// ); +// +// double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); +// ll += ll_i; +// model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); +// model.estep_A(i, log_alpha, log_beta, ll_i); +// model.estep_B(i, log_alpha, log_beta, ll_i); +// } +// double penalty_term = 0.5 * lambda * arma::dot(previous_pars, previous_pars); +// ll -= penalty_term; +// +// if (print_level > 0) { +// Rcpp::Rcout<<"Initial value of the log-likelihood: "< 1) { +// Rcpp::Rcout<<"Initial parameter values"< ftol_rel && absolute_change > ftol_abs && +// absolute_x_change > xtol_abs && relative_x_change > xtol_rel && iter < maxeval) { +// iter++; +// ll_new = 0; +// +// // Minimize obj(E_pi, E_A, E_B, eta_pi, eta_A, eta_B, X_pi, X_A, X_B) +// // with respect to eta_pi, eta_A, eta_B +// model.mstep_pi(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// model.mstep_A(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// model.mstep_B(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// // Update model +// model.update_gamma_pi(); +// model.update_gamma_A(); +// model.update_gamma_B(); +// for (arma::uword i = 0; i < model.N; i++) { +// if (model.iv_pi || i == 0) { +// model.update_pi(i); +// } +// if (model.iv_A || i == 0) { +// model.update_A(i); +// } +// if (model.iv_B || i == 0) { +// model.update_B(i); +// } +// model.update_log_py(i); +// univariate_forward_nhmm( +// log_alpha, model.log_Pi, model.log_A, +// model.log_py.cols(0, model.Ti(i) - 1) +// ); +// univariate_backward_nhmm( +// log_beta, model.log_A, model.log_py.cols(0, model.Ti(i) - 1) +// ); +// double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); +// ll_new += ll_i; +// model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); +// model.estep_A(i, log_alpha, log_beta, ll_i); +// model.estep_B(i, log_alpha, log_beta, ll_i); +// } +// current_pars.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); +// current_pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); +// current_pars.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t(); +// +// penalty_term = 0.5 * lambda * arma::dot(current_pars, current_pars); +// ll_new -= penalty_term; +// +// relative_change = (ll_new - ll) / (std::abs(ll) + 1e-8); +// absolute_change = (ll_new - ll) / model.n_obs; +// absolute_x_change = arma::max(arma::abs(current_pars - previous_pars)); +// relative_x_change = arma::norm(current_pars - previous_pars, 1) / arma::norm(current_pars, 1); +// if (print_level > 0) { +// Rcpp::Rcout<<"Iteration: "< 1) { +// Rcpp::Rcout << "current parameter values"<< std::endl; +// Rcpp::Rcout<& eta_B, const arma::cube& X_B, +// const arma::ucube& obs, const bool iv_pi, +// const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, +// const arma::uvec& Ti, const arma::uword n_obs, +// const arma::uword maxeval, const double ftol_abs, const double ftol_rel, +// const double xtol_abs, const double xtol_rel, +// const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, +// const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level, +// const double lambda) { +// +// mnhmm_mc model( +// eta_A.n_slices, X_pi, X_A, X_B, Ti, +// iv_pi, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B, lambda +// ); +// +// // EM-algorithm begins +// arma::uword n_pi = model.eta_pi.n_elem; +// arma::uword n_A = model.eta_A.n_elem; +// arma::uvec n_Bc; +// for (arma::uword c = 0; c < model.C; c++) { +// n_Bc(c) = model.eta_B(c).n_elem; +// } +// arma::uword n_B = arma::accu(n_Bc); +// arma::rowvec current_pars(n_pi + n_A + n_B); +// arma::rowvec previous_pars(n_pi + n_A + n_B); +// +// previous_pars.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); +// previous_pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); +// for (arma::uword c = 0; c < model.C; c++) { +// previous_pars.cols( +// n_pi + n_A, n_pi + n_A + arma::accu(n_Bc.rows(0, c - 1)) - 1 +// ) = arma::vectorise(model.eta_B(c)).t(); +// } +// double relative_change = ftol_rel + 1.0; +// double absolute_change = ftol_abs + 1.0; +// double relative_x_change = xtol_rel + 1.0; +// double absolute_x_change = xtol_abs+ 1.0; +// arma::uword iter = 0; +// double ll_new; +// double ll; +// arma::mat log_alpha(model.S, model.T); +// arma::mat log_beta(model.S, model.T); +// +// // Initial log-likelihood +// for (arma::uword i = 0; i < model.N; i++) { +// if (model.iv_pi || i == 0) { +// model.update_pi(i); +// } +// if (model.iv_A || i == 0) { +// model.update_A(i); +// } +// if (model.iv_B || i == 0) { +// model.update_B(i); +// } +// model.update_log_py(i); +// univariate_forward_nhmm( +// log_alpha, model.log_Pi, model.log_A, +// model.log_py.cols(0, model.Ti(i) - 1) +// ); +// univariate_backward_nhmm( +// log_beta, model.log_A, model.log_py.cols(0, model.Ti(i) - 1) +// ); +// +// double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); +// ll += ll_i; +// model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); +// model.estep_A(i, log_alpha, log_beta, ll_i); +// model.estep_B(i, log_alpha, log_beta, ll_i); +// } +// double penalty_term = 0.5 * lambda * arma::dot(previous_pars, previous_pars); +// ll -= penalty_term; +// +// if (print_level > 0) { +// Rcpp::Rcout<<"Initial value of the log-likelihood: "< 1) { +// Rcpp::Rcout<<"Initial parameter values"< ftol_rel && absolute_change > ftol_abs && +// absolute_x_change > xtol_abs && relative_x_change > xtol_rel && iter < maxeval) { +// iter++; +// ll_new = 0; +// +// // Minimize obj(E_pi, E_A, E_B, eta_pi, eta_A, eta_B, X_pi, X_A, X_B) +// // with respect to eta_pi, eta_A, eta_B +// model.mstep_pi(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// model.mstep_A(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// model.mstep_B(ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_abs_m, maxeval_m); +// // Update model +// model.update_gamma_pi(); +// model.update_gamma_A(); +// model.update_gamma_B(); +// for (arma::uword i = 0; i < model.N; i++) { +// if (model.iv_pi || i == 0) { +// model.update_pi(i); +// } +// if (model.iv_A || i == 0) { +// model.update_A(i); +// } +// if (model.iv_B || i == 0) { +// model.update_B(i); +// } +// model.update_log_py(i); +// univariate_forward_nhmm( +// log_alpha, model.log_Pi, model.log_A, +// model.log_py.cols(0, model.Ti(i) - 1) +// ); +// univariate_backward_nhmm( +// log_beta, model.log_A, model.log_py.cols(0, model.Ti(i) - 1) +// ); +// double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); +// ll_new += ll_i; +// model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); +// model.estep_A(i, log_alpha, log_beta, ll_i); +// model.estep_B(i, log_alpha, log_beta, ll_i); +// } +// current_pars.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); +// current_pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); +// for (arma::uword c = 0; c < model.C; c++) { +// current_pars.cols( +// n_pi + n_A, n_pi + n_A + arma::accu(n_Bc.rows(0, c - 1)) - 1 +// ) = arma::vectorise(model.eta_B(c)).t(); +// } +// +// penalty_term = 0.5 * lambda * arma::dot(current_pars, current_pars); +// ll_new -= penalty_term; +// +// relative_change = (ll_new - ll) / (std::abs(ll) + 1e-8); +// absolute_change = (ll_new - ll) / model.n_obs; +// absolute_x_change = arma::max(arma::abs(current_pars - previous_pars)); +// relative_x_change = arma::norm(current_pars - previous_pars, 1) / arma::norm(current_pars, 1); +// if (print_level > 0) { +// Rcpp::Rcout<<"Iteration: "< 1) { +// Rcpp::Rcout << "current parameter values"<< std::endl; +// Rcpp::Rcout< gamma_pi; arma::field eta_A; arma::field gamma_A; - // Pi, A, and log_p(y) of _one_ id and cluster we are currently working with + // Pi, A, and log_p(y) of _one_ id we are currently working with arma::vec omega; arma::vec log_omega; arma::field Pi; diff --git a/src/mnhmm_mc.h b/src/mnhmm_mc.h index c3fb4521..76c50dd7 100644 --- a/src/mnhmm_mc.h +++ b/src/mnhmm_mc.h @@ -109,11 +109,6 @@ struct mnhmm_mc : public mnhmm_base { } } - void update_probs(const arma::uword i) { - update_pi(i); - update_A(i); - update_B(i); - } void update_log_py(const arma::uword i) { log_py.zeros(); for (arma::uword d = 0; d < D; d++) { @@ -124,5 +119,9 @@ struct mnhmm_mc : public mnhmm_base { } } } + void compute_state_obs_probs( + const arma::uword start, arma::field& obs_prob, + arma::cube& state_prob + ); }; #endif diff --git a/src/mnhmm_sc.h b/src/mnhmm_sc.h index 995dedb4..ba1fa673 100644 --- a/src/mnhmm_sc.h +++ b/src/mnhmm_sc.h @@ -92,11 +92,7 @@ struct mnhmm_sc : public mnhmm_base { } } } - void update_probs(const arma::uword i) { - update_pi(i); - update_A(i); - update_B(i); - } + void update_log_py(const arma::uword i) { for (arma::uword d = 0; d < D; d++) { for (arma::uword t = 0; t < Ti(i); t++) { @@ -104,5 +100,8 @@ struct mnhmm_sc : public mnhmm_base { } } } + void compute_state_obs_probs( + const arma::uword start, arma::cube& obs_prob, arma::cube& state_prob + ); }; #endif diff --git a/src/nhmm_base.h b/src/nhmm_base.h index e0e12b7f..b17870fb 100644 --- a/src/nhmm_base.h +++ b/src/nhmm_base.h @@ -165,6 +165,7 @@ struct nhmm_base { const arma::uword maxeval, const arma::uword print_level); double objective_pi(const arma::vec& x, arma::vec& grad); double objective_A(const arma::vec& x, arma::vec& grad); + }; #endif diff --git a/src/nhmm_forward.h b/src/nhmm_forward.h index 5ebf9e34..c60debee 100644 --- a/src/nhmm_forward.h +++ b/src/nhmm_forward.h @@ -24,26 +24,6 @@ void univariate_forward_nhmm( } } -// // time-invariant A -// template -// void univariate_forward_nhmm( -// submat& log_alpha, -// const arma::vec& log_pi, -// const arma::mat& log_A, -// const arma::mat& log_py) { -// -// arma::uword S = log_py.n_rows; -// arma::uword T = log_py.n_cols; -// log_alpha.col(0) = log_pi + log_py.col(0); -// for (arma::uword t = 1; t < T; t++) { -// for (arma::uword i = 0; i < S; i++) { -// log_alpha(i, t) = logSumExp( -// log_alpha.col(t - 1) + log_A.col(i) + log_py(i, t) -// ); -// } -// } -// } - template void forward_nhmm(Model& model, arma::cube& log_alpha) { for (arma::uword i = 0; i < model.N; i++) { @@ -81,14 +61,13 @@ void forward_mnhmm(Model& model, arma::cube& log_alpha) { if (model.iv_B || i == 0) { model.update_B(i); } - model.update_probs(i); model.update_log_py(i); for (arma::uword d = 0; d < model.D; d++) { arma::subview submat = log_alpha.slice(i).rows(d * model.S, (d + 1) * model.S - 1); univariate_forward_nhmm( submat, - model.omega(d) + model.log_Pi(d), + model.log_omega(d) + model.log_Pi(d), model.log_A(d), model.log_py.slice(d).cols(0, model.Ti(i) - 1) ); diff --git a/src/nhmm_mc.h b/src/nhmm_mc.h index c2c8f890..61423499 100644 --- a/src/nhmm_mc.h +++ b/src/nhmm_mc.h @@ -103,12 +103,6 @@ struct nhmm_mc : public nhmm_base { } } - void update_probs(const arma::uword i) { - update_pi(i); - update_A(i); - update_B(i); - } - void update_log_py(const arma::uword i) { log_py.zeros(); for (arma::uword t = 0; t < Ti(i); t++) { @@ -139,5 +133,10 @@ struct nhmm_mc : public nhmm_base { const arma::uword maxeval, const arma::uword print_level); double objective_B(const arma::vec& x, arma::vec& grad); + + void compute_state_obs_probs( + const arma::uword start, arma::field& obs_prob, + arma::cube& state_prob + ); }; #endif diff --git a/src/nhmm_sc.h b/src/nhmm_sc.h index a9f2a897..6cc56d1b 100644 --- a/src/nhmm_sc.h +++ b/src/nhmm_sc.h @@ -82,12 +82,6 @@ struct nhmm_sc : public nhmm_base { log_B = arma::log(B); } - void update_probs(const arma::uword i) { - update_pi(i); - update_A(i); - update_B(i); - } - void update_log_py(const arma::uword i) { for (arma::uword t = 0; t < Ti(i); t++) { log_py.col(t) = log_B.slice(t).col(obs(t, i)); @@ -112,5 +106,9 @@ struct nhmm_sc : public nhmm_base { const arma::uword maxeval, const arma::uword print_level); double objective_B(const arma::vec& x, arma::vec& grad); + + void compute_state_obs_probs( + const arma::uword start, arma::cube& obs_prob, arma::cube& state_prob + ); }; #endif diff --git a/src/nhmm_viterbi.h b/src/nhmm_viterbi.h index f434642f..ad5139fc 100644 --- a/src/nhmm_viterbi.h +++ b/src/nhmm_viterbi.h @@ -74,7 +74,7 @@ void viterbi_mnhmm(Model& model, arma::umat& q, arma::vec& logp) { for (arma::uword d = 0; d < model.D; d++) { logp_d = univariate_viterbi_nhmm( q_d, - model.omega(d) + model.log_Pi(d), + model.log_omega(d) + model.log_Pi(d), model.log_A(d).slices(0, model.Ti(i) - 1), model.log_py.slice(d).cols(0, model.Ti(i) - 1) ); diff --git a/src/state_obs_probs.cpp b/src/state_obs_probs.cpp new file mode 100644 index 00000000..91bf781c --- /dev/null +++ b/src/state_obs_probs.cpp @@ -0,0 +1,111 @@ +// state_obs_probs algorithm for NHMM +#include "state_obs_probs.h" +#include "nhmm_sc.h" +#include "nhmm_mc.h" +#include "mnhmm_sc.h" +#include "mnhmm_mc.h" + +// [[Rcpp::export]] +Rcpp::List state_obs_probs_nhmm_singlechannel( + arma::mat& eta_pi, const arma::mat& X_pi, + arma::cube& eta_A, const arma::cube& X_A, + arma::cube& eta_B, const arma::cube& X_B, + const arma::umat& obs, const arma::uvec Ti, + const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, + const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, + const arma::uword start) { + + nhmm_sc model( + eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A, + icpt_only_B, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B + ); + + arma::cube obs_prob(model.M, model.T, model.N, arma::fill::value(arma::datum::nan)); + arma::cube state_prob(model.S, model.T, model.N, arma::fill::value(arma::datum::nan)); + model.compute_state_obs_probs(start, obs_prob, state_prob); + return Rcpp::List::create( + Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob), + Rcpp::Named("state_prob") = Rcpp::wrap(state_prob) + ); +} + +// [[Rcpp::export]] +Rcpp::List state_obs_probs_nhmm_multichannel( + arma::mat& eta_pi, const arma::mat& X_pi, + arma::cube& eta_A, const arma::cube& X_A, + arma::field& eta_B, const arma::cube& X_B, + const arma::ucube& obs, const arma::uvec Ti, + const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, + const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, + const arma::uword start) { + + nhmm_mc model( + eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A, + icpt_only_B, iv_A, iv_B, tv_A, tv_B, obs, eta_pi, eta_A, eta_B + ); + arma::field obs_prob(model.C); + for (arma::uword c = 0; c < model.C; c++) { + obs_prob(c) = arma::cube(model.M(c), model.T, model.N, arma::fill::value(arma::datum::nan)); + } + arma::cube state_prob(model.S, model.T, model.N, arma::fill::value(arma::datum::nan)); + model.compute_state_obs_probs(start, obs_prob, state_prob); + return Rcpp::List::create( + Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob), + Rcpp::Named("state_prob") = Rcpp::wrap(state_prob) + ); +} + +// [[Rcpp::export]] +Rcpp::List state_obs_probs_mnhmm_singlechannel( + arma::mat& eta_omega, const arma::mat& X_omega, + arma::field& eta_pi, const arma::mat& X_pi, + arma::field& eta_A, const arma::cube& X_A, + arma::field& eta_B, const arma::cube& X_B, + const arma::umat& obs, const arma::uvec Ti, + const bool icpt_only_omega, const bool icpt_only_pi, + const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, + const bool iv_B, const bool tv_A, const bool tv_B, + const arma::uword start) { + + mnhmm_sc model( + eta_A(0).n_slices, eta_A.n_rows, X_omega, X_pi, X_A, X_B, Ti, + icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, + iv_A, iv_B, tv_A, tv_B, obs, eta_omega, eta_pi, eta_A, eta_B + ); + arma::cube obs_prob(model.M, model.T, model.N, arma::fill::value(arma::datum::nan)); + arma::cube state_prob(model.S * model.D, model.T, model.N, arma::fill::value(arma::datum::nan)); + model.compute_state_obs_probs(start, obs_prob, state_prob); + return Rcpp::List::create( + Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob), + Rcpp::Named("state_prob") = Rcpp::wrap(state_prob) + ); +} + +// [[Rcpp::export]] +Rcpp::List state_obs_probs_mnhmm_multichannel( + arma::mat& eta_omega, const arma::mat& X_omega, + arma::field& eta_pi, const arma::mat& X_pi, + arma::field& eta_A, const arma::cube& X_A, + arma::field& eta_B, const arma::cube& X_B, + const arma::ucube& obs, const arma::uvec Ti, + const bool icpt_only_omega, const bool icpt_only_pi, + const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, + const bool iv_B, const bool tv_A, const bool tv_B, + const arma::uword start) { + + mnhmm_mc model( + eta_A(0).n_slices, eta_A.n_rows, X_omega, X_pi, X_A, X_B, Ti, + icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, + iv_A, iv_B, tv_A, tv_B, obs, eta_omega, eta_pi, eta_A, eta_B + ); + arma::field obs_prob(model.C); + for (arma::uword c = 0; c < model.C; c++) { + obs_prob(c) = arma::cube(model.M(c), model.T, model.N, arma::fill::value(arma::datum::nan)); + } + arma::cube state_prob(model.S * model.D, model.T, model.N, arma::fill::value(arma::datum::nan)); + model.compute_state_obs_probs(start, obs_prob, state_prob); + return Rcpp::List::create( + Rcpp::Named("obs_prob") = Rcpp::wrap(obs_prob), + Rcpp::Named("state_prob") = Rcpp::wrap(state_prob) + ); +} diff --git a/src/state_obs_probs.h b/src/state_obs_probs.h new file mode 100644 index 00000000..7913f7e6 --- /dev/null +++ b/src/state_obs_probs.h @@ -0,0 +1,262 @@ +#ifndef AME_OBS_NHMM_H +#define AME_OBS_NHMM_H + +#include +#include "nhmm_sc.h" +#include "nhmm_mc.h" +#include "mnhmm_sc.h" +#include "mnhmm_mc.h" +#include "nhmm_forward.h" +#include "logsumexp.h" + +template +void univariate_state_prob( + submat& log_state_prob, + const arma::uword start, + const arma::uword end, + const arma::cube& log_A) { + + arma::uword S = log_A.n_rows; + for (arma::uword t = 0; t < start; t++) { + log_state_prob.col(t) -= logSumExp(log_state_prob.col(t)); + } + for (arma::uword t = start; t < end; t++) { + for (arma::uword s = 0; s < S; s++) { + log_state_prob(s, t) = logSumExp(log_state_prob.col(t - 1) + log_A.slice(t - 1).col(s)); + } + log_state_prob.col(t) -= logSumExp(log_state_prob.col(t)); + } +} + +template +void univariate_obs_prob( + arma::mat& log_obs_prob, + const submat& log_state_prob, + const arma::uword start, + const arma::uword end, + const arma::cube& log_B) { + arma::uword M = log_B.n_cols - 1; + for (arma::uword t = start; t < end; t++) { + for (arma::uword m = 0; m < M; m++) { + log_obs_prob(m, t) = logSumExp(log_state_prob.col(t) + + log_B.slice(t).col(m)); + } + } +} + +void nhmm_sc::compute_state_obs_probs( + const arma::uword start, arma::cube& obs_prob, arma::cube& state_prob) { + obs_prob.cols(0, start - 1).fill(-arma::datum::inf); + bool not_updated = true; + for (arma::uword i = 0; i < N; i++) { + arma::uword upper_bound = std::min(start, Ti(i)); + for (arma::uword t = 0; t < upper_bound; t++) { + obs_prob(obs(t, i), t, i) = 0; + } + if (start < Ti(i)) { + if (!icpt_only_pi || not_updated) { + update_pi(i); + } + if (iv_A || not_updated) { + update_A(i); + } + if (iv_B || not_updated) { + update_B(i); + } + not_updated = false; + update_log_py(i); + univariate_forward_nhmm( + state_prob.slice(i), log_Pi, log_A, + log_py.cols(0, start - 1) + ); + univariate_state_prob( + state_prob.slice(i), start, Ti(i), log_A + ); + univariate_obs_prob( + obs_prob.slice(i), + state_prob.slice(i), + start, Ti(i), + log_B + ); + } + } + obs_prob = arma::exp(obs_prob); + state_prob = arma::exp(state_prob); +} + +void nhmm_mc::compute_state_obs_probs( + const arma::uword start, arma::field& obs_prob, + arma::cube& state_prob) { + for (arma::uword c = 0; c < C; c++) { + obs_prob(c).cols(0, start - 1).fill(-arma::datum::inf); + } + bool not_updated = true; + for (arma::uword i = 0; i < N; i++) { + arma::uword upper_bound = std::min(start, Ti(i)); + for (arma::uword t = 0; t < upper_bound; t++) { + for (arma::uword c = 0; c < C; c++) { + obs_prob(c)(obs(c, t, i), t, i) = 0; + } + } + if (start < Ti(i)) { + if (!icpt_only_pi || not_updated) { + update_pi(i); + } + if (iv_A || not_updated) { + update_A(i); + } + if (iv_B || not_updated) { + update_B(i); + } + not_updated = false; + update_log_py(i); + univariate_forward_nhmm( + state_prob.slice(i), log_Pi, log_A, + log_py.cols(0, start - 1) + ); + univariate_state_prob( + state_prob.slice(i), start, Ti(i), log_A + ); + for (arma::uword c = 0; c < C; c++) { + univariate_obs_prob( + obs_prob(c).slice(i), + state_prob.slice(i), + start, Ti(i), + log_B(c) + ); + } + } + } + for (arma::uword c = 0; c < C; c++) { + obs_prob(c) = arma::exp(obs_prob(c)); + } + state_prob = arma::exp(state_prob); +} + +void mnhmm_sc::compute_state_obs_probs( + const arma::uword start, arma::cube& obs_prob, arma::cube& state_prob) { + obs_prob.cols(0, start - 1).fill(-arma::datum::inf); + bool not_updated = true; + arma::cube tmp(M, T, D); + for (arma::uword i = 0; i < N; i++) { + arma::uword upper_bound = std::min(start, Ti(i)); + for (arma::uword t = 0; t < upper_bound; t++) { + obs_prob(obs(t, i), t, i) = 0; + } + if (start < Ti(i)) { + if (!icpt_only_omega || not_updated) { + update_omega(i); + } + if (!icpt_only_pi || not_updated) { + update_pi(i); + } + if (iv_A || not_updated) { + update_A(i); + } + if (iv_B || not_updated) { + update_B(i); + } + not_updated = false; + update_log_py(i); + for (arma::uword d = 0; d < D; d++) { + arma::subview submat = + state_prob.slice(i).rows(d * S, (d + 1) * S - 1); + univariate_forward_nhmm( + submat, + log_Pi(d), + log_A(d), + log_py.slice(d).cols(0, start - 1) + ); + univariate_state_prob( + submat, start, Ti(i), log_A(d) + ); + submat += log_omega(d); + univariate_obs_prob( + tmp.slice(d), + submat, + start, Ti(i), + log_B(d) + ); + } + for (arma::uword t = start; t < Ti(i); t++) { + for (arma::uword m = 0; m < M; m++) { + obs_prob(m, t, i) = logSumExp(tmp.tube(m, t)); + } + } + } + } + obs_prob = arma::exp(obs_prob); + state_prob = arma::exp(state_prob); +} + +void mnhmm_mc::compute_state_obs_probs( + const arma::uword start, arma::field& obs_prob, + arma::cube& state_prob) { + for (arma::uword c = 0; c < C; c++) { + obs_prob(c).cols(0, start - 1).fill(-arma::datum::inf); + } + bool not_updated = true; + arma::field tmp(C); + for (arma::uword c = 0; c < C; c++) { + tmp(c) = arma::cube(M(c), T, D); + } + for (arma::uword i = 0; i < N; i++) { + arma::uword upper_bound = std::min(start, Ti(i)); + for (arma::uword t = 0; t < upper_bound; t++) { + for (arma::uword c = 0; c < C; c++) { + obs_prob(c)(obs(c, t, i), t, i) = 0; + } + } + if (start < Ti(i)) { + if (!icpt_only_omega || not_updated) { + update_omega(i); + } + if (!icpt_only_pi || not_updated) { + update_pi(i); + } + if (iv_A || not_updated) { + update_A(i); + } + if (iv_B || not_updated) { + update_B(i); + } + not_updated = false; + update_log_py(i); + for (arma::uword d = 0; d < D; d++) { + arma::subview submat = + state_prob.slice(i).rows(d * S, (d + 1) * S - 1); + univariate_forward_nhmm( + submat, + log_Pi(d), + log_A(d), + log_py.slice(d).cols(0, start - 1) + ); + univariate_state_prob( + submat, start, Ti(i), log_A(d) + ); + submat += log_omega(d); + for (arma::uword c = 0; c < C; c++) { + univariate_obs_prob( + tmp(c).slice(d), + submat, + start, Ti(i), + log_B(c, d) + ); + } + } + for (arma::uword c = 0; c < C; c++) { + for (arma::uword t = start; t < Ti(i); t++) { + for (arma::uword m = 0; m < M(c); m++) { + obs_prob(c)(m, t, i) = logSumExp(tmp(c).tube(m, t)); + } + } + } + } + } + for (arma::uword c = 0; c < C; c++) { + obs_prob(c) = arma::exp(obs_prob(c)); + } + state_prob = arma::exp(state_prob); +} +#endif + diff --git a/tests/testthat/test-state_obs_probs.R b/tests/testthat/test-state_obs_probs.R new file mode 100644 index 00000000..57c1b9aa --- /dev/null +++ b/tests/testthat/test-state_obs_probs.R @@ -0,0 +1,118 @@ + +test_that("'state_obs_probs' works for multichannel 'nhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_nhmm( + hmm_biofam$observations, n_states = 5, + inits = hmm_biofam[ + c("initial_probs", "transition_probs", "emission_probs") + ], maxeval = 1 + ), + NA + ) + obs <- create_obsArray(fit) + expect_error( + out <- state_obs_probs_nhmm_multichannel( + fit$etas$pi, fit$X_pi, fit$etas$A, fit$X_A, + fit$etas$B, fit$X_B, obs, fit$sequence_lengths, + attr(fit$X_pi, "icpt_only"), attr(fit$X_A, "icpt_only"), + attr(fit$X_B, "icpt_only"), attr(fit$X_A, "iv"), + attr(fit$X_B, "iv"), attr(fit$X_A, "tv"), attr(fit$X_B, "tv"), + start = 3L), + NA + ) + expect_gte(min(unlist(out$obs_prob)), 0) + expect_lte(max(unlist(out$obs_prob)), 1) + expect_gte(min(out$state_prob), 0) + expect_lte(max(out$state_prob), 1) + expect_true(all(abs(apply(out$obs_prob[[1]], 2:3, sum) - 1) < sqrt(.Machine$double.eps))) + expect_true(all(abs(apply(out$obs_prob[[2]], 2:3, sum) - 1) < sqrt(.Machine$double.eps))) + expect_true(all(abs(apply(out$obs_prob[[3]], 2:3, sum) - 1) < sqrt(.Machine$double.eps))) + expect_true(all(abs(apply(out$state_prob, 2:3, sum) - 1) < sqrt(.Machine$double.eps))) +}) + +test_that("'state_obs_probs' works for single-channel 'nhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_nhmm( + hmm_biofam$observations[[1]][1:100,], n_states = 3, + restarts = 2, maxeval = 2, lambda = 1 + ), + NA + ) + obs <- create_obsArray(fit)[1L, , ] + expect_error( + out <- state_obs_probs_nhmm_singlechannel( + fit$etas$pi, fit$X_pi, fit$etas$A, fit$X_A, + fit$etas$B, fit$X_B, obs, fit$sequence_lengths, + attr(fit$X_pi, "icpt_only"), attr(fit$X_A, "icpt_only"), + attr(fit$X_B, "icpt_only"), attr(fit$X_A, "iv"), + attr(fit$X_B, "iv"), attr(fit$X_A, "tv"), attr(fit$X_B, "tv"), + start = 3L), + NA + ) + expect_gte(min(out$obs_prob), 0) + expect_lte(max(out$obs_prob), 1) + expect_gte(min(out$state_prob), 0) + expect_lte(max(out$state_prob), 1) + expect_true(all(abs(apply(out$obs_prob, 2:3, sum) - 1) < sqrt(.Machine$double.eps))) + expect_true(all(abs(apply(out$state_prob, 2:3, sum) - 1) < sqrt(.Machine$double.eps))) +}) + +test_that("'forward_backward' works for multichannel 'mnhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_mnhmm( + hmm_biofam$observations, n_states = 3, n_clusters = 2, + maxeval = 1 + ), + NA + ) + expect_error( + fb <- forward_backward(fit, as_data_frame = FALSE), + NA + ) + expect_gte(min(fb$forward_probs), -2000) + expect_gte(min(fb$backward_probs), -2000) + expect_lte(max(fb$forward_probs), 0) + expect_lte(max(fb$backward_probs), 0) + + expect_error( + fb <- forward_backward(fit), + NA + ) + expect_lte(max(exp(fb$log_probability)), 1) + expect_gte(min(exp(fb$log_probability)), 0) +}) + +test_that("'forward_backward' works for single-channel 'mnhmm'", { + set.seed(1) + expect_error( + fit <- estimate_mnhmm( + hmm_biofam$observations[[1]], n_states = 4, n_clusters = 2, + restarts = 2, maxeval = 1 + ), + NA + ) + obs <- create_obsArray(fit)[1L, , ] + expect_error( + out <- state_obs_probs_mnhmm_singlechannel( + fit$etas$omega, fit$X_omega, fit$etas$pi, fit$X_pi, fit$etas$A, fit$X_A, + fit$etas$B, fit$X_B, obs, fit$sequence_lengths, + attr(fit$X_pi, "icpt_only"), attr(fit$X_omega, "icpt_only"), + attr(fit$X_A, "icpt_only"), + attr(fit$X_B, "icpt_only"), attr(fit$X_A, "iv"), + attr(fit$X_B, "iv"), attr(fit$X_A, "tv"), attr(fit$X_B, "tv"), + start = 3L), + NA + ) + expect_gte(min(out$obs_prob), 0) + expect_lte(max(out$obs_prob), 1) + expect_gte(min(out$state_prob), 0) + expect_lte(max(out$state_prob), 1) + expect_true(all(abs(apply(out$obs_prob, 2:3, sum) - 1) < sqrt(.Machine$double.eps))) + expect_true(all(abs(apply(out$state_prob, 2:3, sum) - 1) < sqrt(.Machine$double.eps))) +})