Skip to content

Commit

Permalink
Merge pull request #1840 from clinssen/test-stdp-synapse
Browse files Browse the repository at this point in the history
Add STDP synapse unit testing
  • Loading branch information
heplesser authored Oct 27, 2021
2 parents ce58929 + 7c7d949 commit cfa80ce
Show file tree
Hide file tree
Showing 5 changed files with 447 additions and 151 deletions.
27 changes: 6 additions & 21 deletions testsuite/pytests/test_jonke_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
Test functionality of the Tetzlaff stdp synapse
"""

import unittest
import nest
import numpy as np


@nest.ll_api.check_stack
class JonkeSynapseTest(unittest.TestCase):
class TestJonkeSynapse:
"""
Test the weight change by STDP.
The test is performed by generating two Poisson spike trains,
Expand Down Expand Up @@ -73,12 +72,12 @@ def test_weight_drift(self):
weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.synapse_parameters["weight"])
self.assertAlmostEqual(
np.testing.assert_almost_equal(
weight_reproduced_independently,
weight_by_nest,
msg=f"{self.synapse_parameters['synapse_model']} test:\n" +
f"Resulting synaptic weight {weight_by_nest} " +
f"differs from expected {weight_reproduced_independently}")
err_msg=f"{self.synapse_parameters['synapse_model']} test:\n" +
f"Resulting synaptic weight {weight_by_nest} " +
f"differs from expected {weight_reproduced_independently}")

def do_the_nest_simulation(self):
"""
Expand Down Expand Up @@ -111,7 +110,7 @@ def do_the_nest_simulation(self):
# reveal small differences in the weight change between NEST
# and ours, some low-probability events (say, coinciding
# spikes) can well not have occurred. To generate and
# test every possible combination of pre/post precedence, we
# test every possible combination of pre/post order, we
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
# post: 2 3 4 8 9 10 12
Expand Down Expand Up @@ -244,17 +243,3 @@ def depress(self, _delta_t, weight, Kminus):
if weight < 0:
weight = 0
return weight


def suite():
suite = unittest.TestLoader().loadTestsFromTestCase(JonkeSynapseTest)
return unittest.TestSuite([suite])


def run():
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())


if __name__ == "__main__":
run()
101 changes: 54 additions & 47 deletions testsuite/pytests/test_stdp_multiplicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@

# This script tests the parrot_neuron in NEST.

import nest
import unittest
import math
import nest
import numpy as np
import pytest

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
DEBUG_PLOTS = True
except Exception:
DEBUG_PLOTS = False


@nest.ll_api.check_stack
class StdpSpikeMultiplicity(unittest.TestCase):
class TestStdpSpikeMultiplicity:
"""
Test correct handling of spike multiplicity in STDP.
Expand All @@ -51,23 +58,29 @@ class StdpSpikeMultiplicity(unittest.TestCase):
delta, since in this case all spikes are at the end of the step, i.e.,
all spikes have identical times independent of delta.
2. We choose delta values that are decrease by factors of 2. The
2. We choose delta values that are decreased by factors of 2. The
plasticity rules depend on spike-time differences through
::
exp(dT / tau)
where dT is the time between pre- and postsynaptic spikes. We construct
pre- and postsynaptic spike times so that
dT = pre_post_shift + m * delta
::
with m * delta < resolution << pre_post_shift. The time-dependence
dT = pre_post_shift + m * delta
with ``m * delta < resolution << pre_post_shift``. The time-dependence
of the plasticity rule is therefore to good approximation linear in
delta.
We can thus test as follows: Let w_pl be the weight obtained with the
plain parrot, and w_ps_j the weight obtained with the precise parrot
for delta_j = delta0 / 2^j. Then,
We can thus test as follows: Let ``w_pl`` be the weight obtained with the
plain parrot, and ``w_ps_j`` the weight obtained with the precise parrot
for ``delta_j = delta0 / 2^j``. Then,
::
( w_ps_{j+1} - w_pl ) / ( w_ps_j - w_pl ) ~ 0.5 for all j
Expand Down Expand Up @@ -157,8 +170,7 @@ def run_protocol(self, pre_post_shift):
# create spike recorder --- debugging only
spikes = nest.Create("spike_recorder")
nest.Connect(
pre_parrot + post_parrot +
pre_parrot_ps + post_parrot_ps,
pre_parrot + post_parrot + pre_parrot_ps + post_parrot_ps,
spikes
)

