Skip to content

Commit

Permalink
Merge pull request #3 from Roche/doc_update
Browse files Browse the repository at this point in the history
Update `README.md`
  • Loading branch information
dengemann authored Dec 19, 2024
2 parents 21ee855 + 33050e6 commit 8edf34c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 7 deletions.
71 changes: 66 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,87 @@
# GREEN architecture
# GREEN architecture (Gabor Riemann EEGNet)
![CI](https://github.com/Roche/neuro-green/actions/workflows/lint_and_test.yaml/badge.svg)
---

## About the architecture
The model is a deep learning architecture designed for EEG data that combines wavelet transforms and Riemannian geometry. The model is composed of the following layers:
It is based on the following layers:

- Convolution: Uses complex-valued Gabor wavelets with parameters that are learned during training.

- Pooling: Derives features from the wavelet-transformed signal, such as covariance matrices.

- Shrinkage layer: applies [shrinkage](https://scikit-learn.org/1.5/modules/covariance.html#basic-shrinkage) to the covariance matrices.

- Riemannian Layers: Applies transformations to the matrices, leveraging the geometry of the Symmetric Positive Definite (SPD) manifold.

- Fully Connected Layers: Standard fully connected layers for final processing.

![alt text](assets/concept_figure.png)


## Getting started
## Getting started
Clone the repository and install locally.

```
pip install -e .
```

## Dependencies

You will need the following dependencies to get most out of GREEN.

```
scikit-learn
torch
geotorch
lightning
mne
```

## Examples

Examples illustrating how to train the presented model can be found in the `green/research_code` folder. The notebook `example.ipynb` shows how to train the model on raw EEG data. And the notebook `example_wo_wav.ipynb` shows how to train a submodel that uses covariance matrices as input.

In addition, being pure PyTorch, the GREEN model can easily be integrated to [`braindecode`](https://braindecode.org/stable/index.html) routines.

```python
import torch
from braindecode import EEGRegressor
from green.wavelet_layers import RealCovariance
from green.research_code.pl_utils import get_green

green_model = get_green(
n_freqs=5, # Learning 5 wavelets
n_ch=22, # EEG data with 22 channels
sfreq=100, # Sampling frequency of 100 Hz
dropout=0.5, # Dropout rate of 0.5 in FC layers
hidden_dim=[100], # Use 100 units in the hidden layer
pool_layer=RealCovariance(), # Compute covariance after wavelet transform
bi_out=[20], # Use a BiMap layer outputing a 20x20 matrix
out_dim=1, # Output dimension of 1, for regression
)

device = "cuda" if torch.cuda.is_available() else "cpu"
EarlyStopping(monitor="valid_loss", patience=10, load_best=True)
clf = EEGRegressor(
module=green_model,
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
device=device,
callbacks=[], # Callbacks can be added here, e.g. EarlyStopping
)
```

## Citation
When using our code, please cite the reference article:

``` bibtex
@article {Paillard2024.05.14.594142,
@article {paillard_2024_green,
author = {Paillard, Joseph and Hipp, Joerg F and Engemann, Denis A},
title = {GREEN: a lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration},
elocation-id = {2024.05.14.594142},
year = {2024},
doi = {10.1101/2024.05.14.594142},
URL = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142},
eprint = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142.full.pdf},
journal = {bioRxiv}
}
```
9 changes: 9 additions & 0 deletions green/research_code/example.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training GREEN with `lightning`\n",
"\n",
"The notebook is a simple example of how to train the GREEN model. It uses dummy data, `mne.Epochs`."
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
12 changes: 10 additions & 2 deletions green/research_code/example_wo_wav.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training a GREEN submodel without learning wavelets\n",
"\n",
"This notebook provides a minimal example of training a Green submodel without learning wavelets. The difference with the full model is that inputs are covariance matrices instead of raw EEG data. "
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand All @@ -11,8 +20,7 @@
"import torch\n",
"\n",
"from research_code.pl_utils import get_green_g2, GreenClassifierLM\n",
"from research_code.crossval_utils import pl_crossval\n",
"\n"
"from research_code.crossval_utils import pl_crossval"
]
},
{
Expand Down

0 comments on commit 8edf34c

Please sign in to comment.