Skip to content

Commit

Permalink
simulate_mnhmm, fix stacked sequence plot warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 3, 2024
1 parent 16a2e21 commit e5eda39
Show file tree
Hide file tree
Showing 67 changed files with 1,947 additions and 1,546 deletions.
45 changes: 24 additions & 21 deletions Examples/seqHMMexample.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ library(TraMineR)

data(biofam3c)

# Building sequence objects (starting at age 15)
marr.seq <- seqdef(biofam3c$married, start = 15)
child.seq <- seqdef(biofam3c$children, start = 15)
left.seq <- seqdef(biofam3c$left, start = 15)

# Choosing colours for states
attr(marr.seq, "cpal") <- c("#AB82FF", "#E6AB02", "#E7298A")
attr(child.seq, "cpal") <- c("#66C2A5", "#FC8D62")
attr(left.seq, "cpal") <- c("#A6CEE3", "#E31A1C")

# Building sequence objects (starting at age 15, with custom color palette)
marr.seq <- seqdef(
biofam3c$married, start = 15,
cpal = c("#AB82FF", "#E6AB02", "#E7298A")
)
child.seq <- seqdef(
biofam3c$children, start = 15,
cpal = c("#66C2A5", "#FC8D62")
)
left.seq <- seqdef(
biofam3c$left, start = 15,
cpal = c("#A6CEE3", "#E31A1C")
)

# Plotting multichannel sequence data

Expand Down Expand Up @@ -137,9 +140,9 @@ ssplot(
most probable paths of hidden states",
# Labels for hidden states (most common states)
mpp.labels = c("1: Childless single, with parents",
"2: Childless single, left home",
"3: Married without children",
"4: Married parent, left home"),
"2: Childless single, left home",
"3: Married without children",
"4: Married parent, left home"),
# Colours for hidden states
mpp.col = c("olivedrab", "bisque", "plum", "indianred"),
# Labels for x axis
Expand Down Expand Up @@ -244,17 +247,17 @@ B3_left <- matrix(c(0.01, 0.99, # High probability for living with parents

# Starting values for transition matrices
A1 <- matrix(c(0.8, 0.16, 0.03, 0.01,
0, 0.9, 0.07, 0.03,
0, 0, 0.9, 0.1,
0, 0, 0, 1),
0, 0.9, 0.07, 0.03,
0, 0, 0.9, 0.1,
0, 0, 0, 1),
nrow = 4, ncol = 4, byrow = TRUE)

A2 <- matrix(c(0.8, 0.10, 0.05, 0.03, 0.01, 0.01,
0, 0.7, 0.1, 0.1, 0.05, 0.05,
0, 0, 0.85, 0.01, 0.1, 0.04,
0, 0, 0, 0.9, 0.05, 0.05,
0, 0, 0, 0, 0.9, 0.1,
0, 0, 0, 0, 0, 1),
0, 0.7, 0.1, 0.1, 0.05, 0.05,
0, 0, 0.85, 0.01, 0.1, 0.04,
0, 0, 0, 0.9, 0.05, 0.05,
0, 0, 0, 0, 0.9, 0.1,
0, 0, 0, 0, 0, 1),
nrow = 6, ncol = 6, byrow = TRUE)

# Starting values for initial state probabilities
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ S3method(print,mhmm)
S3method(print,mnhmm)
S3method(print,nhmm)
S3method(print,summary.mhmm)
S3method(print,summary.mnhmm)
S3method(print,summary.nhmm)
S3method(state_names,hmm)
S3method(state_names,mhmm)
S3method(summary,mhmm)
Expand Down Expand Up @@ -72,6 +74,7 @@ export(simulate_emission_probs)
export(simulate_hmm)
export(simulate_initial_probs)
export(simulate_mhmm)
export(simulate_mnhmm)
export(simulate_transition_probs)
export(sort_sequences)
export(ssp)
Expand Down
4 changes: 2 additions & 2 deletions R/HMMplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,

