Spiking neural network (SNN) framework written on top of PyTorch for efficient simulation of SNNs both on CPU and GPU. The framework is intended for with correlation based learning methods. The library adheres to the highly modular and dynamic design of PyTorch, and does not require its user to learn a new framework.
This framework's power lies in the ease of defining and mixing new Neuron and Connection objects that seamlessly work together, even different versions, in a single network.
PySNN is designed to mostly provide low level objects to its user that can be combined and mixed, just as in PyTorch. The biggest difference is that a network now consists of two types of modules, instead of the single nn.Module in regular PyTorch. These new modules are the pysnn.Neuron and pysnn.Connection.
Documentation can be found at: https://basbuller.github.io/PySNN/
Design of the PySNN framework took inspiration from the following two libraries:
Installation can be done with pip:
$ pip install pysnn
If you want to make updates to the library without having to reinstall it, use the following install command instead:
$ git clone https://github.com/BasBuller/PySNN.git
$ cd PySNN/
$ pip install -e .
Some examples need additional libraries. To install these, run:
$ pip install pysnn[examples]
Code is formatted with Black using a pre-commit hook. To configure it, run:
$ pre-commit install
Installing PySNN requires a Python version of 3.6 or higher, Python 2 is not supported. It also requires PyTorch to be of version 1.2 or higher.
Intention is to mirror most of the structure of PyTorch framework. As an example, the followig piece of code shows how much a Spiking Neural Network definition in PySNN looks like a network definition in PyTorch. The network's graph is cyclical, due to the feedback connection from the output neurons to the hidden neurons.
class Network(SNNNetwork):
def __init__(self):
super(Network, self).__init__()
# Input
self.input = Input((batch_size, 1, n_in), *input_dynamics)
# Layer 1
self.mlp1_c = Linear(n_in, n_hidden, *connection_dynamics)
self.neuron1 = LIFNeuron((batch_size, 1, n_hidden), *neuron_dynamics)
self.add_layer("fc1", self.mlp1_c, self.neuron1)
# Layer 2
self.mlp2_c = Linear(n_hidden, n_out, *connection_dynamics)
self.neuron2 = LIFNeuron((batch_size, 1, n_out), *neuron_dynamics)
self.add_layer("fc2", self.mlp2_c, self.neuron2)
# Feedback connection from neuron 2 to neuron 1
self.mlp2_prev = Linear(n_out, n_hidden, *c_dynamics)
self.add_layer("fc2_back", self.mlp2_prev, self.neuron1)
def forward(self, input):
spikes, trace = self.input(input)
# Layer 1
x_prev, _ = self.mlp2_prev(self.neuron2.spikes, self.neuron2.trace)
x_forw, _ = self.mlp1_c(x, t)
x, t = self.neuron1([x_forw, x_rec, x_prev])
# Layer out
spikes, trace = self.mlp2_c(spikes, trace)
spikes, trace = self.neuron2(spikes, trace)
return x
Any help, suggestions, or additions to PySNN are greatly appreciated! Feel free to make pull request or start a chat about the library. In case of making a pull request, please do have a look at the contribution guidelines.
The overall structure of a network definition is the same as in PyTorch where possible. All newly defined object inherit from the nn.Module class. The biggest differences are as follows:
- Each layer consists out of a Connection and a Neuron object because they both implement different time based dynamics.
- Training does not use gradients.
- Neurons have a state that persists between consecutive timesteps.
- Networks inherit from a special pysnn.SNNNetwork class.
This object is the main difference with ANNs. Neurons have highly non-linear (and also non-differentiable) behaviour. They have an internal voltage, once that surpasses a threshold value it generates a binary spike (non-differentiable operation) which is then propagated to the following layer of Neurons through a Connection object. Defining a new Neuron class is rather simple, one only has to define new neuronal dynamics functions for the Neuron's voltage and trace. The supporting functions are (almost) all defined in the Neuron base class.
For an introduction to (biological) neuronal dynamics, and spiking neural networks the reader is referred to Neuronal Dynamics by Wulfram Gerstner, Werner M. Kistler, Richard Naud and Liam Paninski.
It contains connection weights and routes signals between different layers. It only really differs with PyTorch layers in the fact that it has a state between iterations of its past activity, and the possibility of delaying signal transmission between layers.
In order to keep track of traces and delays in information passing tensors an extra dimension is needed compared to the PyTorch conventions. Due to the addition of spike traces, each spiking tensor contains an extra trace dimension as the last dimension. The resulting dimension ordering is as follows for an image tensor (trace is indicated as R to not be confused with time for video data):
[batch size, channels, height, width, traces] (B,C,H,W,R)
For fully connected layers the resulting tensor is as follows (free dimension can be used the same as in PyTorch):
[batch size, free dimension, input elements, traces] (B,F,I,R)
Currently, no explicit 3D convolution is possible like is common within video-processing. Luckily, SNNs have a built-in temporal dimension and are (currently still theoretically) well suited for processing videos event by event, and thus not needing 3D convolution.
Traces are stored both in the Neuron and Connection objects. Currently, Connection objects takes traces from their pre-synaptic Neurons and propagate the trace over time, meaning it does not do any further processing on the traces. If it is desired, one can implement separate trace processing in a custom Connection object.
Traces are stored in a tensor in each Connection, as well as the delay for each trace propagating through the Connection. Only one trace (or signal) can tracked through each synapse. In case delay times through a synapse become very long (longer than the refractory period of the pre-synaptic cell) it is possible for a new signal to enter the Connection before the previous one has travelled through it. In the current implementation the old signal will be overwritten, meaning the information is lost before it can be used!
It is up to the user to assure refractory periods are just as long or longer than the synaptic delay in the following Connection!
Make sure each module has a self.reset_state() method! It is called from the SNNNetwork class and is needed for proper simulation of multiple inputs.