Skip to content

Latest commit



128 lines (104 loc) · 7.73 KB

File metadata and controls

128 lines (104 loc) · 7.73 KB

A VAE-based Framework for Learning Multi-Level Neural Granger-Causal Connectivity

(Copyright 2024) by Jiahe Lin, Huitian Lei and George Michailidis; paper accepted in TMLR, 2024. [link to paper]

Environment Setup

Assume anaconda/miniconda/miniforge has already been installed. To set up the environment, proceed with the following commands:

conda create -n vae-gc python=3.9
conda activate vae-gc
conda install pyyaml numpy pandas scipy scikit-learn
conda install matplotlib seaborn 
pip install pytorch-lightning torch

See also requirements.txt.

To verify that your GPU is up and running:

python -c "import torch; print(torch.cuda.is_available())"

Repo Layout

We outline the major components in this repository for ease of navigation.

  • bin/: shell scripts for execution; see also section Experiments in the paper

  • src/:

    • torch-based modules whose forward passes correspond to the proposed methods. Model supported: OneLayer and TwoLayer, resp. for single and multi-entity VAE-based models.
    • pl.lightning-based modules that encapsulate the forward-backward propagation pipeline for running VAE-based multi-entity/single-entity methods on a given dataset
    • similar to the above but specifically for synthetic datasets, where the underlying true GC graphs are known. In particular, dataloader and graph evaluation (through torchmetrics) is integrated in every step of training, to facilitate model development and tracking.
      • Of note, the printed metrics during training do not correspond to the final metrics presented in the paper (e.g., AUROC and AUPRC). In particular, for the case where graph type is numeric, it calculates a Pearson Correlation-type metric between the truth and the estimates at the individual sample level.
    • datasets/: objects with being the base class, to read a specific (type) of dataset from disk so that it can be loaded properly through DataLoader later on.
  • generator/: scripts used for generating synthetic data

    • simulator/: various simulator objects for synthetic data generation of the corresponding setting
  • utils/: utilities

    •, wrapper functions for training models for synthetic and real data experiments
    • utility functions for data processing and trajectory parsing
    • for results evaluation
  • configs/: data parameters (e.g., # of entities, # of nodes, trajectory length, etc) and hyperparameters for all VAE-based methods. Some naming convention:

    • no suffix: this is the base config and the synthetic data setting parameters are specified here in the data_params section. The remaining sections correspond to multi-entity learning using a node-centric decoder
    • pattern *_edge: edge-centric decoder
    • pattern *_oneSub: parameters corresponding to single-entity learning

    One can alternatively deviates from these naming conventions, use any customized config file name and pass it with --config in the run command to override the default ones.

  • root/:

    • script for running synthetic data experiments using multi-entity VAE-based method
    • script for running synthetic data experiments using single-entity VAE-based method
    • script for running real data experiments using multi-entity VAE-based method
    • script for running real data experiments using single-entity VAE-based method

Experiments In the Paper

Synthetic data

  • Run synthetic data experiments based on the VAE-based methods, including the proposed one (multi-entity learning) and its single-entity counterpart:
    cd bin
    mkdir -p logs
    ## data generation is included by default; toggle (in the shell script) to false if not needed (say, the data has already been generated)
    ## argvs: SETTING_NAME, GPU_ID, CONFIG_VERSION (default to none, indicating no suffix)
    ## choose SETTING_NAME amongst {Lorenz96, LinearVAR, NonLinearVAR, Lotka, Springs5}
    bash [SETTING_NAME] 0 &>logs/[SETTING_NAME]_run.log
  • Evaluate a single run (a specific experiment setting and a single data seed):
    ## the following command should be executed in the root dir
    python -u --ds_str=[SETTING_NAME] --seed=0
  • Evaluate all data replicates for a specific experiment setting:
    cd bin
    ## argv: ds_str
    bash [SETTING_NAME]

Real data - multi subject EEG

Data is available from rsed2017-dataverse. Once the data are downloaded, they should be put under data_real/EEG_CSV_files/, with the filenames being Subject[ID]_EO.csv or Subject[ID]_EC.csv, depending on the underlying neurophysiological experiment setting.

  • Prepare the raw datasets so that the long trajectories are parsed for the VAE-based method to consume
    python -u --ds_str='EEG_EC,EEG_EO'
  • Run the experiments
    cd bin
    ## argv: ds_str; choose between EEG_EO and EEG_EC
    bash EEG_EO &>logs/EO_log.log

Run Your Own Datasets

See ./demo.ipynb. In the notebook, we generate a demo dataset and outline the steps/files required to utilize our end-to-end training pipeline.

Citation and Contact

To cite this work:

    title = {A VAE-based Framework for Learning Multi-Level Neural Granger-Causal Connectivity},
    author = {Lin, Jiahe and Lei, Huitian and Michailidis, George},
    year = {2024},
    journal = {Transactions on Machine Learning Research},
    url = {}
  • For questions on the paper and/or collaborations based on the methods (extensions or applications), contact George Michailidis
  • For questions on the code implementation, contact Jiahe Lin and/or Huitian Lei


We list below the repositories referenced as this codebase was being developed


Below lists the competitor models considered in the paper and their corresponding repositories