Skip to content

Commit

Permalink
more tests, fix get_pi_all and get_pi_omega
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 24, 2024
1 parent ce09b3d commit c7433a1
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 29 deletions.
2 changes: 1 addition & 1 deletion R/coef.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ coef.nhmm <- function(object, probs = c(0.025, 0.975), ...) {
# }

list(
gamma_pinitial = gamma_pi,
gamma_initial = gamma_pi,
beta_transition = gamma_A,
beta_emission = gamma_B
)
Expand Down
8 changes: 4 additions & 4 deletions R/model_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ model_matrix_initial_formula <- function(formula, data, n_sequences,
na.action = stats::na.pass
)
missing_values <- which(!complete.cases(X))
iv <- nrow(unique(X[-missing_values, ])) > 1
stopifnot_(
length(missing_values) == 0,
length(missing_values) == 0L,
c(
"Missing cases are not allowed in covariates of `initial_formula`.",
"Use {.fn complete.cases} to detect them, then fix or impute them.",
Expand All @@ -49,6 +48,7 @@ model_matrix_initial_formula <- function(formula, data, n_sequences,
)
)
)
iv <- nrow(unique(X)) > 1L
coef_names <- colnames(X)
n_pars <- (n_states - 1L) * ncol(X)
}
Expand Down Expand Up @@ -183,9 +183,8 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters,
na.action = stats::na.pass
)
missing_values <- which(!complete.cases(X))
iv <- nrow(unique(X[-missing_values, ])) > 1
stopifnot_(
length(missing_values) == 0,
length(missing_values) == 0L,
c(
"Missing cases are not allowed in covariates of `cluster_formula`.",
"Use {.fn complete.cases} to detect them, then fix or impute them.",
Expand All @@ -195,6 +194,7 @@ model_matrix_cluster_formula <- function(formula, data, n_sequences, n_clusters,
)
)
)
iv <- nrow(unique(X)) > 1L
coef_names <- colnames(X)
n_pars <- (n_clusters - 1L) * ncol(X)
}
Expand Down
12 changes: 10 additions & 2 deletions src/get_parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ arma::vec get_omega(const arma::mat& gamma_raw, const arma::vec& X, const bool l
// [[Rcpp::export]]
arma::mat get_omega_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace) {
arma::mat beta = arma::join_cols(arma::zeros<arma::rowvec>(gamma_raw.n_cols), gamma_raw);
return softmax(beta * X, logspace).t();
arma::mat omega(beta.n_rows, X.n_cols);
for (unsigned int i = 0; i < X.n_cols; i++) {
omega.col(i) = softmax(beta * X.col(i), logspace);
}
return omega;
}
// gamma_raw is (S - 1) x K (start from, covariates)
// X a vector of length K
Expand All @@ -26,7 +30,11 @@ arma::vec get_pi(const arma::mat& gamma_raw, const arma::vec& X, const bool logs
// [[Rcpp::export]]
arma::mat get_pi_all(const arma::mat& gamma_raw, const arma::mat& X, const bool logspace) {
arma::mat beta = arma::join_cols(arma::zeros<arma::rowvec>(gamma_raw.n_cols), gamma_raw);
return softmax(beta * X, logspace).t();
arma::mat pi(beta.n_rows, X.n_cols);
for (unsigned int i = 0; i < X.n_cols; i++) {
pi.col(i) = softmax(beta * X.col(i), logspace);
}
return pi;
}

// gamma_raw is (S - 1) x K x S (transition to, covariates, transition from)
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test-build_lcm.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ test_that("build_lcm returns object of class 'mhmm'", {
emission_probs = cbind(1, matrix(0, 2, s - 1))),
NA
)
expect_error(
build_lcm(list(obs, obs), n_clusters = k),
NA
)
expect_warning(
model <- build_lcm(
list(obs, obs), n_clusters = k,
Expand Down
72 changes: 55 additions & 17 deletions tests/testthat/test-fit_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,42 @@ test_that("'fit_model' works for 'hmm'", {
),
NA
)

model <- build_hmm(
observations = sim$observations,
transition_probs = transition_probs,
emission_probs = emission_probs,
initial_probs = initial_probs
)

set.seed(1)
expect_error(
fit2 <- fit_model(
model, em_step = TRUE, global_step = TRUE,
local_step = TRUE,
control_em = list(restart = list(times = 2), maxeval = 10),
control_global = list(maxeval = 10),
control_local = list(maxeval = 10)
),
NA
)
set.seed(1)
expect_error(
fit <- fit_model(
fit1 <- fit_model(
model, em_step = TRUE, global_step = TRUE,
local_step = TRUE,
control_em = list(restart = list(times = 2), maxeval = 100),
control_global = list(maxeval = 100),
control_local = list(maxeval = 100)
control_em = list(restart = list(times = 2), maxeval = 10),
control_global = list(maxeval = 10),
control_local = list(maxeval = 10),
log_space = FALSE
),
NA
)
expect_equal(fit1$model, fit2$model)
expect_equal(logLik(fit1$model), logLik(fit2$model, log_space = FALSE))
expect_equal(
logLik(fit1$model, threads = 2L)[1],
sum(logLik(fit1$model, partials = TRUE))
)
})

test_that("'fit_model' works for 'mhmm'", {
Expand All @@ -39,28 +57,28 @@ test_that("'fit_model' works for 'mhmm'", {
emission_probs_2 <- matrix(c(0.1, 0.8, 0.9, 0.2), 2, 2)
colnames(emission_probs_1) <- colnames(emission_probs_2) <-
c("heads", "tails")

transition_probs_1 <- matrix(c(9, 0.1, 1, 9.9) / 10, 2, 2)
transition_probs_2 <- matrix(c(35, 1, 1, 35) / 36, 2, 2)
rownames(emission_probs_1) <- rownames(transition_probs_1) <-
colnames(transition_probs_1) <- c("coin 1", "coin 2")
rownames(emission_probs_2) <- rownames(transition_probs_2) <-
colnames(transition_probs_2) <- c("coin 3", "coin 4")

initial_probs_1 <- c(1, 0)
initial_probs_2 <- c(1, 0)

n <- 30
covariate_1 <- runif(n)
covariate_2 <- sample(c("A", "B"),
size = n, replace = TRUE,
prob = c(0.3, 0.7)
)
dataf <- data.frame(covariate_1, covariate_2)

coefs <- cbind(cluster_1 = c(0, 0, 0), cluster_2 = c(-1.5, 3, -0.7))
rownames(coefs) <- c("(Intercept)", "covariate_1", "covariate_2B")

expect_error(
sim <- simulate_mhmm(
n = n, initial_probs = list(initial_probs_1, initial_probs_2),
Expand All @@ -71,7 +89,7 @@ test_that("'fit_model' works for 'mhmm'", {
),
NA
)

expect_error(
model <- build_mhmm(
sim$observations,
Expand All @@ -82,15 +100,35 @@ test_that("'fit_model' works for 'mhmm'", {
data = dataf),
NA
)


set.seed(1)
expect_error(
fit <- fit_model(
fit1 <- fit_model(
model, em_step = TRUE, global_step = TRUE,
local_step = TRUE,
control_em = list(restart = list(times = 2), maxeval = 100),
control_global = list(maxeval = 100),
control_local = list(maxeval = 100)
control_em = list(restart = list(times = 2), maxeval = 10),
control_global = list(maxeval = 10),
control_local = list(maxeval = 10)
),
NA
)

set.seed(1)
expect_error(
fit2 <- fit_model(
model, em_step = TRUE, global_step = TRUE,
local_step = TRUE,
control_em = list(restart = list(times = 2), maxeval = 10),
control_global = list(maxeval = 10),
control_local = list(maxeval = 10),
log_space = FALSE
),
NA
)
expect_equal(fit1$model, fit2$model)
expect_equal(logLik(fit1$model), logLik(fit2$model, log_space = FALSE))
expect_equal(
logLik(fit1$model, threads = 2L)[1],
sum(logLik(fit1$model, partials = TRUE))
)
})
35 changes: 30 additions & 5 deletions tests/testthat/test-get_probs.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("'get_probs' works for multichannel 'nhmm'", {
test_that("'get_probs' and 'coef' works for multichannel 'nhmm'", {
data("hmm_biofam")
set.seed(1)
expect_error(
Expand All @@ -14,8 +14,12 @@ test_that("'get_probs' works for multichannel 'nhmm'", {
p <- get_probs(fit),
NA
)
expect_error(
coef(fit),
NA
)
})
test_that("'get_probs' works for single-channel 'nhmm'", {
test_that("'get_probs' and 'coef' works for single-channel 'nhmm'", {
data("hmm_biofam")
set.seed(1)
expect_error(
Expand All @@ -28,9 +32,13 @@ test_that("'get_probs' works for single-channel 'nhmm'", {
p <- get_probs(fit),
NA
)
expect_error(
coef(fit),
NA
)
})

test_that("'get_probs' works for multichannel 'mnhmm'", {
test_that("'get_probs' and 'coef' works for multichannel 'mnhmm'", {
data("hmm_biofam")
set.seed(1)
expect_error(
Expand All @@ -43,18 +51,35 @@ test_that("'get_probs' works for multichannel 'mnhmm'", {
p <- get_probs(fit),
NA
)
expect_error(
coef(fit),
NA
)
})

test_that("'get_probs' works for single-channel 'mnhmm'", {
test_that("'get_probs' and 'coef' works for single-channel 'mnhmm'", {
set.seed(1)
d <- data.frame(
group = rep(1:50, each = 16),
time = 1:16,
z = rnorm(16 * 50),
w = 1:16
)
expect_error(
fit <- estimate_mnhmm(
hmm_biofam$observations[[1]], n_states = 4, n_clusters = 2, maxeval = 1
hmm_biofam$observations[[1]][1:50, ], n_states = 4, n_clusters = 2,
initial_formula = ~ z, cluster_formula = ~ z,
transition_formula = ~w, emission_formula = ~ w,
data = d, time = "time", id = "group", maxeval = 1
),
NA
)
expect_error(
p <- get_probs(fit),
NA
)
expect_error(
coef(fit),
NA
)
})

0 comments on commit c7433a1

Please sign in to comment.