Skip to content

Commit

Permalink
Merge pull request #86 from nicholasjclark/hierarchical_cors
Browse files Browse the repository at this point in the history
Hierarchical cors
  • Loading branch information
nicholasjclark authored Oct 30, 2024
2 parents 79e31b7 + f9efe63 commit b3cdfd1
Show file tree
Hide file tree
Showing 78 changed files with 2,625 additions and 164 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ jobs:
shell: Rscript {0}

- name: Test coverage
run: covr::codecov()
run: covr::codecov(line_exclusions = "R/stan_utils.R")
shell: Rscript {0}
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ export(gp)
export(hindcast)
export(hypotheses)
export(irf)
export(jsdgam)
export(lfo_cv)
export(lognormal)
export(loo)
Expand Down
10 changes: 5 additions & 5 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# mvgam 1.1.4 (development version; not yet on CRAN)
## New functionalities
* Added a `stability.mvgam` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21)
* Added function `jsdgam()` to estimate Joint Species Distribution Models in which both the latent factors and the observation model components can include any of mvgam's complex linear predictor effects. See `?mvgam::jsdgam` for details
* Added a `stability.mvgam()` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added functionality to estimate hierarchical error correlations when using multivariate latent process models and when the data are nested among levels of a relevant grouping factor (#75); see `?mvgam::AR` for an example
* Added `ZMVN()` error models for estimating Zero-Mean Multivariate Normal errors; convenient for working with non time-series data where latent residuals are expected to be correlated (such as when fitting Joint Species Distribution Models); see `?mvgam::ZMVN` for examples
=======
* Added a `fevd.mvgam` method to compute forecast error variance decompositions from models fit with Vector Autoregressive dynamics
* Added a `fevd.mvgam()` method to compute forecast error variance decompositions from models fit with Vector Autoregressive dynamics (#21 and #76)

## Bug fixes
* Fixed a minor bug in the way `trend_map` recognises levels of the `series` factor
Expand All @@ -15,8 +15,8 @@
* Allow intercepts to be included in process models when `trend_formula` is supplied. This breaks the assumption that the process has to be zero-centred, adding more modelling flexibility but also potentially inducing nonidentifiabilities with respect to any observation model intercepts. Thoughtful priors are a must for these models
* Added `standata.mvgam_prefit`, `stancode.mvgam` and `stancode.mvgam_prefit` methods for better alignment with 'brms' workflows
* Added 'gratia' to *Enhancements* to allow popular methods such as `draw()` to be used for 'mvgam' models if 'gratia' is already installed
* Added an `ensemble.mvgam_forecast` method to generate evenly weighted combinations of probabilistic forecast distributions
* Added an `irf.mvgam` method to compute Generalized and Orthogonalized Impulse Response Functions (IRFs) from models fit with Vector Autoregressive dynamics
* Added an `ensemble.mvgam_forecast()` method to generate evenly weighted combinations of probabilistic forecast distributions
* Added an `irf.mvgam()` method to compute Generalized and Orthogonalized Impulse Response Functions (IRFs) from models fit with Vector Autoregressive dynamics

## Deprecations
* The `drift` argument has been deprecated. It is now recommended for users to include parametric fixed effects of "time" in their respective GAM formulae to capture any expected drift effects
Expand Down
4 changes: 2 additions & 2 deletions R/add_MACor.R
Original file line number Diff line number Diff line change
Expand Up @@ -1893,7 +1893,7 @@ add_MaCor = function(model_file,
model_file[grep("// posterior predictions",
model_file, fixed = TRUE)] <-
paste0('// computed (full) error covariance matrix\n',
'matrix[n_lv, n_lv] Sigma;\n',
'matrix[n_lv, n_lv] Sigma;\n',
'Sigma = rep_matrix(0, n_lv, n_lv);\n',
'for (g in 1 : n_groups){\n',
'Sigma[group_inds[g], group_inds[g]] = multiply_lower_tri_self_transpose(L_Sigma_group[g]);\n',
Expand All @@ -1904,7 +1904,7 @@ add_MaCor = function(model_file,
model_file[grep("// posterior predictions",
model_file, fixed = TRUE)] <-
paste0('// computed (full) error covariance matrix\n',
'matrix[n_series, n_series] Sigma;\n',
'matrix[n_series, n_series] Sigma;\n',
'Sigma = rep_matrix(0, n_series, n_series);\n',
'for (g in 1 : n_groups){\n',
'Sigma[group_inds[g], group_inds[g]] = multiply_lower_tri_self_transpose(L_Sigma_group[g]);\n',
Expand Down
32 changes: 22 additions & 10 deletions R/forecast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,24 @@ forecast.mvgam = function(object,
type <- match.arg(arg = type, choices = c("link", "response",
"trend", "expected",
"detection", "latent_N"))

if(inherits(object, 'jsdgam')){
orig_trend_model <- attr(object$model_data, 'prepped_trend_model')
} else {
orig_trend_model <- object$trend_model
}

data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
n_series <- NCOL(object$ytimes)

# Check whether a forecast has already been computed
forecasts_exist <- FALSE
if(!is.null(object$test_data) && !missing(data_test)){
object$test_data <- validate_series_time(object$test_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
data_test <- validate_series_time(data_test,
trend_model = object$trend_model)
trend_model = orig_trend_model)
if(max(data_test$index..time..index) <=
max(object$test_data$index..time..index)){
forecasts_exist <- TRUE
Expand Down Expand Up @@ -126,7 +133,7 @@ forecast.mvgam = function(object,

if(is.null(object$test_data)){
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
data.frame(series = object$obs_data$series,
time = object$obs_data$time) %>%
dplyr::group_by(series) %>%
Expand Down Expand Up @@ -176,7 +183,7 @@ forecast.mvgam = function(object,
data_test$y <- rep(NA, NROW(data_test))
}
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
}

# Generate draw-specific forecasts
Expand All @@ -198,7 +205,7 @@ forecast.mvgam = function(object,

# Extract hindcasts
data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2],
length.out = NCOL(object$ytimes) + 1)
starts <- ends + 1
Expand Down Expand Up @@ -330,12 +337,12 @@ forecast.mvgam = function(object,
} else {
# If forecasts already exist, simply extract them
data_test <- validate_series_time(object$test_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
last_train <- max(object$obs_data$index..time..index) -
(min(object$obs_data$index..time..index) - 1)

data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2],
length.out = NCOL(object$ytimes) + 1)
starts <- ends + 1
Expand Down Expand Up @@ -593,8 +600,13 @@ forecast_draws = function(object,

# Check arguments
validate_pos_integer(n_cores)
if(inherits(object, 'jsdgam')){
orig_trend_model <- attr(object$model_data, 'prepped_trend_model')
} else {
orig_trend_model <- object$trend_model
}
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
data_test <- sort_data(data_test)
n_series <- NCOL(object$ytimes)
use_lv <- object$use_lv
Expand Down Expand Up @@ -695,7 +707,7 @@ forecast_draws = function(object,

# No need to compute in parallel if there was no trend model
nmix_notrend <- FALSE
if(!inherits(object$trend_model, 'mvgam_trend') &
if(!inherits(orig_trend_model, 'mvgam_trend') &
object$family == 'nmix'){
nmix_notrend <- TRUE
}
Expand Down
3 changes: 2 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ utils::globalVariables(c("y", "year", "smooth_vals", "smooth_num",
"time_lag", "dis_time", "maxt", "orig_rows",
"matches", "time.", "file_name", ".data",
"horizon", "target", "Series", "evd", "mean_evd",
"total_evd"))
"total_evd", "smooth_label", "by_variable",
"gr", "tot_subgrs", "subgr"))
16 changes: 10 additions & 6 deletions R/index-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ variables.mvgam = function(x, ...){
# Linear predictor parameters
observation_linpreds <- data.frame(orig_name = parnames[grepl('mus[',
parnames,
fixed = TRUE)],
fixed = TRUE) &
!grepl('trend_mus[',
parnames,
fixed = TRUE)],
alias = NA)

if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
trend_linpreds <- data.frame(orig_name = parnames[grepl('trend_mus[',
parnames,
fixed = TRUE)],
Expand All @@ -71,7 +74,7 @@ variables.mvgam = function(x, ...){
mgcv_names <- names(coef(x$mgcv_model))
observation_betas <- data.frame(orig_name = b_names, alias = mgcv_names)

if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
b_names <- colnames(mcmc_chains(x$model_output, 'b_trend'))
mgcv_names <- gsub('series', 'trend',
paste0(names(coef(x$trend_mgcv_model)), '_trend'))
Expand All @@ -97,7 +100,7 @@ variables.mvgam = function(x, ...){
}

trend_re_params <- NULL
if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
if(any(unlist(purrr::map(x$trend_mgcv_model$smooth, inherits, 'random.effect')))){
re_labs <- unlist(lapply(purrr::map(x$trend_mgcv_model$smooth, 'term'),
paste, collapse = ','))[
Expand Down Expand Up @@ -125,7 +128,7 @@ variables.mvgam = function(x, ...){
observation_smoothpars <- NULL
}

if(any(grepl('rho_trend[', parnames, fixed = TRUE))){
if(any(grepl('rho_trend[', parnames, fixed = TRUE)) & !inherits(x, 'jsdgam')){
trend_smoothpars <- data.frame(orig_name = parnames[grepl('rho_trend[',
parnames,
fixed = TRUE)],
Expand All @@ -136,7 +139,8 @@ variables.mvgam = function(x, ...){

# Trend state parameters
if(any(grepl('trend[', parnames, fixed = TRUE) &
!grepl('_trend[', parnames, fixed = TRUE))){
!grepl('_trend[', parnames, fixed = TRUE)) &
!inherits(x, 'jsdgam')){
trend_states <- grepl('trend[', parnames, fixed = TRUE) &
!grepl('_trend[', parnames, fixed = TRUE)
trends <- data.frame(orig_name = parnames[trend_states],
Expand Down
Loading

0 comments on commit b3cdfd1

Please sign in to comment.