Skip to content

Commit

Permalink
fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 6, 2024
1 parent 8915cec commit 1ff2404
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 14 deletions.
6 changes: 3 additions & 3 deletions R/average_marginal_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ average_marginal_prediction <- function(
"Argument {.arg variable} must be a single character string."
)
stopifnot_(
length(values) != 2,
length(values) == 2,
"Argument {.arg values} should contain two values for
variable {.var variable}.")
if (is.null(newdata)) {
if (!is.null(newdata)) {
time <- model$time_variable
id <- model$id_variable
stopifnot_(
Expand All @@ -62,7 +62,7 @@ average_marginal_prediction <- function(
)
} else {
stopifnot_(
!is.null(model$data),
is.null(model$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
Expand Down
10 changes: 5 additions & 5 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ predict.nhmm <- function(
)
} else {
stopifnot_(
!is.null(object$data),
is.null(object$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
Expand Down Expand Up @@ -105,22 +105,22 @@ predict.mnhmm <- function(
)
} else {
stopifnot_(
!is.null(object$data),
is.null(object$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
object <- update(object, newdata = newdata)
}

beta_i_raw <- stan_to_cpp_initial(
object$coefficients$beta_i_raw
object$coefficients$beta_i_raw, object$n_clusters
)
beta_s_raw <- stan_to_cpp_transition(
object$coefficients$beta_s_raw
object$coefficients$beta_s_raw, object$n_clusters
)
beta_o_raw <- stan_to_cpp_emission(
object$coefficients$beta_o_raw,
1,
object$n_clusters,
object$n_channels > 1
)
X_initial <- t(object$X_initial)
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-build_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ test_that("estimate_mnhmm errors with incorrect observations", {
})
test_that("build_mnhmm works with vector of characters as observations", {
expect_error(
estimate_mnhmm("y", s, d, data = data, time = "time", id = "id", iter = 0),
estimate_mnhmm("y", s, d, data = data, time = "time", id = "id", iter = 0,
verbose = FALSE),
NA
)
})
6 changes: 3 additions & 3 deletions tests/testthat/test-build_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ test_that("build_nhmm returns object of class 'nhmm'", {
model <- build_nhmm(
obs, s, initial_formula = ~ x, transition_formula = ~z,
emission_formula = ~ z, data = data,
time = "time", id = "id", state_names = 1:s, channel_names = "obs",
verbose = FALSE
time = "time", id = "id", state_names = 1:s, channel_names = "obs"
),
NA
)
Expand Down Expand Up @@ -97,7 +96,8 @@ test_that("estimate_nhmm errors with incorrect observations", {
})
test_that("build_nhmm works with vector of characters as observations", {
expect_error(
estimate_nhmm("y", s, data = data, time = "time", id = "id", iter = 0),
estimate_nhmm("y", s, data = data, time = "time", id = "id", iter = 0,
verbose = FALSE),
NA
)
})
4 changes: 2 additions & 2 deletions tests/testthat/test-simulate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ test_that("simulate_mnhmm, coef and get_probs works", {
)
expect_error(
fit <- estimate_mnhmm(
sim$observations, n_states = 2,
sim$model$obs, n_states = 2,
n_clusters = 3,
initial_formula = ~1, transition_formula = ~ x,
emission_formula = ~ x + z, cluster_formula = ~w,
data = d, time = "month", id = "person",
init = sim$model$coefficients,
inits = sim$model$coefficients,
iter = 1, verbose = FALSE, hessian = FALSE),
NA
)
Expand Down

0 comments on commit 1ff2404

Please sign in to comment.