Skip to content

Commit

Permalink
change def ar priors to enforce stationarity; add noncent for CAR mod…
Browse files Browse the repository at this point in the history
…els; improve loo
  • Loading branch information
nicholasjclark committed Jun 4, 2024
1 parent 8e4117b commit cbf740c
Show file tree
Hide file tree
Showing 64 changed files with 592 additions and 411 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# mvgam 1.1.2
* Added options for silencing some of the 'Stan' compiler and modeling messages using the `silent` argument in `mvgam()`
* Added an option to use `trend_model = 'None'` in State-Space models, increasing flexibility by ensuring the process error evolves as white noise (#51)
* Added an option to use the non-centred parameterisation for some autoregressive trend models,
which speeds up mixing most of the time
* Changed default priors on autoregressive coefficients (AR1, AR2, AR3) to enforce
stationarity, which is a much more sensible prior in the majority of contexts
* Fixed a small bug that prevented `conditional_effects.mvgam()` from handling effects with three-way interactions

# mvgam 1.1.1
Expand Down
48 changes: 26 additions & 22 deletions R/add_trend_lines.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
add_trend_lines = function(model_file, stan = FALSE,
use_lv, trend_model, drift){

if(use_lv & trend_model == 'None'){
trend_model <- 'RW'
}

# Add in necessary trend structure
if(stan){
if(trend_model == 'None'){
Expand Down Expand Up @@ -336,7 +340,7 @@ add_trend_lines = function(model_file, stan = FALSE,
if(drift){
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0(c('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=0,upper=1.5>[n_lv] ar1;\n\n'),
paste0(c('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=0,upper=1>[n_lv] ar1;\n\n'),
'// latent factor drift terms\nvector[n_lv] drift;')

model_file[grep('LV_raw[1, j] ~ ', model_file, fixed = T)] <-
Expand All @@ -347,7 +351,7 @@ add_trend_lines = function(model_file, stan = FALSE,

} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0(c('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=0,upper=1.5>[n_series] ar1;\n\n'),
paste0(c('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=0,upper=1>[n_series] ar1;\n\n'),
'// latent trend drift terms\nvector[n_series] drift;')

model_file[grep('trend[1, s] ~ ', model_file, fixed = T)] <-
Expand All @@ -364,7 +368,7 @@ add_trend_lines = function(model_file, stan = FALSE,
} else {
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=0,upper=1.5>[n_lv] ar1;')
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=0,upper=1>[n_lv] ar1;')

model_file[grep('// dynamic factor estimates', model_file) + 6] <-
paste0('LV_raw[2:n, j] ~ normal(ar1[j] * LV_raw[1:(n - 1), j], 0.1);')
Expand All @@ -374,7 +378,7 @@ add_trend_lines = function(model_file, stan = FALSE,

} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=0,upper=1.5>[n_series] ar1;')
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=0,upper=1>[n_series] ar1;')

model_file[grep('// trend estimates', model_file) + 6] <-
paste0('trend[2:n, s] ~ normal(ar1[s] * trend[1:(n - 1), s], sigma[s]);')
Expand All @@ -391,7 +395,7 @@ add_trend_lines = function(model_file, stan = FALSE,
if(drift){
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0(c('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;\n\n'),
paste0(c('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;\n\n'),
'// latent factor drift terms\nvector[n_lv] drift;')

model_file[grep('LV_raw[1, j] ~ ', model_file, fixed = T)] <-
Expand All @@ -402,7 +406,7 @@ add_trend_lines = function(model_file, stan = FALSE,

} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0(c('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;\n\n'),
paste0(c('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;\n\n'),
'// latent trend drift terms\nvector[n_series] drift;')

model_file[grep('trend[1, s] ~ ', model_file, fixed = T)] <-
Expand All @@ -419,7 +423,7 @@ add_trend_lines = function(model_file, stan = FALSE,
} else {
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;')
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;')

model_file[grep('// dynamic factor estimates', model_file) + 6] <-
paste0('LV_raw[2:n, j] ~ normal(ar1[j] * LV_raw[1:(n - 1), j], 0.1);')
Expand All @@ -429,7 +433,7 @@ add_trend_lines = function(model_file, stan = FALSE,

} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;')
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;')

model_file[grep('// trend estimates', model_file) + 6] <-
paste0('trend[2:n, s] ~ normal(ar1[s] * trend[1:(n - 1), s], sigma[s]);')
Expand All @@ -447,8 +451,8 @@ add_trend_lines = function(model_file, stan = FALSE,
if(drift){
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar2;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1,upper=1>[n_lv] ar2;\n\n',
'// latent factor drift terms\nvector[n_lv] drift;')

model_file[grep('LV_raw[1, j] ~ ', model_file, fixed = T)] <-
Expand All @@ -465,7 +469,7 @@ add_trend_lines = function(model_file, stan = FALSE,
'}\n}\n')
} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;\n\n',
'// latent trend AR2 terms\nvector<lower=-1,upper=1>[n_series] ar2;\n\n',
'// latent trend drift terms\nvector[n_series] drift;')

Expand All @@ -490,8 +494,8 @@ add_trend_lines = function(model_file, stan = FALSE,
} else {
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar2;')
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1,upper=1>[n_lv] ar2;')
model_file[grep('// dynamic factor estimates', model_file) + 2] <-
paste0('LV_raw[1, j] ~ normal(0, 0.1);')

Expand All @@ -506,7 +510,7 @@ add_trend_lines = function(model_file, stan = FALSE,
'}\n}\n')
} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;\n\n',
'// latent trend AR2 terms\nvector<lower=-1,upper=1>[n_series] ar2;')
model_file[grep('// trend estimates', model_file) + 2] <-
paste0('trend[1, s] ~ normal(0, sigma[s]);')
Expand Down Expand Up @@ -534,9 +538,9 @@ add_trend_lines = function(model_file, stan = FALSE,
if(drift){
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar2;\n\n',
'// latent factor AR3 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar3;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1,upper=1>[n_lv] ar2;\n\n',
'// latent factor AR3 terms\nvector<lower=-1,upper=1>[n_lv] ar3;\n\n',
'// latent factor drift terms\nvector[n_lv] drift;')

model_file[grep('LV_raw[1, s] ~ ', model_file, fixed = T)] <-
Expand All @@ -558,7 +562,7 @@ add_trend_lines = function(model_file, stan = FALSE,
'}\n}\n')
} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;\n\n',
'// latent trend AR2 terms\nvector<lower=-1,upper=1>[n_series] ar2;\n\n',
'// latent trend AR3 terms\nvector<lower=-1,upper=1>[n_series] ar3;\n\n',
'// latent trend drift terms\nvector[n_series] drift;')
Expand Down Expand Up @@ -590,9 +594,9 @@ add_trend_lines = function(model_file, stan = FALSE,
} else {
if(use_lv){
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar2;\n\n',
'// latent factor AR3 terms\nvector<lower=-1.5,upper=1.5>[n_lv] ar3;')
paste0('row_vector[num_basis] b_raw;\n\n// latent factor AR1 terms\nvector<lower=-1,upper=1>[n_lv] ar1;\n\n',
'// latent factor AR2 terms\nvector<lower=-1,upper=1>[n_lv] ar2;\n\n',
'// latent factor AR3 terms\nvector<lower=-1,upper=1>[n_lv] ar3;')
model_file[grep('// dynamic factor estimates', model_file) + 2] <-
paste0('LV_raw[1, j] ~ normal(0, 0.1);')

Expand All @@ -612,7 +616,7 @@ add_trend_lines = function(model_file, stan = FALSE,
'}\n}\n')
} else {
model_file[grep('// raw basis', model_file) + 1] <-
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1.5,upper=1.5>[n_series] ar1;\n\n',
paste0('row_vector[num_basis] b_raw;\n\n// latent trend AR1 terms\nvector<lower=-1,upper=1>[n_series] ar1;\n\n',
'// latent trend AR2 terms\nvector<lower=-1,upper=1>[n_series] ar2;\n\n',
'// latent trend AR3 terms\nvector<lower=-1,upper=1>[n_series] ar3;')
model_file[grep('// trend estimates', model_file) + 2] <-
Expand Down
2 changes: 2 additions & 0 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#' \item \code{\link[mgcv]{nb}} with log-link, for count data
#' \item \code{\link[brms]{lognormal}} with identity-link, for non-negative real-valued data
#' \item \code{\link[brms]{bernoulli}} with logit-link, for binary data
#' \item \code{\link[brms]{beta_binomial}} with logit-link, as for `binomial()` but allows
#' for overdispersion
#' }
#'
#'Finally, \code{mvgam} supports the three extended families described here:
Expand Down
2 changes: 1 addition & 1 deletion R/forecast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ forecast_draws = function(object,
}
if(attr(object$model_data, 'trend_model') == 'None' |
nmix_notrend){
if(type == 'trend' & !nmix_notrend){
if(type == 'trend' & !nmix_notrend & !use_lv){
stop('No trend_model was used in this model',
call. = FALSE)
}
Expand Down
25 changes: 13 additions & 12 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ get_mvgam_priors = function(formula,
# If trend_formula supplied, first run get_mvgam_priors for the observation model
# and then modify the resulting output
if(!missing(trend_formula)){
if(trend_model == 'None') trend_model <- 'RW'
validate_trend_formula(trend_formula)
prior_df <- get_mvgam_priors(formula = orig_formula,
data = data,
Expand Down Expand Up @@ -403,11 +404,11 @@ get_mvgam_priors = function(formula,
}
}

# No point in latent variables if trend model is None
if(trend_model == 'None' & use_lv){
use_lv <- FALSE
warning('No point in latent variables if trend model is None; changing use_lv to FALSE')
}
# # No point in latent variables if trend model is None
# if(trend_model == 'None' & use_lv){
# use_lv <- FALSE
# warning('No point in latent variables if trend model is None; changing use_lv to FALSE')
# }

# Fill in missing observations in data_train so the size of the dataset is correct when
# building the initial JAGS model
Expand Down Expand Up @@ -884,7 +885,7 @@ get_mvgam_priors = function(formula,
}

if(trend_model == 'CAR1'){
trend_df <- data.frame(param_name = c(paste0('vector<lower=0,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=0,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=0>[',
Expand Down Expand Up @@ -912,7 +913,7 @@ get_mvgam_priors = function(formula,

if(trend_model == 'AR1'){
if(use_stan){
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=0>[',
Expand All @@ -937,7 +938,7 @@ get_mvgam_priors = function(formula,
');'
)))
} else {
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=0>[',
Expand Down Expand Up @@ -967,7 +968,7 @@ get_mvgam_priors = function(formula,

if(trend_model == 'AR2'){
if(use_stan){
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=-1,upper=1>[',
Expand Down Expand Up @@ -1004,7 +1005,7 @@ get_mvgam_priors = function(formula,
');'
)))
} else {
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=-1,upper=1>[',
Expand Down Expand Up @@ -1046,7 +1047,7 @@ get_mvgam_priors = function(formula,

if(trend_model == 'AR3'){
if(use_stan){
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=-1,upper=1>[',
Expand Down Expand Up @@ -1095,7 +1096,7 @@ get_mvgam_priors = function(formula,
');'
)))
} else {
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1.5,upper=1.5>[',
trend_df <- data.frame(param_name = c(paste0('vector<lower=-1,upper=1>[',
ifelse(use_lv, 'n_lv', 'n_series'),
'] ar1;'),
paste0('vector<lower=-1,upper=1>[',
Expand Down
4 changes: 2 additions & 2 deletions R/logLik.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#'@importFrom parallel setDefaultCluster stopCluster
#'@param object \code{list} object returned from \code{mvgam}
#'@param linpreds Optional `matrix` of linear predictor draws to use for calculating
#'poitwise log-likelihoods
#'@param newdata Optional `data.frame` of `list` object specifying which series each column
#'pointwise log-likelihoods
#'@param newdata Optional `data.frame` or `list` object specifying which series each column
#'in `linpreds` belongs to. If `linpreds` is supplied, then `newdata` must also be supplied
#'@param family_pars Optional `list` containing posterior draws of
#'family-specific parameters (i.e. shape, scale or overdispersion parameters). Required if
Expand Down
Loading

0 comments on commit cbf740c

Please sign in to comment.