diff --git a/R/RcppExports.R b/R/RcppExports.R index bd2832f..fe0684c 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -129,20 +129,20 @@ logSumExp <- function(x) { .Call(`_seqHMM_logSumExp`, x) } -EM_LBFGS_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) { - .Call(`_seqHMM_EM_LBFGS_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) +EM_LBFGS_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) { + .Call(`_seqHMM_EM_LBFGS_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) } -EM_LBFGS_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) { - .Call(`_seqHMM_EM_LBFGS_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) +EM_LBFGS_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) { + .Call(`_seqHMM_EM_LBFGS_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) } -EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) { - .Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) +EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) { + .Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) } -EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) { - .Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) +EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) { + .Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) } backward_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) { diff --git a/R/bootstrap.R b/R/bootstrap.R index f8dc75c..41b9cb2 100644 --- a/R/bootstrap.R +++ b/R/bootstrap.R @@ -99,7 +99,6 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000, init <- model$etas gammas_mle <- model$gammas lambda <- model$estimation_results$lambda - pseudocount <- model$estimation_results$pseudocount bound <- model$estimation_results$bound p <- progressr::progressor(along = seq_len(nsim)) original_options <- options(future.globals.maxSize = Inf) @@ -109,8 +108,7 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000, seq_len(nsim), function(i) { mod <- bootstrap_model(model) fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda, - method = method, pseudocount = pseudocount, - bound = bound, ...) + method = method, bound = bound, ...) if (fit$estimation_results$return_code >= 0) { fit$gammas <- permute_states(fit$gammas, gammas_mle) } else { @@ -137,8 +135,7 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000, N, T_, M, S, formula_pi, formula_A, formula_B, d, time, id, init, 0)$model fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda, - method = method, pseudocount = pseudocount, - bound = bound, ...) + method = method, bound = bound, ...) if (fit$estimation_results$return_code >= 0) { fit$gammas <- permute_states(fit$gammas, gammas_mle) } else { @@ -185,7 +182,6 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000, gammas_mle <- model$gammas pcp_mle <- posterior_cluster_probabilities(model) lambda <- model$estimation_results$lambda - pseudocount <- model$estimation_results$pseudocount bound <- model$estimation_results$bound D <- model$n_clusters p <- progressr::progressor(along = seq_len(nsim)) @@ -196,8 +192,7 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000, seq_len(nsim), function(i) { mod <- bootstrap_model(model) fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda, - method = method, pseudocount = pseudocount, - bound = bound, ...) + method = method, bound = bound, ...) if (fit$estimation_results$return_code >= 0) { fit <- permute_clusters(fit, pcp_mle) for (j in seq_len(D)) { @@ -234,8 +229,7 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000, N, T_, M, S, D, formula_pi, formula_A, formula_B, formula_omega, d, time, id, init, 0)$model fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda, - method = method, pseudocount = pseudocount, - bound = bound, ...) + method = method, bound = bound, ...) if (fit$estimation_results$return_code >= 0) { fit <- permute_clusters(fit, pcp_mle) for (j in seq_len(D)) { diff --git a/R/dnm_mnhmm.R b/R/dnm_mnhmm.R index 95228cd..6a24b60 100644 --- a/R/dnm_mnhmm.R +++ b/R/dnm_mnhmm.R @@ -132,8 +132,8 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, out <- future.apply::future_lapply(seq_len(restarts), function(i) { init <- unlist(create_initial_values(inits, model, init_sd)) fit <- nloptr( - x0 = init, eval_f = objectivef, lb = -bound, ub = bound, - opts = control_restart + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control_restart ) p() fit @@ -162,8 +162,8 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, } out <- nloptr( - x0 = init, eval_f = objectivef, lb = -bound, ub = bound, - opts = control + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control ) if (out$status < 0) { warning_( @@ -205,7 +205,6 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL, all_solutions = all_solutions, lambda = lambda, - pseudocount = 0, bound = bound, method = "DNM", algorithm = control$algorithm diff --git a/R/dnm_nhmm.R b/R/dnm_nhmm.R index a751d8b..fd2ae46 100644 --- a/R/dnm_nhmm.R +++ b/R/dnm_nhmm.R @@ -98,8 +98,8 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, out <- future.apply::future_lapply(seq_len(restarts), function(i) { init <- unlist(create_initial_values(inits, model, init_sd)) fit <- nloptr( - x0 = init, eval_f = objectivef, lb = -bound, ub = bound, - opts = control_restart + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control_restart ) p() fit @@ -128,8 +128,8 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, } out <- nloptr( - x0 = init, eval_f = objectivef, lb = -bound, ub = bound, - opts = control + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control ) if (out$status < 0) { warning_( @@ -165,7 +165,6 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL, all_solutions = all_solutions, lambda = lambda, - pseudocount = 0, bound = bound, method = "DNM", algorithm = control$algorithm diff --git a/R/em_dnm_mnhmm.R b/R/em_dnm_mnhmm.R index 8f83624..f1dc53a 100644 --- a/R/em_dnm_mnhmm.R +++ b/R/em_dnm_mnhmm.R @@ -1,4 +1,4 @@ -em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, +em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions) { M <- model$n_symbols @@ -143,7 +143,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { eta_B <- unlist(init$B, recursive = FALSE) fit <- EM_LBFGS_mnhmm_multichannel( @@ -156,7 +156,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } em_return_code <- fit$return_code if (em_return_code >= 0) { @@ -171,8 +171,8 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, ) ) fit <- nloptr( - x0 = init, eval_f = objectivef, - lb = -bound, ub = bound, opts = control_restart + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control_restart ) p() fit @@ -229,7 +229,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { eta_B <- unlist(init$B, recursive = FALSE) out <- EM_LBFGS_mnhmm_multichannel( @@ -242,7 +242,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } em_return_code <- out$return_code if (em_return_code >= 0) { @@ -263,8 +263,9 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, } } out <- nloptr( - x0 = unlist(init), eval_f = objectivef, lb = -bound, ub = bound, - opts = control + x0 = unlist(init), eval_f = objectivef, + lb = -rep(bound, length(unlist(init))), + ub = rep(bound, length(unlist(init))), opts = control ) if (out$status < 0) { warning_( @@ -306,7 +307,6 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL, all_solutions = all_solutions, lambda = lambda, - pseudocount = pseudocount, bound = bound, method = "EM-DNM", algorithm = control$algorithm, diff --git a/R/em_dnm_nhmm.R b/R/em_dnm_nhmm.R index b49166c..40bdbe4 100644 --- a/R/em_dnm_nhmm.R +++ b/R/em_dnm_nhmm.R @@ -1,4 +1,4 @@ -em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, +em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions) { M <- model$n_symbols @@ -108,7 +108,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { fit <- EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -119,7 +119,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } em_return_code <- fit$return_code if (em_return_code >= 0) { @@ -133,8 +133,8 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, ) ) fit <- nloptr( - x0 = init, eval_f = objectivef, lb = -bound, ub = bound, - opts = control_restart + x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), + ub = rep(bound, length(init)), opts = control_restart ) p() fit @@ -187,7 +187,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { out <- EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -198,7 +198,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } em_return_code <- out$return_code if (em_return_code >= 0) { @@ -218,8 +218,9 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, } } out <- nloptr( - x0 = unlist(init), eval_f = objectivef, lb = -bound, ub = bound, - opts = control + x0 = unlist(init), eval_f = objectivef, + lb = -rep(bound, length(unlist(init))), + ub = rep(bound, length(unlist(init))), opts = control ) if (out$status < 0) { warning_( @@ -252,7 +253,6 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL, all_solutions = all_solutions, lambda = lambda, - pseudocount = pseudocount, bound = bound, method = "EM-DNM", algorithm = control$algorithm, diff --git a/R/em_mnhmm.R b/R/em_mnhmm.R index e896232..0575be1 100644 --- a/R/em_mnhmm.R +++ b/R/em_mnhmm.R @@ -1,4 +1,4 @@ -em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, +em_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions) { M <- model$n_symbols @@ -50,7 +50,7 @@ em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { eta_B <- unlist(init$B, recursive = FALSE) fit <- EM_LBFGS_mnhmm_multichannel( @@ -63,7 +63,7 @@ em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } p() fit @@ -99,7 +99,7 @@ em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { eta_B <- unlist(init$B, recursive = FALSE) out <- EM_LBFGS_mnhmm_multichannel( @@ -112,7 +112,7 @@ em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } if (out$return_code < 0) { warning_( @@ -148,7 +148,6 @@ em_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, x_rel_change = out$relative_x_change, x_abs_change = out$absolute_x_change, lambda = lambda, - pseudocount = pseudocount, bound = bound, method = "EM" ) diff --git a/R/em_nhmm.R b/R/em_nhmm.R index 8d6577a..01e3132 100644 --- a/R/em_nhmm.R +++ b/R/em_nhmm.R @@ -1,4 +1,4 @@ -em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, +em_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions) { M <- model$n_symbols @@ -44,7 +44,7 @@ em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { fit <- EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -55,7 +55,7 @@ em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control_restart$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } p() fit @@ -89,7 +89,7 @@ em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } else { out <- EM_LBFGS_nhmm_multichannel( init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs, @@ -100,7 +100,7 @@ em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, control$print_level, control_mstep$maxeval, control_mstep$ftol_abs, control_mstep$ftol_rel, control_mstep$xtol_abs, control_mstep$xtol_rel, - control_mstep$print_level, lambda, pseudocount, bound) + control_mstep$print_level, lambda, bound) } if (out$return_code < 0) { warning_( @@ -132,7 +132,6 @@ em_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount, x_rel_change = out$relative_x_change, x_abs_change = out$absolute_x_change, lambda = lambda, - pseudocount = pseudocount, bound = bound, method = "EM" ) diff --git a/R/estimate_mnhmm.R b/R/estimate_mnhmm.R index 06d2c79..f6bfa79 100644 --- a/R/estimate_mnhmm.R +++ b/R/estimate_mnhmm.R @@ -44,7 +44,7 @@ estimate_mnhmm <- function( transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1, data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2, - restarts = 0L, lambda = 1e-4, method = "EM-DNM", pseudocount = 0, + restarts = 0L, lambda = 1e-4, method = "EM-DNM", bound = 50, store_data = TRUE, ...) { call <- match.call() @@ -61,8 +61,7 @@ estimate_mnhmm <- function( model$data <- data } start_time <- proc.time() - out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, - pseudocount, bound, ...) + out <- fit_mnhmm(model, inits, init_sd, restarts, lambda, method, bound, ...) end_time <- proc.time() out$estimation_results$time <- end_time - start_time attr(out, "call") <- call diff --git a/R/estimate_nhmm.R b/R/estimate_nhmm.R index 70a0bd6..47ccf8c 100644 --- a/R/estimate_nhmm.R +++ b/R/estimate_nhmm.R @@ -74,25 +74,22 @@ #' @param restarts Number of times to run optimization using random starting #' values (in addition to the final run). Default is 0. #' @param lambda Penalization factor `lambda` for penalized log-likelihood, where the -#' penalization is `0.5 * lambda * sum(parameters^2)`. Note that with +#' penalization is `0.5 * lambda * sum(eta^2)`. Note that with #' `method = "L-BFGS"` both objective function (log-likelihood) and #' the penalization term is scaled with number of non-missing observations. -#' Default is `1e-4` for ensuring numerical stability of L-BFGS by avoiding -#' extreme probabilities. +#' Default is `0`, but small values such as `1e-4` can help to ensure numerical +#' stability of L-BFGS by avoiding extreme probabilities. See also argument +#' `bound` for hard constraints. #' @param method Optimization method used. Option `"EM"` uses EM #' algorithm with L-BFGS in the M-step. Option `"DNM"` uses #' direct maximization of the log-likelihood, by default using L-BFGS. Option #' `"EM-DNM"` (the default) runs first a maximum of 10 iterations of EM and #' then switches to L-BFGS (but other algorithms of NLopt can be used). -#' @param pseudocount A positive scalar to be added for the expected counts of -#' E-step. Only used in EM and EM-DNM algorithms. Default is 0. Larger values -#' can be used to avoid extreme initial, transition, and emission -#' probabilities, i.e. these have similar role as `lambda`. #' @param bound Positive value defining the hard bounds for the working #' parameters \eqn{\eta}, which are used to avoid extreme probabilities and #' corresponding numerical issues especially in the M-step of EM algorithm. #' Default is 50, i.e., \eqn{-50<\eta<50}. Note that he bounds are not enforced -#' for M-step in intercept-only case with `lambda=0`. +#' for M-step in intercept-only case with `lambda = 0`. #' @param store_data If `TRUE` (default), original data frame passed as `data` #' is stored to the model object. For large datasets, this can be set to #' `FALSE`, in which case you might need to pass the data separately to some @@ -120,7 +117,7 @@ estimate_nhmm <- function( transition_formula = ~1, emission_formula = ~1, data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL, inits = "random", init_sd = 2, restarts = 0L, - lambda = 1e-4, method = "EM-DNM", pseudocount = 0, bound = 50, + lambda = 1e-4, method = "EM-DNM", bound = 50, store_data = TRUE, ...) { @@ -138,8 +135,7 @@ estimate_nhmm <- function( model$data <- data } start_time <- proc.time() - out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, pseudocount, - bound, ...) + out <- fit_nhmm(model, inits, init_sd, restarts, lambda, method, bound, ...) end_time <- proc.time() out$estimation_results$time <- end_time - start_time attr(out, "call") <- call diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 422ef4c..e703207 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -2,7 +2,7 @@ #' #' @noRd fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, - pseudocount, bound, save_all_solutions = FALSE, + bound, save_all_solutions = FALSE, control_restart = list(), control_mstep = list(), ...) { stopifnot_( @@ -68,7 +68,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, all_solutions <- NULL if (method == "EM-DNM") { out <- em_dnm_mnhmm( - model, inits, init_sd, restarts, lambda, pseudocount, bound, control, + model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions ) } @@ -80,7 +80,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, } if (method == "EM") { out <- em_mnhmm( - model, inits, init_sd, restarts, lambda, pseudocount, bound, control, + model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions ) } diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index 17e84b2..6b47dbd 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -1,7 +1,7 @@ #' Estimate a Non-homogeneous Hidden Markov Model #' #' @noRd -fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocount, +fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, bound, save_all_solutions = FALSE, control_restart = list(), control_mstep = list(), ...) { @@ -64,7 +64,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun } if (method == "EM-DNM") { out <- em_dnm_nhmm( - model, inits, init_sd, restarts, lambda, pseudocount, bound, control, + model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions ) } @@ -76,7 +76,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun } if (method == "EM") { out <- em_nhmm( - model, inits, init_sd, restarts, lambda, pseudocount, bound, control, + model, inits, init_sd, restarts, lambda, bound, control, control_restart, control_mstep, save_all_solutions ) } diff --git a/R/utilities.R b/R/utilities.R index 47858c3..f25d1e8 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -159,8 +159,8 @@ error_msg <- function(error) { if (!is.null(gamma) && error %in% (-100 * (1:4))) { return(paste0(x, "Error in M-step of ", gamma, " encountered expected count of zero. ", - "Try increasing the pseudocounts or regularization via lambda to avoid ", - "extreme probabilities.") + "Try increasing the regularization via lambda or adjust parameter bounds ", + "to avoid extreme probabilities.") ) } diff --git a/man/estimate_mnhmm.Rd b/man/estimate_mnhmm.Rd index 6315171..d51ce62 100644 --- a/man/estimate_mnhmm.Rd +++ b/man/estimate_mnhmm.Rd @@ -23,7 +23,6 @@ estimate_mnhmm( restarts = 0L, lambda = 1e-04, method = "EM-DNM", - pseudocount = 0, bound = 50, store_data = TRUE, ... @@ -87,11 +86,12 @@ of the regression coefficients to zero, use \code{init_sd = 0}.} values (in addition to the final run). Default is 0.} \item{lambda}{Penalization factor \code{lambda} for penalized log-likelihood, where the -penalization is \code{0.5 * lambda * sum(parameters^2)}. Note that with +penalization is \code{0.5 * lambda * sum(eta^2)}. Note that with \code{method = "L-BFGS"} both objective function (log-likelihood) and the penalization term is scaled with number of non-missing observations. -Default is \code{1e-4} for ensuring numerical stability of L-BFGS by avoiding -extreme probabilities.} +Default is \code{0}, but small values such as \code{1e-4} can help to ensure numerical +stability of L-BFGS by avoiding extreme probabilities. See also argument +\code{bound} for hard constraints.} \item{method}{Optimization method used. Option \code{"EM"} uses EM algorithm with L-BFGS in the M-step. Option \code{"DNM"} uses @@ -99,16 +99,11 @@ direct maximization of the log-likelihood, by default using L-BFGS. Option \code{"EM-DNM"} (the default) runs first a maximum of 10 iterations of EM and then switches to L-BFGS (but other algorithms of NLopt can be used).} -\item{pseudocount}{A positive scalar to be added for the expected counts of -E-step. Only used in EM and EM-DNM algorithms. Default is 0. Larger values -can be used to avoid extreme initial, transition, and emission -probabilities, i.e. these have similar role as \code{lambda}.} - \item{bound}{Positive value defining the hard bounds for the working parameters \eqn{\eta}, which are used to avoid extreme probabilities and corresponding numerical issues especially in the M-step of EM algorithm. Default is 50, i.e., \eqn{-50<\eta<50}. Note that he bounds are not enforced -for M-step in intercept-only case with \code{lambda=0}.} +for M-step in intercept-only case with \code{lambda = 0}.} \item{store_data}{If \code{TRUE} (default), original data frame passed as \code{data} is stored to the model object. For large datasets, this can be set to diff --git a/man/estimate_nhmm.Rd b/man/estimate_nhmm.Rd index af27640..b2c3fb7 100644 --- a/man/estimate_nhmm.Rd +++ b/man/estimate_nhmm.Rd @@ -20,7 +20,6 @@ estimate_nhmm( restarts = 0L, lambda = 1e-04, method = "EM-DNM", - pseudocount = 0, bound = 50, store_data = TRUE, ... @@ -74,11 +73,12 @@ of the regression coefficients to zero, use \code{init_sd = 0}.} values (in addition to the final run). Default is 0.} \item{lambda}{Penalization factor \code{lambda} for penalized log-likelihood, where the -penalization is \code{0.5 * lambda * sum(parameters^2)}. Note that with +penalization is \code{0.5 * lambda * sum(eta^2)}. Note that with \code{method = "L-BFGS"} both objective function (log-likelihood) and the penalization term is scaled with number of non-missing observations. -Default is \code{1e-4} for ensuring numerical stability of L-BFGS by avoiding -extreme probabilities.} +Default is \code{0}, but small values such as \code{1e-4} can help to ensure numerical +stability of L-BFGS by avoiding extreme probabilities. See also argument +\code{bound} for hard constraints.} \item{method}{Optimization method used. Option \code{"EM"} uses EM algorithm with L-BFGS in the M-step. Option \code{"DNM"} uses @@ -86,16 +86,11 @@ direct maximization of the log-likelihood, by default using L-BFGS. Option \code{"EM-DNM"} (the default) runs first a maximum of 10 iterations of EM and then switches to L-BFGS (but other algorithms of NLopt can be used).} -\item{pseudocount}{A positive scalar to be added for the expected counts of -E-step. Only used in EM and EM-DNM algorithms. Default is 0. Larger values -can be used to avoid extreme initial, transition, and emission -probabilities, i.e. these have similar role as \code{lambda}.} - \item{bound}{Positive value defining the hard bounds for the working parameters \eqn{\eta}, which are used to avoid extreme probabilities and corresponding numerical issues especially in the M-step of EM algorithm. Default is 50, i.e., \eqn{-50<\eta<50}. Note that he bounds are not enforced -for M-step in intercept-only case with \code{lambda=0}.} +for M-step in intercept-only case with \code{lambda = 0}.} \item{store_data}{If \code{TRUE} (default), original data frame passed as \code{data} is stored to the model object. For large datasets, this can be set to diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f1b7601..652d087 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -478,8 +478,8 @@ BEGIN_RCPP END_RCPP } // EM_LBFGS_mnhmm_singlechannel -Rcpp::List EM_LBFGS_mnhmm_singlechannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount, const double bound); -RcppExport SEXP _seqHMM_EM_LBFGS_mnhmm_singlechannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP, SEXP boundSEXP) { +Rcpp::List EM_LBFGS_mnhmm_singlechannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double bound); +RcppExport SEXP _seqHMM_EM_LBFGS_mnhmm_singlechannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP boundSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -515,15 +515,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); Rcpp::traits::input_parameter< const double >::type bound(boundSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_mnhmm_singlechannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_mnhmm_singlechannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)); return rcpp_result_gen; END_RCPP } // EM_LBFGS_mnhmm_multichannel -Rcpp::List EM_LBFGS_mnhmm_multichannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount, const double bound); -RcppExport SEXP _seqHMM_EM_LBFGS_mnhmm_multichannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP, SEXP boundSEXP) { +Rcpp::List EM_LBFGS_mnhmm_multichannel(arma::mat& eta_omega, const arma::mat& X_omega, arma::field& eta_pi, const arma::mat& X_pi, arma::field& eta_A, const arma::cube& X_A, arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec Ti, const bool icpt_only_omega, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double bound); +RcppExport SEXP _seqHMM_EM_LBFGS_mnhmm_multichannel(SEXP eta_omegaSEXP, SEXP X_omegaSEXP, SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_omegaSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP boundSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -559,15 +558,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); Rcpp::traits::input_parameter< const double >::type bound(boundSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_mnhmm_multichannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_mnhmm_multichannel(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)); return rcpp_result_gen; END_RCPP } // EM_LBFGS_nhmm_singlechannel -Rcpp::List EM_LBFGS_nhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount, const double bound); -RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP, SEXP boundSEXP) { +Rcpp::List EM_LBFGS_nhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double bound); +RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP boundSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -600,15 +598,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); Rcpp::traits::input_parameter< const double >::type bound(boundSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)); return rcpp_result_gen; END_RCPP } // EM_LBFGS_nhmm_multichannel -Rcpp::List EM_LBFGS_nhmm_multichannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount, const double bound); -RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP, SEXP boundSEXP) { +Rcpp::List EM_LBFGS_nhmm_multichannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::field& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double bound); +RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP boundSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -641,9 +638,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const double >::type xtol_rel_m(xtol_rel_mSEXP); Rcpp::traits::input_parameter< const arma::uword >::type print_level_m(print_level_mSEXP); Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); - Rcpp::traits::input_parameter< const double >::type pseudocount(pseudocountSEXP); Rcpp::traits::input_parameter< const double >::type bound(boundSEXP); - rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)); + rcpp_result_gen = Rcpp::wrap(EM_LBFGS_nhmm_multichannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)); return rcpp_result_gen; END_RCPP } @@ -1639,10 +1635,10 @@ static const R_CallMethodDef CallEntries[] = { {"_seqHMM_get_B_ame", (DL_FUNC) &_seqHMM_get_B_ame, 5}, {"_seqHMM_get_omega_ame", (DL_FUNC) &_seqHMM_get_omega_ame, 4}, {"_seqHMM_logSumExp", (DL_FUNC) &_seqHMM_logSumExp, 1}, - {"_seqHMM_EM_LBFGS_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_mnhmm_singlechannel, 34}, - {"_seqHMM_EM_LBFGS_mnhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_mnhmm_multichannel, 34}, - {"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 31}, - {"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 31}, + {"_seqHMM_EM_LBFGS_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_mnhmm_singlechannel, 33}, + {"_seqHMM_EM_LBFGS_mnhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_mnhmm_multichannel, 33}, + {"_seqHMM_EM_LBFGS_nhmm_singlechannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_singlechannel, 30}, + {"_seqHMM_EM_LBFGS_nhmm_multichannel", (DL_FUNC) &_seqHMM_EM_LBFGS_nhmm_multichannel, 30}, {"_seqHMM_backward_nhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_nhmm_singlechannel, 15}, {"_seqHMM_backward_nhmm_multichannel", (DL_FUNC) &_seqHMM_backward_nhmm_multichannel, 15}, {"_seqHMM_backward_mnhmm_singlechannel", (DL_FUNC) &_seqHMM_backward_mnhmm_singlechannel, 18}, diff --git a/src/mnhmm_EM.cpp b/src/mnhmm_EM.cpp index 85d1ab1..da7bbf1 100644 --- a/src/mnhmm_EM.cpp +++ b/src/mnhmm_EM.cpp @@ -28,30 +28,20 @@ double mnhmm_base::objective_omega(const arma::vec& x, arma::vec& grad) { const arma::vec& counts = E_omega.col(i); idx = arma::find(counts); if (idx.n_elem > 0) { - double sum_eo = arma::accu(counts.rows(idx)); // this is != 1 if pseudocounts are used double val = arma::dot(counts.rows(idx), log_omega.rows(idx)); if (!std::isfinite(val)) { - if (!grad.is_empty()) { - grad.zeros(); - } + grad.zeros(); return maxval; } value -= val; // Only update grad if it's non-empty (i.e., for gradient-based optimization) - if (!grad.is_empty()) { - diff.zeros(); - diff.rows(idx) = counts(idx) - sum_eo * omega.rows(idx); - grad -= arma::vectorise(tQd * diff * X_omega.col(i).t()); - if (!grad.is_finite()) { - grad.zeros(); - return maxval; - } - } + diff.zeros(); + diff.rows(idx) = counts(idx) - omega.rows(idx); + grad -= arma::vectorise(tQd * diff * X_omega.col(i).t()); } } - if (!grad.is_empty()) { - grad += lambda * x; - } + grad += lambda * x; + return value + 0.5 * lambda * std::pow(arma::norm(x, 2), 2); } @@ -73,13 +63,8 @@ void mnhmm_base::mstep_omega(const double xtol_abs, const double ftol_abs, auto objective_omega_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); - if (grad) { - arma::vec grad_vec(grad, n, false, true); - return self->objective_omega(x_vec, grad_vec); - } else { - arma::vec grad_dummy; - return self->objective_omega(x_vec, grad_dummy); - } + arma::vec grad_vec(grad, n, false, true); + return self->objective_omega(x_vec, grad_vec); }; arma::vec x_omega = arma::vectorise(eta_omega); @@ -125,30 +110,19 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) { const arma::vec& counts = E_Pi(current_d).col(i); idx = arma::find(counts); if (idx.n_elem > 0) { - double sum_epi = arma::accu(counts.rows(idx)); // this is != 1 if pseudocounts are used double val = arma::dot(counts.rows(idx), log_pi(current_d).rows(idx)); if (!std::isfinite(val)) { - if (!grad.is_empty()) { - grad.zeros(); - } + grad.zeros(); return maxval; } value -= val; - // Only update grad if it's non-empty (i.e., for gradient-based optimization) - if (!grad.is_empty()) { - diff.zeros(); - diff.rows(idx) = counts.rows(idx) - sum_epi * pi(current_d).rows(idx); - grad -= arma::vectorise(tQs * diff * X_pi.col(i).t()); - if (!grad.is_finite()) { - grad.zeros(); - return maxval; - } - } + diff.zeros(); + diff.rows(idx) = counts.rows(idx) - pi(current_d).rows(idx); + grad -= arma::vectorise(tQs * diff * X_pi.col(i).t()); } } - if (!grad.is_empty()) { - grad += lambda * x; - } + grad += lambda * x; + return value + 0.5 * lambda * std::pow(arma::norm(x, 2), 2); } @@ -172,13 +146,8 @@ void mnhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs, auto objective_pi_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); - if (grad) { - arma::vec grad_vec(grad, n, false, true); - return self->objective_pi(x_vec, grad_vec); - } else { - arma::vec grad_dummy; - return self->objective_pi(x_vec, grad_dummy); - } + arma::vec grad_vec(grad, n, false, true); + return self->objective_pi(x_vec, grad_vec); }; arma::vec x_pi = arma::vectorise(eta_pi(0)); @@ -244,27 +213,19 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) { } double val = arma::dot(counts.rows(idx), log_A1.rows(idx)); if (!std::isfinite(val)) { - if (!grad.is_empty()) { - grad.zeros(); - } + grad.zeros(); return maxval; } value -= val; - if (!grad.is_empty()) { - diff.zeros(); - diff.rows(idx) = counts.rows(idx) - sum_ea * A1.rows(idx); - grad -= arma::vectorise(tQs * diff * X_A.slice(i).col(t).t()); - if (!grad.is_finite()) { - grad.zeros(); - return maxval; - } - } + + diff.zeros(); + diff.rows(idx) = counts.rows(idx) - sum_ea * A1.rows(idx); + grad -= arma::vectorise(tQs * diff * X_A.slice(i).col(t).t()); } } } - if (!grad.is_empty()) { - grad += lambda * x; - } + grad += lambda * x; + return value + 0.5 * lambda * std::pow(arma::norm(x, 2), 2); } void mnhmm_base::mstep_A(const double ftol_abs, const double ftol_rel, @@ -295,13 +256,8 @@ void mnhmm_base::mstep_A(const double ftol_abs, const double ftol_rel, auto objective_A_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); - if (grad) { - arma::vec grad_vec(grad, n, false, true); - return self->objective_A(x_vec, grad_vec); - } else { - arma::vec grad_dummy; - return self->objective_A(x_vec, grad_dummy); - } + arma::vec grad_vec(grad, n, false, true); + return self->objective_A(x_vec, grad_vec); }; arma::vec x_A(eta_A(0).slice(0).n_elem); @@ -370,27 +326,18 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) { double val = e_b * log_B1(obs(t, i)); if (!std::isfinite(val)) { - if (!grad.is_empty()) { - grad.zeros(); - } + grad.zeros(); return maxval; } value -= val; - if (!grad.is_empty()) { - grad -= arma::vectorise(tQm * - e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t()); - if (!grad.is_finite()) { - grad.zeros(); - return maxval; - } - } + grad -= arma::vectorise(tQm * + e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t()); } } } } - if (!grad.is_empty()) { - grad += lambda * x; - } + grad += lambda * x; + return value + 0.5 * lambda * std::pow(arma::norm(x, 2), 2); } void mnhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, @@ -424,13 +371,8 @@ void mnhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, auto objective_B_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); - if (grad) { - arma::vec grad_vec(grad, n, false, true); - return self->objective_B(x_vec, grad_vec); - } else { - arma::vec grad_dummy; - return self->objective_B(x_vec, grad_dummy); - } + arma::vec grad_vec(grad, n, false, true); + return self->objective_B(x_vec, grad_vec); }; arma::vec x_B(eta_B(0).slice(0).n_elem); nlopt_opt opt_B = nlopt_create(NLOPT_LD_LBFGS, x_B.n_elem); @@ -500,27 +442,18 @@ double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) { } double val = e_b * log_B1(obs(current_c, t, i)); if (!std::isfinite(val)) { - if (!grad.is_empty()) { - grad.zeros(); - } + grad.zeros(); return maxval; } value -= val; - if (!grad.is_empty()) { - grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) * - X_B.slice(i).col(t).t()); - if (!grad.is_finite()) { - grad.zeros(); - return maxval; - } - } + grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) * + X_B.slice(i).col(t).t()); } } } } - if (!grad.is_empty()) { - grad += lambda * x; - } + grad += lambda * x; + return value + 0.5 * lambda * std::pow(arma::norm(x, 2), 2); } void mnhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel, @@ -556,13 +489,8 @@ void mnhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel, auto objective_B_wrapper = [](unsigned n, const double* x, double* grad, void* data) -> double { auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); - if (grad) { - arma::vec grad_vec(grad, n, false, true); - return self->objective_B(x_vec, grad_vec); - } else { - arma::vec grad_dummy; - return self->objective_B(x_vec, grad_dummy); - } + arma::vec grad_vec(grad, n, false, true); + return self->objective_B(x_vec, grad_vec); }; double minf; int return_code; @@ -617,7 +545,7 @@ Rcpp::List EM_LBFGS_mnhmm_singlechannel( const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, - const double lambda, const double pseudocount, const double bound) { + const double lambda, const double bound) { mnhmm_sc model( eta_A(0).n_slices, eta_A.n_rows, X_omega, X_pi, X_A, X_B, Ti, @@ -857,7 +785,7 @@ Rcpp::List EM_LBFGS_mnhmm_multichannel( const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, - const double lambda, const double pseudocount, const double bound) { + const double lambda, const double bound) { mnhmm_mc model( eta_A(0).n_slices, eta_A.n_rows, X_omega, X_pi, X_A, X_B, Ti, diff --git a/src/mnhmm_base.h b/src/mnhmm_base.h index a52a2a1..6a1aa52 100644 --- a/src/mnhmm_base.h +++ b/src/mnhmm_base.h @@ -232,25 +232,24 @@ struct mnhmm_base { } void estep_omega(const arma::uword i, const arma::vec ll_i, - const double ll, const double pseudocount = 0) { - E_omega.col(i) = arma::exp(ll_i - ll) + pseudocount; + const double ll) { + E_omega.col(i) = arma::exp(ll_i - ll); } void estep_pi(const arma::uword i, const arma::uword d, const arma::vec& log_alpha, - const arma::vec& log_beta, const double ll, - const double pseudocount = 0) { - E_Pi(d).col(i) = arma::exp(log_alpha + log_beta - ll) + pseudocount; + const arma::vec& log_beta, const double ll) { + E_Pi(d).col(i) = arma::exp(log_alpha + log_beta - ll); } void estep_A(const arma::uword i, const arma::uword d, const arma::mat& log_alpha, const arma::mat& log_beta, - const double ll, const double pseudocount = 0) { + const double ll) { for (arma::uword k = 0; k < S; k++) { // from for (arma::uword j = 0; j < S; j++) { // to for (arma::uword t = 0; t < (Ti(i) - 1); t++) { // time E_A(k, d)(j, i, t) = exp(log_alpha(k, t) + log_A(d)(k, j, t) + - log_beta(j, t + 1) + log_py(j, t + 1, d) - ll) + pseudocount; + log_beta(j, t + 1) + log_py(j, t + 1, d) - ll); } } } diff --git a/src/mnhmm_mc.h b/src/mnhmm_mc.h index f2cc49b..00c0998 100644 --- a/src/mnhmm_mc.h +++ b/src/mnhmm_mc.h @@ -136,13 +136,13 @@ struct mnhmm_mc : public mnhmm_base { } void estep_B(const arma::uword i, const arma::uword d, const arma::mat& log_alpha, const arma::mat& log_beta, - const double ll, const double pseudocount = 0) { + const double ll) { for (arma::uword k = 0; k < S; k++) { // state for (arma::uword t = 0; t < Ti(i); t++) { // time double pp = exp(log_alpha(k, t) + log_beta(k, t) - ll); for (arma::uword c = 0; c < C; c++) { // channel if (obs(c, t, i) < M(c)) { - E_B(c, d)(t, i, k) = pp + pseudocount; + E_B(c, d)(t, i, k) = pp; } else { E_B(c, d)(t, i, k) = 0.0; } diff --git a/src/mnhmm_sc.h b/src/mnhmm_sc.h index a65fe20..d383e55 100644 --- a/src/mnhmm_sc.h +++ b/src/mnhmm_sc.h @@ -111,11 +111,11 @@ struct mnhmm_sc : public mnhmm_base { } void estep_B(const arma::uword i, const arma::uword d, const arma::mat& log_alpha, const arma::mat& log_beta, - const double ll, const double pseudocount = 0) { + const double ll) { for (arma::uword k = 0; k < S; k++) { // state for (arma::uword t = 0; t < Ti(i); t++) { // time if (obs(t, i) < M) { - E_B(d)(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll) + pseudocount; + E_B(d)(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll); } else { E_B(d)(t, i, k) = 0.0; } diff --git a/src/nhmm_EM.cpp b/src/nhmm_EM.cpp index 0277d43..32ca63d 100644 --- a/src/nhmm_EM.cpp +++ b/src/nhmm_EM.cpp @@ -27,7 +27,6 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) { const arma::vec& counts = E_Pi.col(i); idx = arma::find(counts); if (idx.n_elem > 0) { - double sum_epi = arma::accu(counts.rows(idx)); // this is != 1 if pseudocounts are used double val = arma::dot(counts.rows(idx), log_pi.rows(idx)); if (!std::isfinite(val)) { grad.zeros(); @@ -36,7 +35,7 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) { value -= val; diff.zeros(); - diff.rows(idx) = counts.rows(idx) - sum_epi * pi.rows(idx); + diff.rows(idx) = counts.rows(idx) - pi.rows(idx); grad -= arma::vectorise(tQs * diff * X_pi.col(i).t()); } @@ -438,7 +437,7 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, - const double lambda, const double pseudocount, const double bound) { + const double lambda, const double bound) { nhmm_sc model( eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A, @@ -488,9 +487,9 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( ); double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); ll += ll_i; - model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount); - model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount); - model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount); + model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); + model.estep_A(i, log_alpha, log_beta, ll_i); + model.estep_B(i, log_alpha, log_beta, ll_i); } double penalty_term = 0.5 * lambda * std::pow(arma::norm(pars, 2), 2); ll -= penalty_term; @@ -573,9 +572,9 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( ); double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); ll_new += ll_i; - model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount); - model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount); - model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount); + model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); + model.estep_A(i, log_alpha, log_beta, ll_i); + model.estep_B(i, log_alpha, log_beta, ll_i); } pars_new.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); @@ -644,7 +643,7 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, - const double lambda, const double pseudocount, const double bound) { + const double lambda, const double bound) { nhmm_mc model( eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A, @@ -702,9 +701,9 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( ); double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); ll += ll_i; - model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount); - model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount); - model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount); + model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); + model.estep_A(i, log_alpha, log_beta, ll_i); + model.estep_B(i, log_alpha, log_beta, ll_i); } double penalty_term = 0.5 * lambda * std::pow(arma::norm(pars, 2), 2); ll -= penalty_term; @@ -786,9 +785,9 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( ); double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1)); ll_new += ll_i; - model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount); - model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount); - model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount); + model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i); + model.estep_A(i, log_alpha, log_beta, ll_i); + model.estep_B(i, log_alpha, log_beta, ll_i); } pars_new.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t(); pars_new.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); diff --git a/src/nhmm_base.h b/src/nhmm_base.h index 24f2ef1..3c5f320 100644 --- a/src/nhmm_base.h +++ b/src/nhmm_base.h @@ -141,19 +141,17 @@ struct nhmm_base { } void estep_pi(const arma::uword i, const arma::vec& log_alpha, - const arma::vec& log_beta, const double ll, - const double pseudocount = 0) { - E_Pi.col(i) = arma::exp(log_alpha + log_beta - ll) + pseudocount; + const arma::vec& log_beta, const double ll) { + E_Pi.col(i) = arma::exp(log_alpha + log_beta - ll); } void estep_A(const arma::uword i, const arma::mat& log_alpha, - const arma::mat& log_beta, const double ll, - const double pseudocount = 0) { + const arma::mat& log_beta, const double ll) { for (arma::uword k = 0; k < S; k++) { // from for (arma::uword j = 0; j < S; j++) { // to for (arma::uword t = 0; t < (Ti(i) - 1); t++) { // time E_A(k)(j, i, t) = exp(log_alpha(k, t) + log_A(k, j, t) + - log_beta(j, t + 1) + log_py(j, t + 1) - ll) + pseudocount; + log_beta(j, t + 1) + log_py(j, t + 1) - ll); } } } diff --git a/src/nhmm_mc.h b/src/nhmm_mc.h index ce1bfee..43d204c 100644 --- a/src/nhmm_mc.h +++ b/src/nhmm_mc.h @@ -114,14 +114,13 @@ struct nhmm_mc : public nhmm_base { } } void estep_B(const arma::uword i, const arma::mat& log_alpha, - const arma::mat& log_beta, const double ll, - const double pseudocount = 0) { + const arma::mat& log_beta, const double ll) { for (arma::uword k = 0; k < S; k++) { // state for (arma::uword t = 0; t < Ti(i); t++) { // time double pp = exp(log_alpha(k, t) + log_beta(k, t) - ll); for (arma::uword c = 0; c < C; c++) { // channel if (obs(c, t, i) < M(c)) { - E_B(c)(t, i, k) = pp + pseudocount; + E_B(c)(t, i, k) = pp; } else { E_B(c)(t, i, k) = 0.0; } diff --git a/src/nhmm_sc.h b/src/nhmm_sc.h index 9eec310..e742ae8 100644 --- a/src/nhmm_sc.h +++ b/src/nhmm_sc.h @@ -90,12 +90,11 @@ struct nhmm_sc : public nhmm_base { } } void estep_B(const arma::uword i, const arma::mat& log_alpha, - const arma::mat& log_beta, const double ll, - const double pseudocount = 0) { + const arma::mat& log_beta, const double ll) { for (arma::uword k = 0; k < S; k++) { // state for (arma::uword t = 0; t < Ti(i); t++) { // time if (obs(t, i) < M) { - E_B(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll) + pseudocount; + E_B(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll); } else { E_B(t, i, k) = 0.0; }