Skip to content

Commit

Permalink
fix return code handling, add iterations to output
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 21, 2024
1 parent 6d44d0b commit 53d0b37
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
5 changes: 3 additions & 2 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {

logliks <- unlist(lapply(out, "[[", "objective"))
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes == 0)
successful <- which(return_codes > 0)
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
} else {
Expand All @@ -182,7 +182,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
opts = dots
)
if (out$status < 0) {
warning_(paste("Local optimization terminated:", out$message))
warning_(paste("Optimization terminated due to error:", out$message))
}
pars <- out$solution
model$coefficients$gamma_pi_raw <- create_gamma_pi_raw_mnhmm(
Expand Down Expand Up @@ -222,6 +222,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
loglik = out$objective,
return_code = out$status,
message = out$message,
iterations = out$iterations,
logliks_of_restarts = if(restarts > 0L) logliks else NULL,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL
)
Expand Down
5 changes: 3 additions & 2 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {

logliks <- unlist(lapply(out, "[[", "objective"))
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes == 0)
successful <- which(return_codes > 0)
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
} else {
Expand All @@ -152,7 +152,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
opts = dots
)
if (out$status < 0) {
warning_(paste("Local optimization terminated:", out$message))
warning_(paste("Optimization terminated due to error:", out$message))
}
pars <- out$solution
model$coefficients$gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i)
Expand Down Expand Up @@ -185,6 +185,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) {
loglik = out$objective,
return_code = out$status,
message = out$message,
iterations = out$iterations,
logliks_of_restarts = if(restarts > 0L) logliks else NULL,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL
)
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test-forward_backward.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ test_that("'forward_backward' works for multichannel 'nhmm'", {
hmm_biofam$observations, n_states = 5,
inits = hmm_biofam[
c("initial_probs", "transition_probs", "emission_probs")
]
], maxeval = 1
),
NA
)
Expand All @@ -65,16 +65,16 @@ test_that("'forward_backward' works for single-channel 'nhmm'", {
expect_error(
fit <- estimate_nhmm(
hmm_biofam$observations[[1]], n_states = 3,
restarts = 2, threads = 1
restarts = 2, threads = 1, maxeval = 2
),
NA
)
expect_error(
fb <- forward_backward(fit, as_data_frame = FALSE),
NA
)
expect_gte(min(fb$forward_probs), -60)
expect_gte(min(fb$backward_probs), -60)
expect_gte(min(fb$forward_probs), -2000)
expect_gte(min(fb$backward_probs), -2000)
expect_lte(max(fb$forward_probs), 0)
expect_lte(max(fb$backward_probs), 0)

Expand All @@ -99,8 +99,8 @@ test_that("'forward_backward' works for multichannel 'mnhmm'", {
fb <- forward_backward(fit, as_data_frame = FALSE),
NA
)
expect_gte(min(fb$forward_probs), -130)
expect_gte(min(fb$backward_probs), -120)
expect_gte(min(fb$forward_probs), -2000)
expect_gte(min(fb$backward_probs), -2000)
expect_lte(max(fb$forward_probs), 0)
expect_lte(max(fb$backward_probs), 0)

Expand Down

0 comments on commit 53d0b37

Please sign in to comment.