Skip to content

Commit

Permalink
refactor stan models, alternative restart strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 12, 2024
1 parent e710774 commit 092b34b
Show file tree
Hide file tree
Showing 201 changed files with 5,578 additions and 101,074 deletions.
2 changes: 1 addition & 1 deletion R/build_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ build_mnhmm <- function(
class = "mnhmm",
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type),
type = paste0(out$extras$multichannel, "mnhmm"),
intercept_only = out$extras$intercept_only
)
}
2 changes: 1 addition & 1 deletion R/build_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build_nhmm <- function(
class = "nhmm",
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type),
type = paste0(out$extras$multichannel, "nhmm"),
intercept_only = out$extras$intercept_only
)
}
27 changes: 25 additions & 2 deletions R/get_coefs.R → R/coef.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,22 @@ coef.nhmm <- function(object, probs = c(0.025, 0.975), ...) {
"Argument {.arg probs} must be a {.cls numeric} vector with values
between 0 and 1."
)
sds <- sqrt(diag(solve(-object$estimation_results$hessian)))
p_i <- length(beta_i_raw)
p_s <- length(beta_s_raw)
p_o <- length(beta_o_raw)
sds <- try(
sqrt(diag(solve(-object$estimation_results$hessian))),
silent = TRUE
)
if (inherits(sds, "try-error")) {
warning_(
paste0(
"Standard errors could not be computed due to singular Hessian.",
"Confidence intervals will not be provided."
)
)
sds <- rep(NA, p_i + p_s + p_o)
}
for(i in seq_along(probs)) {
q <- qnorm(probs[i])
beta_i[paste0("q", 100 * probs[i])] <- beta_i_raw + q * sds[seq_len(p_i)]
Expand Down Expand Up @@ -125,11 +137,22 @@ coef.mnhmm <- function(object, probs = c(0.025, 0.5, 0.975), ...) {
"Argument {.arg probs} must be a {.cls numeric} vector with values
between 0 and 1."
)
sds <- sqrt(diag(solve(-object$estimation_results$hessian)))
p_i <- length(beta_i_raw)
p_s <- length(beta_s_raw)
p_o <- length(beta_o_raw)
p_c <- length(theta_raw)
sds <- try(
sqrt(diag(solve(-object$estimation_results$hessian))),
silent = TRUE
)
if (inherits(sds, "try-error")) {
warning_(
"Standard errors could not be computed due to singular Hessian.
Confidence intervals will not be provided."
)
sds <- rep(NA, p_i + p_s + p_o + p_c)
}

for(i in seq_along(probs)) {
q <- qnorm(probs[i])
beta_i[paste0("q", 100 * probs[i])] <- beta_i_raw + q * sds[seq_len(p_i)]
Expand Down
3 changes: 0 additions & 3 deletions R/create_base_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
extras = list(
n_pars = n_pars,
multichannel = ifelse(n_channels > 1, "multichannel_", ""),
model_type = paste0(
pi$type, A$type, B$type, if (mixture) theta$type else ""
),
intercept_only = icp_only_i && icp_only_s && icp_only_o && icp_only_d
)
)
Expand Down
3 changes: 2 additions & 1 deletion R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, threads = 1L, store_data = TRUE, verbose = TRUE, restart_method = "1", ...) {
restarts = 0L, threads = 1L, store_data = TRUE, verbose = TRUE,
restart_method = "1", ...) {

call <- match.call()
model <- build_mnhmm(
Expand Down
11 changes: 9 additions & 2 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs),
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
verbose = FALSE
), dots)
Expand Down Expand Up @@ -114,7 +117,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs),
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
init = init,
verbose = verbose
Expand All @@ -127,6 +133,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
model$stan_model <- model_code@model_code
model$estimation_results <- list(
hessian = out$hessian,
penalized_loglik_N = out$par[["ploglik_N"]],
penalized_loglik = out$value,
loglik = out$par[["log_lik"]],
penalty = out$par[["prior"]],
Expand Down
103 changes: 56 additions & 47 deletions R/fit_mnhmm2.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,75 +49,81 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
if (is.null(dots$tol_param)) dots$tol_param <- 1e-4
if (is.null(dots$check_data)) dots$check_data <- FALSE

n <- ceiling(model$n_sequences / restarts)
out0 <- future_lapply(seq_len(restarts), function(i) {
n <- as.integer(max(5 * D, ceiling(model$n_sequences / restarts)))
# 5 * restarts runs with subset data
out0 <- future_lapply(seq_len(5 * restarts), function(i) {
init <- create_initial_values(
inits, S, M, init_sd, K_i, K_s, K_o, K_d, D
)
subset_data <- list(
N = n,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
max_M = max(model$n_symbols),
M = M,
S = S,
C = model$n_channels,
D = D,
K_i = K_i,
K_s = K_s,
K_o = K_o,
K_d = K_d,
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = subsample_obs(obs, n))

ids <- sample(model$n_sequences, size = n)
do.call(
optimizing,
c(list(
model_code, init = init,
data = subset_data,
data = list(
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
max_M = max(model$n_symbols),
M = M,
S = S,
C = model$n_channels,
D = D,
K_i = K_i,
K_s = K_s,
K_o = K_o,
K_d = K_d,
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs,
ids = ids,
N_sample = n
),
as_vector = FALSE,
verbose = FALSE
), dots)
)[c("par", "value", "return_code")]
},
future.seed = TRUE)

data <- list(
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
max_M = max(model$n_symbols),
M = M,
S = S,
C = model$n_channels,
D = D,
K_i = K_i,
K_s = K_s,
K_o = K_o,
K_d = K_d,
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs)
# take restarts/5 best solutions and run to the end
logliks <- unlist(lapply(out0, "[[", "value"))
idx <- head(order(logliks, decreasing = TRUE), restarts / 5)

out <- future_lapply(seq_len(restarts), function(i) {
out <- future_lapply(seq_len(restarts / 5), function(i) {
do.call(
optimizing,
c(list(
model_code, init = out0[[i]]$par,
data = data,
model_code, init = out0[[idx[i]]]$par,
data = list(
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
max_M = max(model$n_symbols),
M = M,
S = S,
C = model$n_channels,
D = D,
K_i = K_i,
K_s = K_s,
K_o = K_o,
K_d = K_d,
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
verbose = FALSE
), dots)
)[c("par", "value", "return_code")]
},
future.seed = TRUE)

logliks <- unlist(lapply(out, "[[", "value"))
return_codes <- unlist(lapply(out, "[[", "return_code"))
successful <- which(return_codes == 0)
Expand Down Expand Up @@ -152,7 +158,10 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
X_s = model$X_transition,
X_o = model$X_emission,
X_d = model$X_cluster,
obs = obs),
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
init = init,
verbose = verbose
Expand All @@ -165,7 +174,7 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
model$stan_model <- model_code@model_code
model$estimation_results <- list(
hessian = out$hessian,
penalized_loglik = out$value,
penalized_loglik = out$value,
loglik = out$par[["log_lik"]],
penalty = out$par[["prior"]],
return_code = out$return_code,
Expand Down
11 changes: 9 additions & 2 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
obs = obs),
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
verbose = FALSE
), dots)
Expand Down Expand Up @@ -104,7 +107,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
X_i = model$X_initial,
X_s = model$X_transition,
X_o = model$X_emission,
obs = obs),
obs = obs,
ids = seq_len(model$n_sequences),
N_sample = model$n_sequences
),
as_vector = FALSE,
init = init,
verbose = verbose
Expand All @@ -115,6 +121,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
model$stan_model <- model_code@model_code
model$estimation_results <- list(
hessian = out$hessian,
penalized_loglik_N = out$par[["ploglik_N"]],
penalized_loglik = out$value,
loglik = out$par[["log_lik"]],
penalty = out$par[["prior"]],
Expand Down
20 changes: 4 additions & 16 deletions R/model_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ model_matrix_initial_formula <- function(formula, data, n_sequences,
time, id) {
icp_only <- intercept_only(formula)
if (icp_only) {
type <- "c"
n_pars <- n_states - 1L
X <- matrix(1, n_sequences, 1)
coef_names <- "(Intercept)"
Expand All @@ -30,11 +29,9 @@ model_matrix_initial_formula <- function(formula, data, n_sequences,
)
)
coef_names <- colnames(X)
type <- "v"
n_pars <- (n_states - 1L) * ncol(X)
}
list(formula = formula, type = type, n_pars = n_pars, X = X,
coef_names = coef_names)
list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names)
}
#' Create the Model Matrix based on NHMM Formulas
#'
Expand All @@ -44,7 +41,6 @@ model_matrix_transition_formula <- function(formula, data, n_sequences,
time, id, sequence_lengths) {
icp_only <- intercept_only(formula)
if (icp_only) {
type <- "c"
n_pars <- n_states * (n_states - 1L)
X <- array(1, c(length_of_sequences, n_sequences, 1L))
coef_names <- "(Intercept)"
Expand Down Expand Up @@ -74,11 +70,9 @@ model_matrix_transition_formula <- function(formula, data, n_sequences,
X[missing_values] <- 0
coef_names <- colnames(X)
dim(X) <- c(length_of_sequences, n_sequences, ncol(X))
type <- "v"
n_pars <- n_states * (n_states - 1L) * dim(X)[3]
}
list(formula = formula, type = type, n_pars = n_pars, X = X,
coef_names = coef_names)
list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names)
}
#' Create the Model Matrix based on NHMM Formulas
#'
Expand All @@ -89,7 +83,6 @@ model_matrix_emission_formula <- function(formula, data, n_sequences,
time, id, sequence_lengths) {
icp_only <- intercept_only(formula)
if (icp_only) {
type <- "c"
n_pars <- n_channels * n_states * (n_symbols - 1L)
X <- array(1, c(length_of_sequences, n_sequences, 1L))
coef_names <- "(Intercept)"
Expand Down Expand Up @@ -119,11 +112,9 @@ model_matrix_emission_formula <- function(formula, data, n_sequences,
X[missing_values] <- 0
coef_names <- colnames(X)
dim(X) <- c(length_of_sequences, n_sequences, ncol(X))
type <- "v"
n_pars <- n_channels * n_states * (n_symbols - 1L) * dim(X)[3]
}
list(formula = formula, type = type, n_pars = n_pars, X = X,
coef_names = coef_names)
list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names)
}
#' Create the Model Matrix based on NHMM Formulas
#'
Expand All @@ -132,7 +123,6 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters,
time, id) {
icp_only <- intercept_only(formula)
if (icp_only) {
type <- "c"
n_pars <- n_clusters - 1L
X <- matrix(1, n_sequences, 1)
coef_names <- "(Intercept)"
Expand All @@ -156,9 +146,7 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters,
)
)
coef_names <- colnames(X)
type <- "v"
n_pars <- (n_clusters - 1L) * ncol(X)
}
list(formula = formula, type = type, n_pars = n_pars, X = X,
coef_names = coef_names)
list(formula = formula, n_pars = n_pars, X = X, coef_names = coef_names)
}
3 changes: 2 additions & 1 deletion R/plot.ame.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#' found in the input object `x`, an error is thrown.
#' @export
plot.ame <- function(x, type, probs = c(0.025, 0.975)) {
type <- match.arg(type, c("initial", "transition", "emission", "cluster"), ...)

type <- match.arg(type, c("initial", "transition", "emission", "cluster"))

cluster <- time <- state <- state_from <- state_to <- observation <-
estimate <- NULL
Expand Down
Loading

0 comments on commit 092b34b

Please sign in to comment.