Skip to content

Commit

Permalink
ame probs and obs
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 15, 2024
1 parent 1743101 commit 3aab8ae
Show file tree
Hide file tree
Showing 20 changed files with 1,526 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
.project
.settings
# Extra source folders and compiled code
vignettes/cache
src-i386
src-x64
*.o
Expand Down
6 changes: 3 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ S3method("state_names<-",hmm)
S3method("state_names<-",mhmm)
S3method("state_names<-",mnhmm)
S3method("state_names<-",nhmm)
S3method(ame,mnhmm)
S3method(ame,nhmm)
S3method(ame_probs,mnhmm)
S3method(ame_probs,nhmm)
S3method(bootstrap_coefs,mnhmm)
S3method(bootstrap_coefs,nhmm)
S3method(cluster_names,mhmm)
Expand Down Expand Up @@ -74,7 +74,7 @@ S3method(vcov,mhmm)
export("cluster_names<-")
export("state_names<-")
export(alphabet)
export(ame)
export(ame_probs)
export(bootstrap_coefs)
export(build_hmm)
export(build_lcm)
Expand Down
16 changes: 16 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,19 @@ softmax <- function(x) {
.Call(`_seqHMM_softmax`, x)
}

state_obs_probs_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) {
.Call(`_seqHMM_state_obs_probs_nhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)
}

state_obs_probs_nhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) {
.Call(`_seqHMM_state_obs_probs_nhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)
}

