Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update leakyparallel.py #326

Merged
merged 2 commits into from
Jun 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,24 @@ class LeakyParallel(nn.Module):

Several differences between `LeakyParallel` and `Leaky` include:

* Negative hidden states are clipped due to the forced ReLU operation in RNN
* Linear weights are included in addition to recurrent weights
* `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise
* There is no explicit reset mechanism
* Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel`
* Only the output spike is returned. Membrane potential is not accessible by default
* RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`.
* Negative hidden states are clipped due to the
forced ReLU operation in RNN.
* Linear weights are included in addition to
recurrent weights.
* `beta` is clipped between [0,1] and cloned to
`weight_hh_l` only upon layer initialization.
It is unused otherwise.
* There is no explicit reset mechanism.
* Several functions such as `init_hidden`, `output`,
`inhibition`, and `state_quant` are unavailable
in `LeakyParallel`.
* Only the output spike is returned. Membrane potential
is not accessible by default.
* RNN uses a hidden matrix of size (num_hidden, num_hidden)
to transform the hidden state vector. This would 'leak'
the membrane potential between LIF neurons, and so the
hidden matrix is forced to a diagonal matrix by default.
This can be disabled by setting `weight_hh_enable=True`.

Example::

Expand Down Expand Up @@ -117,22 +128,28 @@ def forward(self, x):

where:

`L = sequence length`
* **`L** = sequence length`

`N = batch size`
* **`N** = batch size`

`H_{in} = input_size`
* **`H_{in}** = input_size`

`H_{out} = hidden_size`
* **`H_{out}** = hidden_size`

Learnable Parameters:
- **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden weights of shape (hidden_size, input_size)
- **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden weights of the k-th layer which are sampled from `beta` of shape (hidden_size, hidden_size)
- **bias_ih_l** - the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
- **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
- **threshold** (torch.Tensor) - optional learnable thresholds
must be manually passed in, of shape `1` or`` (input_size).
- **graded_spikes_factor** (torch.Tensor) - optional learnable graded spike factor
- **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden
weights of shape (hidden_size, input_size).
- **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden
weights of the k-th layer which are sampled from `beta` of shape
(hidden_size, hidden_size).
- **bias_ih_l** - the learnable input-hidden bias of the k-th layer,
of shape (hidden_size).
- **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer,
of shape (hidden_size).
- **threshold** (torch.Tensor) - optional learnable thresholds must be
manually passed in, of shape `1` or`` (input_size).
- **graded_spikes_factor** (torch.Tensor) - optional learnable graded
spike factor.

"""

Expand Down Expand Up @@ -303,4 +320,4 @@ def _threshold_buffer(self, threshold, learn_threshold):
if learn_threshold:
self.threshold = nn.Parameter(threshold)
else:
self.register_buffer("threshold", threshold)
self.register_buffer("threshold", threshold)
Loading