diff --git a/NEWS.md b/NEWS.md index c8c3ddae3..596296247 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ ### New Features +* Add experimental support for the `pathfinder` and `laplace` algorithms +in the `cmdstanr` backend. (#1591) * Automatically recompute fit criteria previously stored in the model if potentially results-changing arguments are provided to the criterion method. * Allow to turn off automatic broadcasting of `constant` priors. diff --git a/R/backends.R b/R/backends.R index b3eb7beba..b8af3cf87 100644 --- a/R/backends.R +++ b/R/backends.R @@ -215,6 +215,7 @@ fit_model <- function(model, backend, ...) { } else { stop2("Algorithm '", algorithm, "' is not supported.") } + # TODO: add support for pathfinder and laplace out <- repair_stanfit(out) out } @@ -242,13 +243,6 @@ fit_model <- function(model, backend, ...) { stop2("Argument 'future' is not supported by backend 'cmdstanr'.") } args <- nlist(data = sdata, seed, init) - if (use_threading(threads)) { - if (algorithm %in% c("sampling", "fixed_param")) { - args$threads_per_chain <- threads$threads - } else if (algorithm %in% c("fullrank", "meanfield")) { - args$threads <- threads$threads - } - } if (use_opencl(opencl)) { args$opencl_ids <- opencl$ids } @@ -282,18 +276,33 @@ fit_model <- function(model, backend, ...) { show_exceptions = silent == 0, fixed_param = algorithm == "fixed_param" ) + if (use_threading(threads)) { + args$threads_per_chain <- threads$threads + } out <- do_call(model$sample, args) } else if (algorithm %in% c("fullrank", "meanfield")) { - # vb does not support parallel execution c(args) <- nlist(iter, algorithm) + if (use_threading(threads)) { + args$threads <- threads$threads + } out <- do_call(model$variational, args) + } else if (algorithm %in% c("pathfinder")) { + if (use_threading(threads)) { + args$num_threads <- threads$threads + } + out <- do_call(model$pathfinder, args) + } else if (algorithm %in% c("laplace")) { + if (use_threading(threads)) { + args$threads <- threads$threads + } + out <- do_call(model$laplace, args) } else { stop2("Algorithm '", algorithm, "' is not supported.") } out <- read_csv_as_stanfit( out$output_files(), variables = out$metadata()$variables, - model = model, exclude = exclude + model = model, exclude = exclude, algorithm = algorithm ) if (empty_model) { @@ -416,7 +425,7 @@ backend_choices <- function() { # supported Stan algorithms algorithm_choices <- function() { - c("sampling", "meanfield", "fullrank", "fixed_param") + c("sampling", "meanfield", "fullrank", "pathfinder", "laplace", "fixed_param") } # check if the model was fit the the required backend @@ -651,6 +660,8 @@ file_refit_options <- function() { #' if you want to allow updating the model without recompilation. #' @param exclude Character vector of variables to exclude from the stanfit. Only #' used when \code{variables} is also specified. +#' @param algorithm The algorithm with which the model was fitted. +#' See \code{\link{brm}} for details. #' #' @return A stanfit object consistent with the structure of the \code{fit} #' slot of a brmsfit object. @@ -672,7 +683,7 @@ file_refit_options <- function() { #' #' @export read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = NULL, - model = NULL, exclude = "") { + model = NULL, exclude = "", algorithm = "sampling") { require_package("cmdstanr") if (!is.null(variables)) { @@ -680,8 +691,16 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N variables <- repair_variable_names(variables) variables <- unique(sub("\\[.+", "", variables)) variables <- setdiff(variables, exclude) - # temp fix for cmdstanr not recognizing the variable names it produces #1473 - variables <- ifelse(variables == "lp_approx__", "log_g__", variables) + # cmdstanr deals with special variables inconsistently + # below is an attempt to deal with this somehow (part 1) + if (algorithm %in% c("meanfield", "fullrank")) { + # temp fix for cmdstanr not recognizing the variable names it produces #1473 + variables <- ifelse(variables == "lp_approx__", "log_g__", variables) + } else if (algorithm %in% "pathfinder") { + variables <- setdiff(variables, "lp_approx__") + } else if (algorithm %in% "laplace") { + variables <- setdiff(variables, c("lp__", "lp_approx__")) + } } csfit <- cmdstanr::read_cmdstan_csv( @@ -695,8 +714,18 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N # @model_pars svars <- variables %||% csfit$metadata$stan_variables - if ("lp__" %in% svars) { - svars <- c(setdiff(svars, "lp__"), "lp__") + # cmdstanr deals with special variables inconsistently + # below is an attempt to deal with this somehow (part 2) + special_vars <- c("lp__", "lp_approx__", "log_g__") + vars_in_draws <- variables(csfit$draws) + for (v in intersect(special_vars, svars)) { + if (v %in% vars_in_draws) { + # put special vars at the end + svars <- c(setdiff(svars, v), v) + } else { + # remove special vars as they do not seem to be stored in draws + svars <- setdiff(svars, v) + } } pars_oi <- svars par_names <- csfit$metadata$model_params diff --git a/man/read_csv_as_stanfit.Rd b/man/read_csv_as_stanfit.Rd index 2437c5009..c7ef15d82 100644 --- a/man/read_csv_as_stanfit.Rd +++ b/man/read_csv_as_stanfit.Rd @@ -9,7 +9,8 @@ read_csv_as_stanfit( variables = NULL, sampler_diagnostics = NULL, model = NULL, - exclude = "" + exclude = "", + algorithm = "sampling" ) } \arguments{ @@ -24,6 +25,9 @@ if you want to allow updating the model without recompilation.} \item{exclude}{Character vector of variables to exclude from the stanfit. Only used when \code{variables} is also specified.} + +\item{algorithm}{The algorithm with which the model was fitted. +See \code{\link{brm}} for details.} } \value{ A stanfit object consistent with the structure of the \code{fit} diff --git a/tests/local/tests.models-5.R b/tests/local/tests.models-5.R index 8ad695d1c..2bb82d756 100644 --- a/tests/local/tests.models-5.R +++ b/tests/local/tests.models-5.R @@ -167,6 +167,29 @@ test_that("projpred methods can be run", { # expect_is(vs, "vsel") }) +test_that("alternative algorithms can be used", { + fit <- brm( + count ~ zBase * Trt, data = epilepsy, + backend = "cmdstanr", algorithm = "meanfield" + ) + summary(fit) + expect_is(fit, "brmsfit") + + fit <- brm( + count ~ zBase * Trt, data = epilepsy, + backend = "cmdstanr", algorithm = "pathfinder" + ) + summary(fit) + expect_is(fit, "brmsfit") + + fit <- brm( + count ~ zBase * Trt, data = epilepsy, + backend = "cmdstanr", algorithm = "laplace" + ) + summary(fit) + expect_is(fit, "brmsfit") +}) + test_that(paste( "Families sratio() and cratio() are equivalent for symmetric distribution", "functions (here only testing the logit link)"