From a5023a965dc8bc135860cef8140ec4409e3d20c0 Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Tue, 24 Sep 2024 11:06:53 +0300 Subject: [PATCH] inits for mixtures --- R/create_initial_values.R | 8 ++++---- tests/testthat/test-simulate_mnhmm.R | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/create_initial_values.R b/R/create_initial_values.R index 3fdc26e9..70c52683 100644 --- a/R/create_initial_values.R +++ b/R/create_initial_values.R @@ -49,7 +49,7 @@ create_gamma_pi_inits <- function(x, S, K, init_sd = 0, D = 1) { create_gamma_pi_raw_mnhmm(rnorm((S - 1) * K * D, sd = init_sd), S, K, D) } else { stopifnot_( - length(x) == (S - 1) * K * D, + length(unlist(x)) == (S - 1) * K * D, paste0( "Number of initial values for {.val gamma_pi} is not equal to ", "(S - 1) * K * D = {(S - 1) * K * D}." @@ -78,7 +78,7 @@ create_gamma_A_inits <- function(x, S, K, init_sd = 0, D = 1) { create_gamma_A_raw_mnhmm(rnorm((S - 1) * K * S * D, sd = init_sd), S, K, D) } else { stopifnot_( - length(x) == (S - 1) * K * S * D, + length(unlist(x)) == (S - 1) * K * S * D, paste0( "Number of initial values for {.val gamma_A} is not equal to ", "(S - 1) * K * S * D = {(S - 1) * K * S * D}." @@ -110,7 +110,7 @@ create_gamma_B_inits <- function(x, S, M, K, init_sd = 0, D = 1) { ) } else { stopifnot_( - length(x) == sum((M - 1) * K * S) * D, + length(unlist(x)) == sum((M - 1) * K * S) * D, paste0( "Number of initial values for {.val gamma_B} is not equal to ", "sum((M - 1) * K * S) * D = {sum((M - 1) * K * S) * D}." @@ -142,7 +142,7 @@ create_gamma_B_inits <- function(x, S, M, K, init_sd = 0, D = 1) { ) } else { stopifnot_( - length(x) == (M - 1) * K * S * D, + length(unlist(x)) == (M - 1) * K * S * D, paste0( "Number of initial values for {.val gamma_B} is not equal to ", "(M - 1) * K * S * D = {(M - 1) * K * S * D}." diff --git a/tests/testthat/test-simulate_mnhmm.R b/tests/testthat/test-simulate_mnhmm.R index 5505184f..1bff8616 100644 --- a/tests/testthat/test-simulate_mnhmm.R +++ b/tests/testthat/test-simulate_mnhmm.R @@ -42,7 +42,7 @@ test_that("simulate_mnhmm, coef and get_probs works", { initial_formula = ~1, transition_formula = ~ x, emission_formula = ~ x + z, cluster_formula = ~w, data = d, time = "month", id = "person", - inits = sim$model$coefficients), + inits = sim$model$coefficients, maxeval = 1), NA ) expect_error(