From 039f5f8419fc90ddd51e64e4c14e0e6a50c12e5a Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Sat, 9 Nov 2024 22:23:10 +0200 Subject: [PATCH] pseudocounts --- R/RcppExports.R | 8 ++++---- R/estimate_mnhmm.R | 6 ++++-- R/estimate_nhmm.R | 9 +++++++-- R/fit_mnhmm.R | 11 ++++------- R/fit_nhmm.R | 15 ++++++--------- man/estimate_mnhmm.Rd | 1 + man/estimate_nhmm.Rd | 6 ++++++ src/RcppExports.cpp | 18 ++++++++++-------- src/nhmm_EM.cpp | 28 ++++++++++++++++------------ 9 files changed, 58 insertions(+), 44 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 7ba41d0..f0b5d08 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -121,12 +121,12 @@ logSumExp <- function(x) { .Call(`_seqHMM_logSumExp`, x) } -EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) { - .Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) +EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) { + .Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) } -EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) { - .Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda) +EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) { + .Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount) } backward_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) { diff --git a/R/estimate_mnhmm.R b/R/estimate_mnhmm.R index 790b8b7..67c3d45 100644 --- a/R/estimate_mnhmm.R +++ b/R/estimate_mnhmm.R @@ -44,7 +44,8 @@ estimate_mnhmm <- function( transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1, data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2, - restarts = 0L, lambda = 0, method = "EM", store_data = TRUE, ...) { + restarts = 0L, lambda = 0, method = "EM", pseudocount = 0, + store_data = TRUE, ...) { call <- match.call() model <- build_mnhmm( @@ -63,7 +64,8 @@ estimate_mnhmm <- function( if (store_data) { model$data <- data } - out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, ...) + out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, + pseudocount, ...) attr(out, "call") <- call out diff --git a/R/estimate_nhmm.R b/R/estimate_nhmm.R index 0173b64..0c4618a 100644 --- a/R/estimate_nhmm.R +++ b/R/estimate_nhmm.R @@ -56,6 +56,10 @@ #' @param method Optimization method used. Default is `"EM"` which uses EM #' algorithm with L-BFGS in the M-step. Another option is `"DNM"` which uses #' direct maximization of the log-likelihood using [nloptr::nloptr()]. +#' @param pseudocount. A positive scalar to be added for the expected counts of +#' E-step. Only used in EM algorithm. Default is 0. Larger values can be used +#' to avoid zero probabilities in initial, transition, and emission +#' probabilities, i.e. these have similar role as `lambda`. #' @param store_data If `TRUE` (default), original data frame passed as `data` #' is stored to the model object. For large datasets, this can be set to #' `FALSE`, in which case you might need to pass the data separately to some @@ -88,7 +92,7 @@ estimate_nhmm <- function( transition_formula = ~1, emission_formula = ~1, data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL, inits = "random", init_sd = 2, restarts = 0L, lambda = 0, method = "EM", - store_data = TRUE, ...) { + pseudocount = 0, store_data = TRUE, ...) { call <- match.call() @@ -110,7 +114,8 @@ estimate_nhmm <- function( if (store_data) { model$data <- data } - out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, ...) + out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, pseudocount, + ...) attr(out, "call") <- call out } diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index e177968..eb20326 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -1,9 +1,9 @@ #' Estimate a Mixture Non-homogeneous Hidden Markov Model #' #' @noRd -fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, - save_all_solutions = FALSE, control_restart = list(), - control_mstep = list(), ...) { +fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, + pseudocount = 0, save_all_solutions = FALSE, + control_restart = list(), control_mstep = list(), ...) { stopifnot_( checkmate::test_int(x = restarts, lower = 0L), "Argument {.arg restarts} must be a single integer." @@ -21,10 +21,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, list(...) ) control_restart <- utils::modifyList(control, control_restart) - control_mstep <- utils::modifyList( - c(control, list(pseudocount = 0)), - control_mstep - ) + control_mstep <- utils::modifyList(control, control_mstep) M <- model$n_symbols S <- model$n_states diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index 780439a..71a1301 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -1,7 +1,7 @@ #' Estimate a Non-homogeneous Hidden Markov Model #' #' @noRd -fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, +fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocount = 0, save_all_solutions = FALSE, control_restart = list(), control_mstep = list(), ...) { @@ -22,10 +22,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, list(...) ) control_restart <- utils::modifyList(control, control_restart) - control_mstep <- utils::modifyList( - c(control, list(pseudocount = 0)), - control_mstep - ) + control_mstep <- utils::modifyList(control, control_mstep) M <- model$n_symbols S <- model$n_states @@ -235,7 +232,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda) + control_mstep$print_level, lambda, pseudocount) } else { EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -246,7 +243,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda) + control_mstep$print_level, lambda, pseudocount) } }, future.seed = TRUE) @@ -274,7 +271,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, control_mstep$pseudocount) + control_mstep$print_level, lambda, pseudocount) } else { out <- EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -285,7 +282,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, control_mstep$pseudocount) + control_mstep$print_level, lambda, pseudocount) } end_time <- proc.time() # if (out$status < 0) { diff --git a/man/estimate_mnhmm.Rd b/man/estimate_mnhmm.Rd index de477dd..aaaa39d 100644 --- a/man/estimate_mnhmm.Rd +++ b/man/estimate_mnhmm.Rd @@ -23,6 +23,7 @@ estimate_mnhmm( restarts = 0L, lambda = 0, method = "EM", + pseudocount = 0, store_data = TRUE, ... ) diff --git a/man/estimate_nhmm.Rd b/man/estimate_nhmm.Rd index eb9d3d5..5e25a0c 100644 --- a/man/estimate_nhmm.Rd +++ b/man/estimate_nhmm.Rd @@ -20,6 +20,7 @@ estimate_nhmm( restarts = 0L, lambda = 0, method = "EM", + pseudocount = 0, store_data = TRUE, ... ) @@ -91,6 +92,11 @@ Other useful arguments are \code{algorithm} (default uses LBFGS), \code{print_level} (default is \code{0}, no console output of optimization), and arguments for adjusting the stopping criteria of the optimization (see details).} + +\item{pseudocount.}{A positive scalar to be added for the expected counts of +E-step. Only used in EM algorithm. Default is 0. Larger values can be used +to avoid zero probabilities in initial, transition, and emission +probabilities, i.e. these have similar role as \code{lambda}.} } \value{ Object of class \code{nhmm}. diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index b48d9e5..d6df805 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -398,8 +398,8 @@ BEGIN_RCPP END_RCPP } // EM_LBFGS_nhmm_singlechannel -Rcpp::List EM_LBFGS_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda); -RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP) { +Rcpp::List EM_LBFGS_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount); +RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -432,13 +432,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)); + Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount)); return rcpp_result_gen; END_RCPP } // EM_LBFGS_nhmm_multichannel -Rcpp::List EM_LBFGS_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda); -RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP) { +Rcpp::List EM_LBFGS_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount); +RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -471,7 +472,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda)); + Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount)); return rcpp_result_gen; END_RCPP } @@ -1355,8 +1357,8 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_get_B_ame", (DL_FUNC) &_seqHMM_get_B_ame, 5}, {"_seqHMM_get_omega_ame", (DL_FUNC) &_seqHMM_get_omega_ame, 4}, {"_seqHMM_logSumExp", (DL_FUNC) &_seqHMM_logSumExp, 1}, - {"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 29}, - {"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 29}, + {"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 30}, + {"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 30}, {"_seqHMM_backward_nhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_nhmm_singlechannel, 15}, {"_seqHMM_backward_nhmm_multichannel", (DL_FUNC) &_seqHMM_backward_nhmm_multichannel, 15}, {"_seqHMM_backward_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_mnhmm_singlechannel, 18}, diff --git a/src/nhmm_EM.cpp b/src/nhmm_EM.cpp index 8798499..4f6d1cc 100644 --- a/src/nhmm_EM.cpp +++ b/src/nhmm_EM.cpp @@ -58,7 +58,8 @@ void nhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs, Rcpp::stop( "Some of the values in gamma_pi are nonfinite likely due to zero " "expected initial state counts.\n" - "Try increasing the penalty lambda to avoid extreme probabilities." + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities." ); } return; @@ -95,8 +96,8 @@ void nhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs, if (status < 0) { Rcpp::stop( "M-step of initial state probabilities errored with error code %i.\n" - "Try increasing the penalty lambda to avoid extreme probabilities.", - status + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities.", status ); } eta_pi = arma::mat(x_pi.memptr(), S - 1, K_pi); @@ -172,7 +173,8 @@ void nhmm_base::mstep_A(const double ftol_abs, const double ftol_rel, Rcpp::stop( "Some of the values in gamma_A are nonfinite likely due to zero " "expected transition counts.\n" - "Try increasing the penalty lambda to avoid extreme probabilities." + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities." ); } } @@ -216,8 +218,8 @@ void nhmm_base::mstep_A(const double ftol_abs, const double ftol_rel, if (status < 0) { Rcpp::stop( "M-step of transition probabilities errored with error code %i.\n" - "Try increasing the penalty lambda to avoid extreme probabilities.", - status + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities.", status ); } eta_A.slice(s) = arma::mat(x_A.memptr(), S - 1, K_A); @@ -300,7 +302,8 @@ void nhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, Rcpp::stop( "Some of the values in gamma_B are nonfinite likely due to zero " "expected emission counts.\n" - "Try increasing the penalty lambda to avoid extreme probabilities." + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities." ); } } @@ -341,8 +344,8 @@ void nhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, if (status < 0) { Rcpp::stop( "M-step of emission probabilities errored with error code %i.\n" - "Try increasing the penalty lambda to avoid extreme probabilities.", - status + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities.", status ); } eta_B.slice(s) = arma::mat(x_B.memptr(), M - 1, K_B); @@ -429,7 +432,8 @@ void nhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel, Rcpp::stop( "Some of the values in gamma_B are nonfinite likely due to zero " "expected emission counts.\n" - "Try increasing the penalty lambda to avoid extreme probabilities." + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities." ); } } @@ -472,8 +476,8 @@ void nhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel, if (status < 0) { Rcpp::stop( "M-step of emission probabilities errored with error code %i.\n" - "Try increasing the penalty lambda to avoid extreme probabilities.", - status + "Try increasing the penalty lambda or adding pseudocounts " + "to avoid extreme probabilities.", status ); } eta_B(c).slice(s) = arma::mat(x_B.memptr(), M(c) - 1, K_B);