PyTorch implementation of the bistable recurrent cell (BRC) and recurrently neuromodulated bistable recurrent cell (nBRC).
The available classes, BRCLayer
, nBRCLayer
, BRC
and nBRC
, are documented in brc.py.
git clone https://github.com/glambrechts/bistable-recurrent-cell
cd brc/
See main.py for a copy-first-input benchmark with the BRC cell.
python3 main.py
The implementation is similar to that of torch.nn.GRU
, such that the output of the RNN is its hidden state. A small wrapper is proposed in main.py to add a linear layer on top of the recurrent cell.
Also note that the parameter train_h0
allows to make the initial hidden state a trainable parameter of the recurrent cell.