Skip to content

Commit

Permalink
add bounds, remove pseudocounts
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Dec 3, 2024
1 parent 72cbbe2 commit 2240195
Show file tree
Hide file tree
Showing 24 changed files with 167 additions and 274 deletions.
16 changes: 8 additions & 8 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,20 @@ logSumExp <- function(x) {
.Call(`_seqHMM_logSumExp`, x)
}

EM_LBFGS_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) {
.Call(`_seqHMM_EM_LBFGS_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)
EM_LBFGS_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) {
.Call(`_seqHMM_EM_LBFGS_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)
}

EM_LBFGS_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) {
.Call(`_seqHMM_EM_LBFGS_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)
EM_LBFGS_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) {
.Call(`_seqHMM_EM_LBFGS_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)
}

EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)
EM_LBFGS_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) {
.Call(`_seqHMM_EM_LBFGS_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)
}

EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, pseudocount, bound)
EM_LBFGS_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound) {
.Call(`_seqHMM_EM_LBFGS_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, n_obs, maxeval, ftol_abs, ftol_rel, xtol_abs, xtol_rel, print_level, maxeval_m, ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, print_level_m, lambda, bound)
}

backward_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) {
Expand Down
14 changes: 4 additions & 10 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000,
init <- model$etas
gammas_mle <- model$gammas
lambda <- model$estimation_results$lambda
pseudocount <- model$estimation_results$pseudocount
bound <- model$estimation_results$bound
p <- progressr::progressor(along = seq_len(nsim))
original_options <- options(future.globals.maxSize = Inf)
Expand All @@ -109,8 +108,7 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000,
seq_len(nsim), function(i) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount,
bound = bound, ...)
method = method, bound = bound, ...)
if (fit$estimation_results$return_code >= 0) {
fit$gammas <- permute_states(fit$gammas, gammas_mle)
} else {
Expand All @@ -137,8 +135,7 @@ bootstrap_coefs.nhmm <- function(model, nsim = 1000,
N, T_, M, S, formula_pi, formula_A, formula_B,
d, time, id, init, 0)$model
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount,
bound = bound, ...)
method = method, bound = bound, ...)
if (fit$estimation_results$return_code >= 0) {
fit$gammas <- permute_states(fit$gammas, gammas_mle)
} else {
Expand Down Expand Up @@ -185,7 +182,6 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
gammas_mle <- model$gammas
pcp_mle <- posterior_cluster_probabilities(model)
lambda <- model$estimation_results$lambda
pseudocount <- model$estimation_results$pseudocount
bound <- model$estimation_results$bound
D <- model$n_clusters
p <- progressr::progressor(along = seq_len(nsim))
Expand All @@ -196,8 +192,7 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
seq_len(nsim), function(i) {
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount,
bound = bound, ...)
method = method, bound = bound, ...)
if (fit$estimation_results$return_code >= 0) {
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
Expand Down Expand Up @@ -234,8 +229,7 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
N, T_, M, S, D, formula_pi, formula_A, formula_B, formula_omega,
d, time, id, init, 0)$model
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount,
bound = bound, ...)
method = method, bound = bound, ...)
if (fit$estimation_results$return_code >= 0) {
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
Expand Down
9 changes: 4 additions & 5 deletions R/dnm_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- unlist(create_initial_values(inits, model, init_sd))
fit <- nloptr(
x0 = init, eval_f = objectivef, lb = -bound, ub = bound,
opts = control_restart
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control_restart
)
p()
fit
Expand Down Expand Up @@ -162,8 +162,8 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
}