Expand Down Expand Up @@ -194,47 +206,42 @@ def run_protocol(self, pre_post_shift):
post_weights['parrot'].append(w_post)
post_weights['parrot_ps'].append(w_post_ps)

if DEBUG_PLOTS:
fig, ax = plt.subplots(nrows=2)
fig.suptitle("Final obtained weights")
ax[0].plot(post_weights["parrot"], marker="o", label="parrot")
ax[0].plot(post_weights["parrot_ps"], marker="o", label="parrot_ps")
ax[0].set_ylabel("final weight")
ax[0].set_xticklabels([])
ax[1].semilogy(np.abs(np.array(post_weights["parrot"]) - np.array(post_weights["parrot_ps"])),
marker="o", label="error")
ax[1].set_xticks([i for i in range(len(deltas))])
ax[1].set_xticklabels(["{0:.1E}".format(d) for d in deltas])
ax[1].set_xlabel("timestep [ms]")
for _ax in ax:
_ax.grid(True)
_ax.legend()
plt.savefig("/tmp/test_stdp_multiplicity.png")
plt.close(fig)
print(post_weights)
return post_weights

def test_ParrotNeuronSTDPProtocolPotentiation(self):
"""Check weight convergence on potentiation."""

post_weights = self.run_protocol(pre_post_shift=10.0)
w_plain = np.array(post_weights['parrot'])
w_precise = np.array(post_weights['parrot_ps'])
@pytest.mark.parametrize("pre_post_shift", [10., # test potentiation
-10.]) # test depression
def test_stdp_multiplicity(self, pre_post_shift, max_abs_err=1E-3):
"""Check that for smaller and smaller timestep, weights obtained from parrot and precise parrot converge.
assert all(w_plain == w_plain[0]), 'Plain weights differ'
dw = w_precise - w_plain
dwrel = dw[1:] / dw[:-1]
assert all(np.round(dwrel, decimals=3) ==
0.5), 'Precise weights do not converge.'
Enforce a maximum allowed absolute error ``max_abs_err`` between the final weights for the smallest timestep
tested.
def test_ParrotNeuronSTDPProtocolDepression(self):
"""Check weight convergence on depression."""
Enforce that the error should strictly decrease with smaller timestep."""

post_weights = self.run_protocol(pre_post_shift=-10.0)
post_weights = self.run_protocol(pre_post_shift=pre_post_shift)
w_plain = np.array(post_weights['parrot'])
w_precise = np.array(post_weights['parrot_ps'])

assert all(w_plain == w_plain[0]), 'Plain weights differ'
dw = w_precise - w_plain
dwrel = dw[1:] / dw[:-1]
assert all(np.round(dwrel, decimals=3) ==
0.5), 'Precise weights do not converge.'


def suite():

# makeSuite is sort of obsolete http://bugs.python.org/issue2721
# using loadTestsFromTestCase instead.
suite = unittest.TestLoader().loadTestsFromTestCase(StdpSpikeMultiplicity)
return unittest.TestSuite([suite])


def run():
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())


if __name__ == "__main__":
run()
assert all(w_plain == w_plain[0]), 'Plain weights should be independent of timestep!'
abs_err = np.abs(w_precise - w_plain)
assert abs_err[-1] < max_abs_err, 'Final absolute error is ' + '{0:.2E}'.format(abs_err[-1]) \
+ ' but should be <= ' + '{0:.2E}'.format(max_abs_err)
assert np.all(np.diff(abs_err) < 0), 'Error should decrease with smaller timestep!'
78 changes: 27 additions & 51 deletions testsuite/pytests/test_stdp_nn_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
# and stdp_nn_restr_synapse in NEST.

import nest
import unittest
import numpy as np
import pytest

from math import exp


@nest.ll_api.check_stack
class STDPNNSynapsesTest(unittest.TestCase):
class TestSTDPNNSynapses:
"""
Test the weight change by STDP
with three nearest-neighbour spike pairing schemes.
Expand All @@ -43,12 +45,12 @@ class STDPNNSynapsesTest(unittest.TestCase):
Instead, it directly iterates through the spike history.
"""

