From 53d0b3709caa627cea75ab6f222133df9534b941 Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Sat, 21 Sep 2024 23:21:18 +0300 Subject: [PATCH] fix return code handling, add iterations to output --- R/fit_mnhmm.R | 5 +++-- R/fit_nhmm.R | 5 +++-- tests/testthat/test-forward_backward.R | 12 ++++++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 8fd6d35d..8c0e9c0a 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -163,7 +163,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { logliks <- unlist(lapply(out, "[[", "objective")) return_codes <- unlist(lapply(out, "[[", "status")) - successful <- which(return_codes == 0) + successful <- which(return_codes > 0) optimum <- successful[which.max(logliks[successful])] init <- out[[optimum]]$solution } else { @@ -182,7 +182,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { opts = dots ) if (out$status < 0) { - warning_(paste("Local optimization terminated:", out$message)) + warning_(paste("Optimization terminated due to error:", out$message)) } pars <- out$solution model$coefficients$gamma_pi_raw <- create_gamma_pi_raw_mnhmm( @@ -222,6 +222,7 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { loglik = out$objective, return_code = out$status, message = out$message, + iterations = out$iterations, logliks_of_restarts = if(restarts > 0L) logliks else NULL, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL ) diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index dcfdaac8..78bf390e 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -133,7 +133,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { logliks <- unlist(lapply(out, "[[", "objective")) return_codes <- unlist(lapply(out, "[[", "status")) - successful <- which(return_codes == 0) + successful <- which(return_codes > 0) optimum <- successful[which.max(logliks[successful])] init <- out[[optimum]]$solution } else { @@ -152,7 +152,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { opts = dots ) if (out$status < 0) { - warning_(paste("Local optimization terminated:", out$message)) + warning_(paste("Optimization terminated due to error:", out$message)) } pars <- out$solution model$coefficients$gamma_pi_raw <- create_gamma_pi_raw_nhmm(pars[seq_len(n_i)], S, K_i) @@ -185,6 +185,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, hessian, ...) { loglik = out$objective, return_code = out$status, message = out$message, + iterations = out$iterations, logliks_of_restarts = if(restarts > 0L) logliks else NULL, return_codes_of_restarts = if(restarts > 0L) return_codes else NULL ) diff --git a/tests/testthat/test-forward_backward.R b/tests/testthat/test-forward_backward.R index c2a0e860..5256e080 100644 --- a/tests/testthat/test-forward_backward.R +++ b/tests/testthat/test-forward_backward.R @@ -39,7 +39,7 @@ test_that("'forward_backward' works for multichannel 'nhmm'", { hmm_biofam$observations, n_states = 5, inits = hmm_biofam[ c("initial_probs", "transition_probs", "emission_probs") - ] + ], maxeval = 1 ), NA ) @@ -65,7 +65,7 @@ test_that("'forward_backward' works for single-channel 'nhmm'", { expect_error( fit <- estimate_nhmm( hmm_biofam$observations[[1]], n_states = 3, - restarts = 2, threads = 1 + restarts = 2, threads = 1, maxeval = 2 ), NA ) @@ -73,8 +73,8 @@ test_that("'forward_backward' works for single-channel 'nhmm'", { fb <- forward_backward(fit, as_data_frame = FALSE), NA ) - expect_gte(min(fb$forward_probs), -60) - expect_gte(min(fb$backward_probs), -60) + expect_gte(min(fb$forward_probs), -2000) + expect_gte(min(fb$backward_probs), -2000) expect_lte(max(fb$forward_probs), 0) expect_lte(max(fb$backward_probs), 0) @@ -99,8 +99,8 @@ test_that("'forward_backward' works for multichannel 'mnhmm'", { fb <- forward_backward(fit, as_data_frame = FALSE), NA ) - expect_gte(min(fb$forward_probs), -130) - expect_gte(min(fb$backward_probs), -120) + expect_gte(min(fb$forward_probs), -2000) + expect_gte(min(fb$backward_probs), -2000) expect_lte(max(fb$forward_probs), 0) expect_lte(max(fb$backward_probs), 0)