# Colors for the (combinations of) observed states
if (identical(cpal, "auto")) {
pie.colors <- attr(x$observations, "cpal")
pie.colors <- TraMineR::cpal(x$observations)
} else if (length(cpal) != ncol(x$emiss)) {
warning_("The length of {.arg cpal} does not match the number of observed
states. Automatic color palette was used.")
pie.colors <- attr(x$observations, "cpal")
pie.colors <- TraMineR::cpal(x$observations)
} else if (!all(isColor(cpal))) {
stop_("Please provide a vector of colors for {.arg cpal} or use value
{.val 'auto'} for automatic color palette.")
Expand Down
19 changes: 15 additions & 4 deletions R/average_marginal_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ average_marginal_effect <- function(
)
} else {
stopifnot(
!is.null(object$data),
!is.null(model$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
Expand Down Expand Up @@ -84,10 +84,21 @@ average_marginal_effect <- function(
ame_B <- if (model$n_channels == 1) {
get_B(beta_o_raw, X_emission1, 0) - get_B(beta_o_raw, X_emission2, 0)
} else {
get_multichannel_B(beta_o_raw, X_emission1, S, C, M, 0) -
get_multichannel_B(beta_o_raw, X_emission2, S, C, M, 0)
get_multichannel_B(
beta_o_raw,
X_emission1,
model$n_states,
model$n_channels,
model$n_symbols,
0, 0) -
get_multichannel_B(
beta_o_raw,
X_emission2,
model$n_states,
model$n_channels,
model$n_symbols,
0, 0)
}
browser()
if (nsim > 0) {
stopifnot_(
checkmate::test_numeric(
Expand Down
17 changes: 10 additions & 7 deletions R/build_hmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,24 @@
#' # Building sequence objects
#' marr_seq <- seqdef(biofam3c$married,
#' start = 15,
#' alphabet = c("single", "married", "divorced")
#' alphabet = c("single", "married", "divorced"),
#' cpal = c("violetred2", "darkgoldenrod2", "darkmagenta")
#' )
#' child_seq <- seqdef(biofam3c$children,
#' start = 15,
#' alphabet = c("childless", "children")
#' alphabet = c("childless", "children"),
#' cpal = c("darkseagreen1", "coral3")
#' )
#' left_seq <- seqdef(biofam3c$left,
#' start = 15,
#' alphabet = c("with parents", "left home")
#' alphabet = c("with parents", "left home"),
#' cpal = c("lightblue", "red3")
#' )
#'
#' # Define colors
#' attr(marr_seq, "cpal") <- c("violetred2", "darkgoldenrod2", "darkmagenta")
#' attr(child_seq, "cpal") <- c("darkseagreen1", "coral3")
#' attr(left_seq, "cpal") <- c("lightblue", "red3")
#' # You could also define the colors using cpal function from TraMineR
#' # cpal(marr_seq) <- c("violetred2", "darkgoldenrod2", "darkmagenta")
#' # cpal(child_seq) <- c("darkseagreen1", "coral3")
#' # cpal(left_seq) <- c("lightblue", "red3")
#'
#' # Left-to-right HMM with 3 hidden states and random starting values
#' set.seed(1010)
Expand Down
2 changes: 1 addition & 1 deletion R/build_lcm.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ build_lcm <- function(observations, n_clusters, emission_probs,
if (n_channels > 1L) {
n_clusters <- nrow(emission_probs[[1]])
emission_probs_list <- vector("list", n_clusters)
for (i in 1:n_channels) {
for (i in seq_len(n_channels)) {
stopifnot_(
nrow(emission_probs[[i]]) != n_clusters,
"Different number of rows in the list components of {.arg emission_probs}."
Expand Down
14 changes: 6 additions & 8 deletions R/build_mhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,20 @@
#' ## Building sequence objects
#' marr_seq <- seqdef(biofam3c$married,
#' start = 15,
#' alphabet = c("single", "married", "divorced")
#' alphabet = c("single", "married", "divorced"),
#' cpal = c("#AB82FF", "#E6AB02", "#E7298A")
#' )
#' child_seq <- seqdef(biofam3c$children,
#' start = 15,
#' alphabet = c("childless", "children")
#' alphabet = c("childless", "children"),
#' cpal = c("#66C2A5", "#FC8D62")
#' )
#' left_seq <- seqdef(biofam3c$left,
#' start = 15,
#' alphabet = c("with parents", "left home")
#' alphabet = c("with parents", "left home"),
#' cpal = c("#A6CEE3", "#E31A1C")
#' )
#'
#' ## Choosing colors
#' attr(marr_seq, "cpal") <- c("#AB82FF", "#E6AB02", "#E7298A")
#' attr(child_seq, "cpal") <- c("#66C2A5", "#FC8D62")
#' attr(left_seq, "cpal") <- c("#A6CEE3", "#E31A1C")
#'
#' ## MHMM with random starting values, no covariates
#' set.seed(468)
#' init_mhmm_bf1 <- build_mhmm(
Expand Down
6 changes: 2 additions & 4 deletions R/build_mm.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@
#' mvad_scodes <- c("EM", "FE", "HE", "JL", "SC", "TR")
#' mvad_seq <- seqdef(mvad, 17:86,
#' alphabet = mvad_alphabet,
#' states = mvad_scodes, labels = mvad_labels, xtstep = 6
#' states = mvad_scodes, labels = mvad_labels, xtstep = 6,
#' cpal = colorpalette[[6]]
#' )
#'
#' # Define a color palette for the sequence data
#' attr(mvad_seq, "cpal") <- colorpalette[[6]]
#'
#' # Estimate the Markov model
#' mm_mvad <- build_mm(observations = mvad_seq)
#'
Expand Down
2 changes: 0 additions & 2 deletions R/build_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ build_mnhmm <- function(
structure(
out$model,
class = "mnhmm",
time_variable = time,
id_variable = id,
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type)
Expand Down
2 changes: 0 additions & 2 deletions R/build_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ build_nhmm <- function(
structure(
out$model,
class = "nhmm",
time_variable = time,
id_variable = id,
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type)
Expand Down
6 changes: 3 additions & 3 deletions R/check_build_arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@

if (is.null(channel_names)) {
if (is.null(channel_names <- names(x))) {
channel_names <- paste("Channel", 1:n_channels)
channel_names <- paste("Channel", seq_len(n_channels))
}
} else {
if (length(channel_names) != n_channels) {
warning_(
"The length of {.arg channel_names} does not match the number of
channels. Names were not used."
)
channel_names <- paste("Channel", 1:n_channels)
channel_names <- paste("Channel", seq_len(n_channels))
}
}
for (i in seq_len(n_channels)) {
Expand Down Expand Up @@ -238,4 +238,4 @@
)
data <- data[order(data[[id]], data[[time]]), ]
fill_time(data, id, time)
}
}
7 changes: 0 additions & 7 deletions R/check_missing_pattern.R

This file was deleted.

4 changes: 2 additions & 2 deletions R/cluster_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cluster_names.mnhmm <- function(object) {
#' @export
`cluster_names<-.mnhmm` <- function(object, value) {
stopifnot_(
length(Value) == object$n_clusters,
length(value) == object$n_clusters,
"New cluster names should be a vector of length {object$n_clusters}."
)
object$cluster_names <- value
Expand All @@ -38,7 +38,7 @@ cluster_names.mnhmm <- function(object) {
#' @export
`cluster_names<-.mhmm` <- function(object, value) {
stopifnot_(
length(Value) == object$n_clusters,
length(value) == object$n_clusters,
"New cluster names should be a vector of length {object$n_clusters}."
)
object$cluster_names <- value
Expand Down
2 changes: 2 additions & 0 deletions R/create_base_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
list(
model = list(
observations = observations,
time_variable = time,
id_variable = id,
X_initial = pi$X, X_transition = A$X, X_emission = B$X,
X_cluster = if(mixture) theta$X else NULL,
initial_formula = pi$formula,
Expand Down
11 changes: 8 additions & 3 deletions R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
#' (mixtures).
#' @param cluster_formula of class [formula()] for the
#' mixture probabilities.
#' @param inits If `inits = "random"` (default), random initial values are
#' used. Otherwise `inits` should be list of initial values. If coefficients
#' are given using list components `beta_i_raw`, `beta_s_raw`, `beta_o_raw`,
#' and `theta_raw`, these are used as is, alternatively initial values can be
#' given in terms of the initial state, transition, emission, and mixture
#' probabilities using list components `initial_probs`, `emission_probs`,
#' `transition_probs`, and `cluster_probs`. These can also be mixed, i.e. you
#' can give only `initial_probs` and `beta_s_raw`.
#' @param cluster_names A vector of optional labels for the clusters. If this
#' is `NULL` (the default), numbered clusters are used.
#' @param inits Optional initial values for the initial state, transition,
#' emission, and mixture probabilities. Either a list with `initial_probs`,
#' `emission_probs`, `transition_probs`, `cluster_probs`, or `"random"`.
#' @return Object of class `mnhmm`.
#' @export
#' @examples
Expand Down
10 changes: 7 additions & 3 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
#' is `NULL` (the default), numbered states are used.
#' @param channel_names A vector of optional names for the channels. If this
#' is `NULL` (the default), numbered channels are used.
#' @param inits Optional initial values for the initial state, transition,
#' emission, and mixture probabilities. Either a list with `initial_probs`,
#' `emission_probs`, `transition_probs`, `cluster_probs`, or `"random"`.
#' @param inits If `inits = "random"` (default), random initial values are
#' used. Otherwise `inits` should be list of initial values. If coefficients
#' are given using list components `beta_i_raw`, `beta_s_raw`, `beta_o_raw`,
#' these are used as is, alternatively initial values can be given in terms of
#' the initial state, transition, and emission probabilities using list
#' components `initial_probs`, `emission_probs`, and `transition_probs`. These
#' can also be mixed, i.e. you can give only `initial_probs` and `beta_s_raw`.
#' @param init_sd Standard deviation of the normal distribution used to generate
#' random initial values. Default is `2`. If you want to fix the initial values
#' of the regression coefficients to zero, use `init_sd = 0`.
Expand Down
6 changes: 0 additions & 6 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
)
dots <- list(...)
if (!is.null(dots$no_hessian)) {
hessian <- !no_hessian
} else {
hessian <- TRUE
}
obs <- create_obsArray(model) + 1L
if (model$n_channels == 1) {
obs <- array(obs, dim(obs)[2:3])
Expand Down
Loading

0 comments on commit e5eda39

Please sign in to comment.