state_obs_probs_mnhmm_singlechannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) {
.Call(`_seqHMM_state_obs_probs_mnhmm_singlechannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)
}

state_obs_probs_mnhmm_multichannel <- function(eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start) {
.Call(`_seqHMM_state_obs_probs_mnhmm_multichannel`, eta_omega, X_omega, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_omega, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B, start)
}

127 changes: 127 additions & 0 deletions R/ame_obs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#' Average Marginal Effects for NHMM Responses
#'
#' The function `ame_obs` computes the average marginal effect (AME) of the
#' model covariate \eqn{X} at time t on the current and future responses by
#' marginalizing over the sequences and latent states. Under the assumption of
#' no unobserved confounding (i.e., there are no unobserved variables that
#' influence the covariate \eqn{X} and the outcome \eqn{Y}), these can be
#' regarded as the causal effects. In case `values` argument is a single value
#' \eqn{x}, the function returns the interventional distributions
#' \deqn{P(y_{t+k} | do(X_t = x))}
#' and in a case `values` contains two values \eqn{x} and \eqn{w} a shift in
#' interventional distributions, i.e.,
#' \deqn{P(y_{t+k} | do(X_t = x)) - P(y_{t+k} | do(X_t = w))}.
#'
#' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`.
#' @param variable Name of the variable of interest.
#' @param values Vector containing one or two values for `variable`.
#' See details.
#' @param start_time Time(s) of intervention. Either a scalar or vector.
#' Intervention is applied to all provided time points.
#' @param newdata Optional data frame which is used for marginalization.
#' @param probs Quantiles of interest of average marginal effect.
#' @param ... Ignored.
#' @rdname ame_obs
#' @export
ame_obs <- function(model, variable, values, start_time, ...) {
UseMethod("ame_obs", model)
}
#' @rdname ame_obs
#' @export
ame_obs.nhmm <- function(
model, variable, values, start_time, newdata = NULL, probs = c(0.05, 0.95),
...) {
stopifnot_(
attr(model, "intercept_only") == FALSE,
"Model does not contain any covariates."
)
stopifnot_(
checkmate::test_string(x = variable),
"Argument {.arg variable} must be a single character string."
)
stopifnot_(
length(values) == 2,
"Argument {.arg values} should contain two values for
variable {.var variable}.")
time <- model$time_variable
id <- model$id_variable
if (!is.null(newdata)) {
stopifnot_(
is.data.frame(newdata),
"Argument {.arg newdata} must be a {.cls data.frame} object."
)
stopifnot_(
!is.null(newdata[[id]]),
"Can't find grouping variable {.var {id}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[time]]),
"Can't find time index variable {.var {time}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[variable]]),
"Can't find time variable {.var {variable}} in {.arg newdata}."
)
} else {
stopifnot_(
!is.null(model$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
newdata <- model$data
}
stopifnot_(
!is.null(model$boot),
paste0(
"Model does not contain bootstrap samples of coefficients. ",
"Run {.fn bootstrap_coefs} first."
)
)
newdata[[variable]] <- values[1]
model1 <- update(model, newdata)
newdata[[variable]] <- values[2]
model2 <- update(model, newdata)
C <- model$n_channels
if (C == 1L) {
times <- colnames(model$observations)
symbol_names <- list(model$symbol_names)
} else {
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names
}
stop("WIP")
if (model$n_channels == 1) {

obs <- create_obsArray(model)[1L, , ]
out1 <- state_obs_probs_nhmm_singlechannel(
model1$etas$pi, model1$X_pi, model1$etas$A, model1$X_A,
model1$etas$B, model1$X_B, obs, model1$sequence_lengths,
attr(model1$X_pi, "icpt_only"), attr(model1$X_A, "icpt_only"),
attr(model1$X_B, "icpt_only"), attr(model1$X_A, "iv"),
attr(model1$X_B, "iv"), attr(model1$X_A, "tv"), attr(model1$X_B, "tv"),
start = start_time)
out2 <- state_obs_probs_nhmm_singlechannel(
model2$etas$pi, model2$X_pi, model2$etas$A, model2$X_A,
model2$etas$B, model2$X_B, obs, model2$sequence_lengths,
attr(model2$X_pi, "icpt_only"), attr(model2$X_A, "icpt_only"),
attr(model2$X_B, "icpt_only"), attr(model2$X_A, "iv"),
attr(model2$X_B, "iv"), attr(model2$X_A, "tv"), attr(model2$X_B, "tv"),
start = start_time)
}

class(out) <- "ame_obs"
attr(out, "model") <- "nhmm"
out
}

#' @rdname ame_obs
#' @export
ame_obs.mnhmm <- function(
model, variable, values, start_time, newdata = NULL, probs = c(0.05, 0.95),
...) {

stop("Not yet implemented")
class(out) <- "ame_obs"
attr(out, "model") <- "mnhmm"
out
}
42 changes: 30 additions & 12 deletions R/ame.R → R/ame_probs.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
#' Average Marginal Effects for Non-homogenous Hidden Markov Models
#' Average Marginal Effects on NHMM Parameters
#'
#' The function `ame_probs` computes the average marginal effect (AME) of the
#' model covariate \eqn{X} on the model parameters by marginalizing over the sequences.
#' Under the assumption of no unobserved confounding (i.e., there are no
#' unobserved variables that influence the covariate \eqn{X} and the outcome),
#' these can be regarded as the causal effects of the covariate on
#' the initial, emission, and transition probabilities of the model. In case
#' `values` argument is a single value \eqn{x}, the function returns the
#' interventional initial, transition, emission probabilities
#' \deqn{P(z_1 | do(X_1 = x))}
#' \deqn{P(z_t | do(X_{t-1} = x), z_{t-1})}
#' \deqn{P(y_t | do(X_t = x), z_t)}
#' and in a case `values` contains two values \eqn{x} and \eqn{w} a shift in
#' interventional distributions, i.e.,
#' \deqn{P(z_1 | do(X_1 = x)) - P(z_1 | do(X_1 = w))}
#' \deqn{P(z_t | do(X_{t-1} = x), z_{t-1}) - P(z_t | do(X_{t-1} = w), z_{t-1})}
#' \deqn{P(y_t | do(X_t = x), z_t) - P(y_t | do(X_t = w), z_t)}.
#'
#' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`.
#' @param variable Name of the variable of interest.
#' @param values Vector containing one or two values for `variable`.
#' @param values Vector containing one or two values for `variable`.
#' See details.
#' @param newdata Optional data frame which is used for marginalization.
#' @param probs Quantiles of interest of average marginal effect.
#' @param ... Ignored.
#' @rdname ame
#' @rdname ame_probs
#' @export
ame <- function(model, variable, values, ...) {
UseMethod("ame", model)
ame_probs <- function(model, variable, values, ...) {
UseMethod("ame_probs", model)
}
#' @rdname ame
#' @rdname ame_probs
#' @export
ame.nhmm <- function(
ame_probs.nhmm <- function(
model, variable, values, newdata = NULL, probs = c(0.05, 0.95),
...) {
stopifnot_(
Expand Down Expand Up @@ -184,19 +202,19 @@ ame.nhmm <- function(
transition = ame_A,
emission = ame_B
)
class(out) <- "amp"
class(out) <- "ace_prob"
attr(out, "model") <- "nhmm"
out
}

#' @rdname ame
#' @rdname ame_probs
#' @export
ame.mnhmm <- function(
ame_probs.mnhmm <- function(
model, variable, values, newdata = NULL, probs = c(0.05, 0.95),
...) {

x <- lapply(
split_mnhmm(model), ame, variable = variable, values = values,
split_mnhmm(model), ame_probs, variable = variable, values = values,
newdata = newdata, probs = probs
)
out <- lapply(c("pi", "A", "B"), function(z) {
Expand Down Expand Up @@ -235,7 +253,7 @@ ame.mnhmm <- function(
),
qs_omega
)
class(out) <- "amp"
class(out) <- "ace_prob"
attr(out, "model") <- "mnhmm"
out
}
30 changes: 0 additions & 30 deletions man/ame.Rd

This file was deleted.

46 changes: 46 additions & 0 deletions man/ame_probs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3aab8ae

Please sign in to comment.