@pytest.fixture(autouse=True)
def setUp(self):
self.resolution = 0.1 # [ms]
self.presynaptic_firing_rate = 20.0 # [Hz]
self.postsynaptic_firing_rate = 20.0 # [Hz]
self.simulation_duration = 1e+4 # [ms]
self.hardcoded_trains_length = 15. # [ms]
self.synapse_parameters = {
"receptor_type": 1,
"delay": self.resolution,
Expand All @@ -66,6 +68,18 @@ def setUp(self):
"tau_minus": 33.7
}

# While the random sequences, fairly long, would supposedly
# reveal small differences in the weight change between NEST
# and ours, some low-probability events (say, coinciding
# spikes) can well not have occured. To generate and
# test every possible combination of pre/post order, we
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
# post: 2 3 4 8 9 10 12
self.hardcoded_pre_times = np.array([1, 5, 6, 7, 9, 11, 12, 13], dtype=float)
self.hardcoded_post_times = np.array([2, 3, 4, 8, 9, 10, 12], dtype=float)
self.hardcoded_trains_length = 2. + max(np.amax(self.hardcoded_pre_times), np.amax(self.hardcoded_post_times))

def do_nest_simulation_and_compare_to_reproduced_weight(self,
pairing_scheme):
synapse_model = "stdp_" + pairing_scheme + "_synapse"
Expand All @@ -75,13 +89,13 @@ def do_nest_simulation_and_compare_to_reproduced_weight(self,
weight_reproduced_independently = self.reproduce_weight_drift(
pre_spikes, post_spikes,
self.synapse_parameters["weight"])
self.assertAlmostEqual(
np.testing.assert_almost_equal(
weight_reproduced_independently,
weight_by_nest,
msg=synapse_model + " test: "
"Resulting synaptic weight %e "
"differs from expected %e" % (
weight_by_nest, weight_reproduced_independently))
err_msg=synapse_model + " test: "
"Resulting synaptic weight %e "
"differs from expected %e" % (
weight_by_nest, weight_reproduced_independently))

def do_the_nest_simulation(self):
"""
Expand All @@ -93,12 +107,10 @@ def do_the_nest_simulation(self):
nest.ResetKernel()
nest.resolution = self.resolution

neurons = nest.Create(
presynaptic_neuron, postsynaptic_neuron = nest.Create(
"parrot_neuron",
2,
params=self.neuron_parameters)
presynaptic_neuron = neurons[0]
postsynaptic_neuron = neurons[1]

generators = nest.Create(
"poisson_generator",
Expand All @@ -110,32 +122,13 @@ def do_the_nest_simulation(self):
presynaptic_generator = generators[0]
postsynaptic_generator = generators[1]

# While the random sequences, fairly long, would supposedly
# reveal small differences in the weight change between NEST
# and ours, some low-probability events (say, coinciding
# spikes) can well not have occured. To generate and
# test every possible combination of pre/post precedence, we
# append some hardcoded spike sequences:
# pre: 1 5 6 7 9 11 12 13
# post: 2 3 4 8 9 10 12
(
hardcoded_pre_times,
hardcoded_post_times
) = [
[
self.simulation_duration - self.hardcoded_trains_length + t
for t in train
] for train in (
(1, 5, 6, 7, 9, 11, 12, 13),
(2, 3, 4, 8, 9, 10, 12)
)
]

spike_senders = nest.Create(
"spike_generator",
2,
params=({"spike_times": hardcoded_pre_times},
{"spike_times": hardcoded_post_times})
params=({"spike_times": self.hardcoded_pre_times
+ self.simulation_duration - self.hardcoded_trains_length},
{"spike_times": self.hardcoded_post_times
+ self.simulation_duration - self.hardcoded_trains_length})
)
pre_spike_generator = spike_senders[0]
post_spike_generator = spike_senders[1]
Expand Down Expand Up @@ -285,20 +278,3 @@ def test_nn_pre_centered_synapse(self):

def test_nn_restr_synapse(self):
self.do_nest_simulation_and_compare_to_reproduced_weight("nn_restr")


def suite():

# makeSuite is sort of obsolete http://bugs.python.org/issue2721
# using loadTestsFromTestCase instead.
suite = unittest.TestLoader().loadTestsFromTestCase(STDPNNSynapsesTest)
return unittest.TestSuite([suite])


def run():
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())


if __name__ == "__main__":
run()
Loading

0 comments on commit cfa80ce

Please sign in to comment.