Skip to content

Commit

Permalink
feature issue #1591
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Mar 18, 2024
1 parent f21da87 commit 02784e9
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

### New Features

* Add experimental support for the `pathfinder` and `laplace` algorithms
in the `cmdstanr` backend. (#1591)
* Automatically recompute fit criteria previously stored in the model
if potentially results-changing arguments are provided to the criterion method.
* Allow to turn off automatic broadcasting of `constant` priors.
Expand Down
59 changes: 44 additions & 15 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ fit_model <- function(model, backend, ...) {
} else {
stop2("Algorithm '", algorithm, "' is not supported.")
}
# TODO: add support for pathfinder and laplace
out <- repair_stanfit(out)
out
}
Expand Down Expand Up @@ -242,13 +243,6 @@ fit_model <- function(model, backend, ...) {
stop2("Argument 'future' is not supported by backend 'cmdstanr'.")
}
args <- nlist(data = sdata, seed, init)
if (use_threading(threads)) {
if (algorithm %in% c("sampling", "fixed_param")) {
args$threads_per_chain <- threads$threads
} else if (algorithm %in% c("fullrank", "meanfield")) {
args$threads <- threads$threads
}
}
if (use_opencl(opencl)) {
args$opencl_ids <- opencl$ids
}
Expand Down Expand Up @@ -282,18 +276,33 @@ fit_model <- function(model, backend, ...) {
show_exceptions = silent == 0,
fixed_param = algorithm == "fixed_param"
)
if (use_threading(threads)) {
args$threads_per_chain <- threads$threads
}
out <- do_call(model$sample, args)
} else if (algorithm %in% c("fullrank", "meanfield")) {
# vb does not support parallel execution
c(args) <- nlist(iter, algorithm)
if (use_threading(threads)) {
args$threads <- threads$threads
}
out <- do_call(model$variational, args)
} else if (algorithm %in% c("pathfinder")) {
if (use_threading(threads)) {
args$num_threads <- threads$threads
}
out <- do_call(model$pathfinder, args)
} else if (algorithm %in% c("laplace")) {
if (use_threading(threads)) {
args$threads <- threads$threads
}
out <- do_call(model$laplace, args)
} else {
stop2("Algorithm '", algorithm, "' is not supported.")
}

out <- read_csv_as_stanfit(
out$output_files(), variables = out$metadata()$variables,
model = model, exclude = exclude
model = model, exclude = exclude, algorithm = algorithm
)

if (empty_model) {
Expand Down Expand Up @@ -416,7 +425,7 @@ backend_choices <- function() {

# supported Stan algorithms
algorithm_choices <- function() {
c("sampling", "meanfield", "fullrank", "fixed_param")
c("sampling", "meanfield", "fullrank", "pathfinder", "laplace", "fixed_param")
}

# check if the model was fit the the required backend
Expand Down Expand Up @@ -651,6 +660,8 @@ file_refit_options <- function() {
#' if you want to allow updating the model without recompilation.
#' @param exclude Character vector of variables to exclude from the stanfit. Only
#' used when \code{variables} is also specified.
#' @param algorithm The algorithm with which the model was fitted.
#' See \code{\link{brm}} for details.
#'
#' @return A stanfit object consistent with the structure of the \code{fit}
#' slot of a brmsfit object.
Expand All @@ -672,16 +683,24 @@ file_refit_options <- function() {
#'
#' @export
read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = NULL,
model = NULL, exclude = "") {
model = NULL, exclude = "", algorithm = "sampling") {
require_package("cmdstanr")

if (!is.null(variables)) {
# ensure that only relevant variables are read from CSV
variables <- repair_variable_names(variables)
variables <- unique(sub("\\[.+", "", variables))
variables <- setdiff(variables, exclude)
# temp fix for cmdstanr not recognizing the variable names it produces #1473
variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
# cmdstanr deals with special variables inconsistently
# below is an attempt to deal with this somehow (part 1)
if (algorithm %in% c("meanfield", "fullrank")) {
# temp fix for cmdstanr not recognizing the variable names it produces #1473
variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
} else if (algorithm %in% "pathfinder") {
variables <- setdiff(variables, "lp_approx__")
} else if (algorithm %in% "laplace") {
variables <- setdiff(variables, c("lp__", "lp_approx__"))
}
}

csfit <- cmdstanr::read_cmdstan_csv(
Expand All @@ -695,8 +714,18 @@ read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = N

# @model_pars
svars <- variables %||% csfit$metadata$stan_variables
if ("lp__" %in% svars) {
svars <- c(setdiff(svars, "lp__"), "lp__")
# cmdstanr deals with special variables inconsistently
# below is an attempt to deal with this somehow (part 2)
special_vars <- c("lp__", "lp_approx__", "log_g__")
vars_in_draws <- variables(csfit$draws)
for (v in intersect(special_vars, svars)) {
if (v %in% vars_in_draws) {
# put special vars at the end
svars <- c(setdiff(svars, v), v)
} else {
# remove special vars as they do not seem to be stored in draws
svars <- setdiff(svars, v)
}
}
pars_oi <- svars
par_names <- csfit$metadata$model_params
Expand Down
6 changes: 5 additions & 1 deletion man/read_csv_as_stanfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/local/tests.models-5.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,29 @@ test_that("projpred methods can be run", {
# expect_is(vs, "vsel")
})

test_that("alternative algorithms can be used", {
fit <- brm(
count ~ zBase * Trt, data = epilepsy,
backend = "cmdstanr", algorithm = "meanfield"
)
summary(fit)
expect_is(fit, "brmsfit")

fit <- brm(
count ~ zBase * Trt, data = epilepsy,
backend = "cmdstanr", algorithm = "pathfinder"
)
summary(fit)
expect_is(fit, "brmsfit")

fit <- brm(
count ~ zBase * Trt, data = epilepsy,
backend = "cmdstanr", algorithm = "laplace"
)
summary(fit)
expect_is(fit, "brmsfit")
})

test_that(paste(
"Families sratio() and cratio() are equivalent for symmetric distribution",
"functions (here only testing the logit link)"
Expand Down

0 comments on commit 02784e9

Please sign in to comment.