Skip to content

Commit

Permalink
Merge pull request #251 from paul-buerkner/feature-issue-246
Browse files Browse the repository at this point in the history
Feature issue 246
  • Loading branch information
paul-buerkner authored Aug 15, 2017
2 parents bac1dcf + ed136e4 commit 04dcc3e
Show file tree
Hide file tree
Showing 42 changed files with 1,959 additions and 1,411 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using Stan
Version: 1.8.0.1
Version: 1.8.0.2
Date: 2017-07-19
Authors@R: person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
role = c("aut", "cre"))
Expand Down
18 changes: 14 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method("+",brmsformula)
S3method("+",brmsprior)
S3method(LOO,brmsfit)
S3method(VarCorr,brmsfit)
S3method(WAIC,brmsfit)
Expand All @@ -11,8 +13,6 @@ S3method(as.data.frame,brmsVarCorr)
S3method(as.data.frame,brmsfit)
S3method(as.matrix,brmsfit)
S3method(as.mcmc,brmsfit)
S3method(auxpar_family,default)
S3method(auxpar_family,mixfamily)
S3method(bayes_R2,brmsfit)
S3method(bayes_factor,brmsfit)
S3method(bridge_sampler,brmsfit)
Expand All @@ -25,11 +25,15 @@ S3method(coef,brmsfit)
S3method(control_params,brmsfit)
S3method(data_effects,btl)
S3method(data_effects,btnl)
S3method(dpar_family,default)
S3method(dpar_family,mixfamily)
S3method(expose_functions,brmsfit)
S3method(extract_draws,brmsfit)
S3method(extract_draws,btl)
S3method(extract_draws,btnl)
S3method(family,brmsfit)
S3method(family_names,brmsformula)
S3method(family_names,brmsterms)
S3method(family_names,default)
S3method(family_names,family)
S3method(family_names,mixfamily)
Expand All @@ -56,6 +60,9 @@ S3method(loo,brmsfit)
S3method(loo_linpred,brmsfit)
S3method(loo_predict,brmsfit)
S3method(loo_predictive_interval,brmsfit)
S3method(make_Jmo_list,brmsterms)
S3method(make_Jmo_list,btl)
S3method(make_Jmo_list,btnl)
S3method(make_gp_list,brmsterms)
S3method(make_gp_list,btl)
S3method(make_gp_list,btnl)
Expand Down Expand Up @@ -122,8 +129,8 @@ S3method(summary,family)
S3method(summary,mixfamily)
S3method(update,brmsfit)
S3method(update,brmsformula)
S3method(valid_auxpars,default)
S3method(valid_auxpars,mixfamily)
S3method(valid_dpars,default)
S3method(valid_dpars,mixfamily)
S3method(vcov,brmsfit)
S3method(waic,brmsfit)
export("add_ic<-")
Expand Down Expand Up @@ -211,6 +218,7 @@ export(kfold)
export(lasso)
export(launch_shiny)
export(launch_shinystan)
export(lf)
export(log_lik)
export(log_posterior)
export(logit_scaled)
Expand All @@ -233,6 +241,7 @@ export(monotonic)
export(neff_ratio)
export(negbinomial)
export(ngrps)
export(nlf)
export(nsamples)
export(nuts_params)
export(parnames)
Expand Down Expand Up @@ -284,6 +293,7 @@ export(rskew_normal)
export(rstudent_t)
export(rvon_mises)
export(rwiener)
export(set_nl)
export(set_prior)
export(skew_normal)
export(sratio)
Expand Down
2 changes: 1 addition & 1 deletion R/brm.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
if (future) {
require_package("future")
if (cores > 1L) {
warning("Argument 'cores' is ignored when using 'future'.")
warning2("Argument 'cores' is ignored when using 'future'.")
}
args$chains <- 1L
futures <- fits <- vector("list", chains)
Expand Down
37 changes: 20 additions & 17 deletions R/brmsfit-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ restructure <- function(x, rstr_summary = FALSE) {
}
}
if (version <= "0.10.0.9000") {
if (length(bterms$auxpars$mu$nlpars)) {
if (length(bterms$dpars$mu$nlpars)) {
# nlpar and group have changed positions
change <- change_old_re(x$ranef, pars = parnames(x),
dims = x$fit@sim$dims_oi)
Expand Down Expand Up @@ -137,6 +137,9 @@ restructure <- function(x, rstr_summary = FALSE) {
any(grepl("^Xme_", parnames(x)))
}
}
if (version <= "1.8.0.1") {
x$prior[, c("resp", "dpar")] <- ""
}
stan_env <- attributes(x$fit)$.MISC
if (rstr_summary && exists("summary", stan_env)) {
stan_summary <- get("summary", stan_env)
Expand Down Expand Up @@ -238,7 +241,7 @@ prepare_conditions <- function(x, conditions = NULL, effects = NULL,
lapply(get_effect(bterms, "offset"), rhs),
re$form, lapply(re$gcall, "[[", "weightvars"),
bterms$adforms[c("se", "disp", "trials", "cat")],
bterms$auxpars$mu$covars
bterms$dpars$mu$covars
)
req_vars <- unique(ulapply(req_vars, all.vars))
req_vars <- setdiff(req_vars, rsv_vars)
Expand Down Expand Up @@ -662,8 +665,8 @@ get_cov_matrix_ident <- function(sigma, nrows, se2 = 0) {
mat
}

get_auxpar <- function(x, i = NULL) {
# get samples of an auxiliary parameter
get_dpar <- function(x, i = NULL) {
# get samples of an distributional parameter
# Args:
# x: object to extract postarior samples from
# i: the current observation number
Expand All @@ -673,7 +676,7 @@ get_auxpar <- function(x, i = NULL) {
family <- x[["f"]]
x <- get_eta(x, i = i)
if (!nzchar(family$family)) {
# apply links for auxiliary parameters only
# apply links for distributional parameters only
# the main family link is applied later on
x <- ilink(x, family$link)
}
Expand All @@ -692,21 +695,21 @@ get_auxpar <- function(x, i = NULL) {
get_sigma <- function(x, data, i = NULL, dim = NULL) {
# get the residual standard devation of linear models
# Args:
# see get_auxpar
# see get_dpar
# dim: target dimension of output matrices (used in fitted)
stopifnot(is.atomic(x) || is.list(x))
out <- get_se(data = data, i = i, dim = dim)
if (!is.null(x)) {
out <- sqrt(out^2 + get_auxpar(x, i = i)^2)
out <- sqrt(out^2 + get_dpar(x, i = i)^2)
}
mult_disp(out, data = data, i = i, dim = dim)
}

get_shape <- function(x, data, i = NULL, dim = NULL) {
# get the shape parameter of gamma, weibull and negbinomial models
# Args: see get_auxpar
# Args: see get_dpar
stopifnot(is.atomic(x) || is.list(x))
x <- get_auxpar(x, i = i)
x <- get_dpar(x, i = i)
mult_disp(x, data = data, i = i, dim = dim)
}

Expand All @@ -715,14 +718,14 @@ get_zi_hu <- function(draws, i = NULL, par = c("zi", "hu")) {
# also works with deprecated models fitted with brms < 1.0.0
# which were using multivariate syntax
# Args:
# see get_auxpar
# see get_dpar
# par: parameter to extract; either 'zi' or 'hu'
par <- match.arg(par)
if (!is.null(draws$data$N_trait)) {
j <- if (!is.null(i)) i else seq_len(draws$data$N_trait)
out <- ilink(get_eta(draws$mu, j + draws$data$N_trait), "logit")
} else {
out <- get_auxpar(draws[[par]], i = i)
out <- get_dpar(draws[[par]], i = i)
}
out
}
Expand All @@ -736,7 +739,7 @@ get_theta <- function(draws, i = NULL) {
families <- family_names(draws$f)
theta <- vector("list", length(families))
for (j in seq_along(families)) {
theta[[j]] <- get_auxpar(draws[[paste0("theta", j)]], i = i)
theta[[j]] <- get_dpar(draws[[paste0("theta", j)]], i = i)
}
theta <- do.call(abind, c(theta, along = 3))
for (n in seq_len(dim(theta)[2])) {
Expand All @@ -752,10 +755,10 @@ get_theta <- function(draws, i = NULL) {
get_disc <- function(draws, i = NULL, ncat = NULL) {
# convenience function to extract discrimination parameters
# Args:
# see get_auxpar
# see get_dpar
# ncat: number of response categories
if (!is.null(draws[["disc"]])) {
disc <- get_auxpar(draws[["disc"]], i)
disc <- get_dpar(draws[["disc"]], i)
if (!is.null(dim(disc))) {
stopifnot(is.numeric(ncat))
disc <- array(disc, dim = c(dim(disc), ncat - 1))
Expand All @@ -768,7 +771,7 @@ get_disc <- function(draws, i = NULL, ncat = NULL) {

get_se <- function(data, i = NULL, dim = NULL) {
# extract user-defined standard errors
# Args: see get_auxpar
# Args: see get_dpar
se <- data[["se"]]
if (!is.null(se)) {
if (!is.null(i)) {
Expand All @@ -785,7 +788,7 @@ get_se <- function(data, i = NULL, dim = NULL) {

mult_disp <- function(x, data, i = NULL, dim = NULL) {
# multiply existing samples by 'disp' data
# Args: see get_auxpar
# Args: see get_dpar
if (!is.null(data$disp)) {
if (!is.null(i)) {
x <- x * data$disp[i]
Expand Down Expand Up @@ -884,7 +887,7 @@ fixef_pars <- function() {
default_plot_pars <- function() {
# list all parameter classes to be included in plots by default
c(fixef_pars(), "^sd_", "^cor_", "^sigma_", "^rescor_",
paste0("^", auxpars(), "[[:digit:]]*$"), "^delta$",
paste0("^", dpars(), "[[:digit:]]*$"), "^delta$",
"^theta", "^ar", "^ma", "^arr", "^lagsar", "^errorsar",
"^car", "^sdcar", "^sigmaLL", "^sds_", "^sdgp_", "^lscale_")
}
Expand Down
Loading

0 comments on commit 04dcc3e

Please sign in to comment.