Skip to content

Commit

Permalink
amp working
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 9, 2024
1 parent a26c103 commit 17fbc7d
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 140 deletions.
153 changes: 82 additions & 71 deletions R/average_marginal_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ average_marginal_prediction <- function(
length(values) == 2,
"Argument {.arg values} should contain two values for
variable {.var variable}.")
time <- model$time_variable
id <- model$id_variable
if (!is.null(newdata)) {
time <- model$time_variable
id <- model$id_variable
stopifnot_(
is.data.frame(newdata),
"Argument {.arg newdata} must be a {.cls data.frame} object."
Expand Down Expand Up @@ -77,83 +77,94 @@ average_marginal_prediction <- function(
seed <- sample(.Machine$integer.max, 1)
set.seed(seed)
pred <- predict(model, newdata, nsim, return_samples = TRUE)
if (length(values) == 2) {
newdata[[variable]] <- values[2]
set.seed(seed)
pred2 <- predict(model, newdata, nsim, return_samples = TRUE)
pred <- mapply("-", pred, pred2, SIMPLIFY = FALSE)
}
T <- model$length_of_sequences
N <- model$n_sequences
S <- model$n_states
M <- model$n_symbols
C <- model$n_channels
D <- model$n_clusters
if (C == 1) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
symbol_names <- list(model$symbol_names)
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names

newdata[[variable]] <- values[2]
set.seed(seed)
pred2 <- predict(model, newdata, nsim, return_samples = TRUE)
pars <- c("initial_probs", "transition_probs", "emission_probs",
if (D > 1) "cluster_probs")
for (i in pars) {
pred[[i]]$estimate <- pred[[i]]$estimate - pred2[[i]]$estimate
}
if (nsim > 0) {
for (i in pars) {
pred$samples[[i]]$estimate <- pred$samples[[i]]$estimate -
pred2$samples[[i]]$estimate
}
}
marginalize <- c(
switch(
marginalize_B_over,
"clusters" = c("cluster", "state", "id"),
"states" = c("state", "id"),
"sequences" = "id"),
"time", "channel", "observation")

pi <- data.frame(
cluster = rep(model$cluster_names, each = S * N),
id = rep(ids, each = S),
state = model$state_names,
estimate = unlist(pred$pi)
) |>
dplyr::group_by(cluster, state) |>
dplyr::summarise(estimate = mean(estimate))
channel <- if (model$n_channels > 1) "channel" else NULL
group_by_B <- switch(
marginalize_B_over,
"clusters" = c("time", channel, "observation"),
"states" = c("cluster", "time", channel, "observation"),
"sequences" = c("cluster", "time", "state", channel, "observation")
)

A <- data.frame(
cluster = rep(model$cluster_names, each = S^2 * T * N),
id = rep(ids, each = S^2 * T),
time = rep(times, each = S^2),
state_from = model$state_names,
state_to = rep(model$state_names, each = S),
estimate = unlist(pred$A)
) |>
dplyr::group_by(cluster, time, state_from, state_to) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::rename(!!time_var := time)
qs <- function(x, probs) {
x <- quantile(x, probs)
names(x) <- paste0("q", 100 * probs)
as.data.frame(t(x))
}
out <- list()
out$initial_probs <- cbind(
pred$initial_probs |>
dplyr::group_by(cluster, state) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::ungroup(),
pred$samples$initial_probs |>
dplyr::group_by(cluster, state, replication) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::group_by(cluster, state) |>
dplyr::summarise(qs(estimate, probs)) |>
dplyr::ungroup() |>
dplyr::select(-c(cluster, state))
)

out$transition_probs <- cbind(
pred$transition_probs |>
dplyr::group_by(cluster, time, state_from, state_to) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::ungroup(),
pred$samples$transition_probs |>
dplyr::group_by(cluster, time, state_from, state_to, replication) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::group_by(cluster, time, state_from, state_to) |>
dplyr::summarise(qs(estimate, probs)) |>
dplyr::ungroup() |>
dplyr::select(-c(cluster, time, state_from, state_to))
) |> dplyr::rename(!!time := time)

B <- data.frame(
cluster = rep(model$cluster_names, each = S * sum(M) * T * N),
id = unlist(lapply(seq_len(C), function(i) rep(ids, each = S * M[i] * T))),
time = unlist(lapply(seq_len(C), function(i) rep(times, each = S * M[i]))),
state = model$state_names,
channel = unlist(lapply(seq_len(C), function(i) {
rep(model$channel_names[i], each = S * M[i]* T * N)
})),
observation = unlist(lapply(seq_len(C), function(i) {
rep(symbol_names[[i]], each = S)
})),
estimate = unlist(pred$B)
) |>
dplyr::group_by(across(all_of(marginalize))) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::rename(!!time_var := time)
out$emission_probs <- cbind(
pred$emission_probs |>
dplyr::group_by(dplyr::pick(group_by_B)) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::ungroup(),
pred$samples$emission_probs |>
dplyr::group_by(dplyr::pick(c(group_by_B, "replication"))) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::group_by(dplyr::pick(group_by_B)) |>
dplyr::summarise(qs(estimate, probs)) |>
dplyr::ungroup() |>
dplyr::select(dplyr::starts_with("q"))
) |> dplyr::rename(!!time := time)

if (D > 1) {
omega <- data.frame(
cluster = model$cluster_names,
id = rep(ids, each = D),
estimate = c(pred$omega)
out$cluster_probs <- cbind(
pred$cluster_probs |>
dplyr::group_by(cluster) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::ungroup(),
pred$samples$cluster_probs |>
dplyr::group_by(cluster, replication) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::group_by(cluster) |>
dplyr::summarise(qs(estimate, probs)) |>
dplyr::ungroup() |>
dplyr::select(-cluster)
)
out <- list(omega = omega, pi = pi, A = A, B = B)
} else {
out <- list(pi = pi, A = A, B = B)
}
}
class(out) <- "amp"
attr(out, "seed") <- seed
attr(out, "marginalize_B_over") <- marginalize_B_over
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 1L, threads = 1L, store_data = TRUE, verbose = TRUE, ...) {
restarts = 0L, threads = 1L, store_data = TRUE, verbose = TRUE, ...) {

model <- build_mnhmm(
observations, n_states, n_clusters, initial_formula,
Expand Down
22 changes: 10 additions & 12 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#' random initial values. Default is `2`. If you want to fix the initial values
#' of the regression coefficients to zero, use `init_sd = 0`.
#' @param restarts Number of times to run optimization using random starting
#' values. Default is 1.
#' values (in addition to the final run). Default is 0.
#' @param threads Number of parallel threads for optimization with restarts.
#' Default is 1.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
Expand All @@ -57,24 +57,22 @@
#' @examples
#' data("mvad", package = "TraMineR")
#'
#' mvad_alphabet <-
#' c("employment", "FE", "HE", "joblessness", "school", "training")
#' mvad_labels <- c("employment", "further education", "higher education",
#' "joblessness", "school", "training")
#' mvad_scodes <- c("EM", "FE", "HE", "JL", "SC", "TR")
#' mvad_seq <- seqdef(mvad, 15:86, alphabet = mvad_alphabet,
#' states = mvad_scodes, labels = mvad_labels, xtstep = 6,
#' cpal = unname(colorpalette[[6]]))
#' d <- reshape(mvad, direction = "long", varying = list(15:86),
#' v.names = "activity")
#'
#' set.seed(1)
#' \dontrun{
#' fit <- estimate_nhmm(mvad_seq, n_states = 3)
#' set.seed(1)
#' fit <- estimate_mnhmm("activity", n_states = 3,
#' data = d, time = "time", id = "id",
#' initial_formula = ~ 1, emission_formula = ~ male + gcse5eq,
#' transition_formula = ~ male + gcse5eq, inits = "random"
#' )
#' }
estimate_nhmm <- function(
observations, n_states, initial_formula = ~1,
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 1L, threads = 1L,
inits = "random", init_sd = 2, restarts = 0L, threads = 1L,
store_data = TRUE, verbose = TRUE, ...) {

model <- build_nhmm(
Expand Down
4 changes: 4 additions & 0 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
)
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
)
obs <- create_obsArray(model) + 1L
if (model$n_channels == 1) {
obs <- array(obs, dim(obs)[2:3])
Expand Down
4 changes: 4 additions & 0 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
)
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
)
obs <- create_obsArray(model) + 1L
if (model$n_channels == 1) {
obs <- array(obs, dim(obs)[2:3])
Expand Down
24 changes: 17 additions & 7 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,19 @@ predict.nhmm <- function(
})
)
colnames(out$emission_probs)[1] <- object$id_variable
colnames(out$emission_probs)[2] <- object$time_variable
colnames(out$emission_probs)[2] <- object$time_variable
if (C == 1) out$emission_probs$channel <- NULL

if (nsim > 0) {
samples <- sample_parameters(object, nsim, probs, return_samples)
samples <- sample_parameters_nhmm(object, nsim)
if (return_samples) {
out$samples <- samples
out$samples <- samples_to_df(object, samples)
} else {
out$quantiles <- samples
out$quantiles <- list(
initial_probs = fast_quantiles(samples$pi, probs),
transition_probs = fast_quantiles(samples$A, probs),
emission_probs = fast_quantiles(samples$B, probs)
)
}
}
out
Expand Down Expand Up @@ -229,11 +234,16 @@ predict.mnhmm <- function(
estimate = c(get_omega(theta_raw, X_cluster, 0))
)
if (nsim > 0) {
samples <- sample_parameters(object, nsim, probs, return_samples)
samples <- sample_parameters_mnhmm(object, nsim)
if (return_samples) {
out$samples <- samples
out$samples <- samples_to_df(object, samples)
} else {
out$quantiles <- samples
out$quantiles <- list(
initial_probs = fast_quantiles(samples$pi, probs),
transition_probs = fast_quantiles(samples$A, probs),
emission_probs = fast_quantiles(samples$B, probs),
cluster_probs = fast_quantiles(samples$omega, probs)
)
}
}
out
Expand Down
Loading

0 comments on commit 17fbc7d

Please sign in to comment.