out <- nloptr(
x0 = init, eval_f = objectivef, lb = -bound, ub = bound,
opts = control
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control
)
if (out$status < 0) {
warning_(
Expand Down Expand Up @@ -205,7 +205,6 @@ dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
all_solutions = all_solutions,
lambda = lambda,
pseudocount = 0,
bound = bound,
method = "DNM",
algorithm = control$algorithm
Expand Down
9 changes: 4 additions & 5 deletions R/dnm_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- unlist(create_initial_values(inits, model, init_sd))
fit <- nloptr(
x0 = init, eval_f = objectivef, lb = -bound, ub = bound,
opts = control_restart
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control_restart
)
p()
fit
Expand Down Expand Up @@ -128,8 +128,8 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
}

out <- nloptr(
x0 = init, eval_f = objectivef, lb = -bound, ub = bound,
opts = control
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control
)
if (out$status < 0) {
warning_(
Expand Down Expand Up @@ -165,7 +165,6 @@ dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
all_solutions = all_solutions,
lambda = lambda,
pseudocount = 0,
bound = bound,
method = "DNM",
algorithm = control$algorithm
Expand Down
20 changes: 10 additions & 10 deletions R/em_dnm_mnhmm.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda,
bound, control, control_restart, control_mstep,
save_all_solutions) {
M <- model$n_symbols
Expand Down Expand Up @@ -143,7 +143,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
} else {
eta_B <- unlist(init$B, recursive = FALSE)
fit <- EM_LBFGS_mnhmm_multichannel(
Expand All @@ -156,7 +156,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
}
em_return_code <- fit$return_code
if (em_return_code >= 0) {
Expand All @@ -171,8 +171,8 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
)
)
fit <- nloptr(
x0 = init, eval_f = objectivef,
lb = -bound, ub = bound, opts = control_restart
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control_restart
)
p()
fit
Expand Down Expand Up @@ -229,7 +229,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
} else {
eta_B <- unlist(init$B, recursive = FALSE)
out <- EM_LBFGS_mnhmm_multichannel(
Expand All @@ -242,7 +242,7 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
}
em_return_code <- out$return_code
if (em_return_code >= 0) {
Expand All @@ -263,8 +263,9 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
}
}
out <- nloptr(
x0 = unlist(init), eval_f = objectivef, lb = -bound, ub = bound,
opts = control
x0 = unlist(init), eval_f = objectivef,
lb = -rep(bound, length(unlist(init))),
ub = rep(bound, length(unlist(init))), opts = control
)
if (out$status < 0) {
warning_(
Expand Down Expand Up @@ -306,7 +307,6 @@ em_dnm_mnhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
all_solutions = all_solutions,
lambda = lambda,
pseudocount = pseudocount,
bound = bound,
method = "EM-DNM",
algorithm = control$algorithm,
Expand Down
20 changes: 10 additions & 10 deletions R/em_dnm_nhmm.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda,
bound, control, control_restart, control_mstep,
save_all_solutions) {
M <- model$n_symbols
Expand Down Expand Up @@ -108,7 +108,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
} else {
fit <- EM_LBFGS_nhmm_multichannel(
init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs,
Expand All @@ -119,7 +119,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control_restart$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
}
em_return_code <- fit$return_code
if (em_return_code >= 0) {
Expand All @@ -133,8 +133,8 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
)
)
fit <- nloptr(
x0 = init, eval_f = objectivef, lb = -bound, ub = bound,
opts = control_restart
x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)),
ub = rep(bound, length(init)), opts = control_restart
)
p()
fit
Expand Down Expand Up @@ -187,7 +187,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
} else {
out <- EM_LBFGS_nhmm_multichannel(
init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs,
Expand All @@ -198,7 +198,7 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda, pseudocount, bound)
control_mstep$print_level, lambda, bound)
}
em_return_code <- out$return_code
if (em_return_code >= 0) {
Expand All @@ -218,8 +218,9 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
}
}
out <- nloptr(
x0 = unlist(init), eval_f = objectivef, lb = -bound, ub = bound,
opts = control
x0 = unlist(init), eval_f = objectivef,
lb = -rep(bound, length(unlist(init))),
ub = rep(bound, length(unlist(init))), opts = control
)
if (out$status < 0) {
warning_(
Expand Down Expand Up @@ -252,7 +253,6 @@ em_dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, pseudocount,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
all_solutions = all_solutions,
lambda = lambda,
pseudocount = pseudocount,
bound = bound,
method = "EM-DNM",
algorithm = control$algorithm,
Expand Down
Loading

0 comments on commit 2240195

Please sign in to comment.